* initial attempt at get_redirect_target

* removing the `i` from the redirect detection while-loop
This commit is contained in:
jonathan vanasco
2017-02-10 13:53:23 -05:00
parent ad65b0cb19
commit 70f31a3166
4 changed files with 70 additions and 11 deletions
+2
View File
@@ -177,3 +177,5 @@ Patches and Suggestions
- Andrii Soldatenko (`@a_soldatenko <https://github.com/andriisoldatenko>`_)
- Moinuddin Quadri <moin18@gmail.com> (`@moin18 <https://github.com/moin18>`_)
- Matt Kohl (`@mattkohl <https://github.com/mattkohl>`_)
- Jonathan Vanasco (`@jvanasco <https://github.com/jvanasco>`_)
+9
View File
@@ -3,6 +3,15 @@
Release History
---------------
**Unreleased**
+++++++++++++++++++
- The behavior of ``SessionRedirectMixin`` was slightly altered.
``resolve_redirects`` will now detect a redirect by calling
``get_redirect_target(response)`` instead of directly
querying ``Response.is_redirect`` and ``Response.headers['location']``.
Advanced users will be able to process malformed redirects more easily.
2.13.0 (2017-01-24)
+++++++++++++++++++
+16 -11
View File
@@ -86,35 +86,39 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict):
class SessionRedirectMixin(object):
def get_redirect_target(self, resp):
"""Receives a Response. Returns a redirect URI or ``None``"""
if resp.is_redirect:
return resp.headers['location']
return None
def resolve_redirects(self, resp, req, stream=False, timeout=None,
verify=True, cert=None, proxies=None, **adapter_kwargs):
"""Receives a Response. Returns a generator of Responses."""
i = 0
hist = [] # keep track of history
while resp.is_redirect:
url = self.get_redirect_target(resp)
while url:
prepared_request = req.copy()
if i > 0:
# Update history and keep track of redirects.
hist.append(resp)
new_hist = list(hist)
resp.history = new_hist
# Update history and keep track of redirects.
# resp.history must ignore the original request in this loop
hist.append(resp)
resp.history = hist[1:]
try:
resp.content # Consume socket so it can be released
except (ChunkedEncodingError, ContentDecodingError, RuntimeError):
resp.raw.read(decode_content=False)
if i >= self.max_redirects:
if len(resp.history) >= self.max_redirects:
raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp)
# Release the connection back into the pool.
resp.close()
url = resp.headers['location']
# Handle redirection without scheme (see: RFC 1808 Section 4)
if url.startswith('//'):
parsed_rurl = urlparse(resp.url)
@@ -192,7 +196,8 @@ class SessionRedirectMixin(object):
extract_cookies_to_jar(self.cookies, prepared_request, resp.raw)
i += 1
# extract redirect url, if any, for the next loop
url = self.get_redirect_target(resp)
yield resp
def rebuild_auth(self, prepared_request, response):
+43
View File
@@ -1740,6 +1740,49 @@ class TestRequests:
assert 'Transfer-Encoding' in prepared_request.headers
assert 'Content-Length' not in prepared_request.headers
def test_custom_redirect_mixin(self, httpbin):
"""Tests a custom mixin to overwrite ``get_redirect_target``.
Ensures a subclassed ``requests.Session`` can handle a certain type of
malformed redirect responses.
1. original request receives a proper response: 302 redirect
2. following the redirect, a malformed response is given:
status code = HTTP 200
location = alternate url
3. the custom session catches the edge case and follows the redirect
"""
url_final = httpbin('html')
querystring_malformed = urlencode({'location': url_final})
url_redirect_malformed = httpbin('response-headers?%s' % querystring_malformed)
querystring_redirect = urlencode({'url': url_redirect_malformed})
url_redirect = httpbin('redirect-to?%s' % querystring_redirect)
urls_test = [url_redirect,
url_redirect_malformed,
url_final,
]
class CustomRedirectSession(requests.Session):
def get_redirect_target(self, resp):
# default behavior
if resp.is_redirect:
return resp.headers['location']
# edge case - check to see if 'location' is in headers anyways
location = resp.headers.get('location')
if location and (location != resp.url):
return location
return None
session = CustomRedirectSession()
r = session.get(urls_test[0])
assert len(r.history) == 2
assert r.status_code == 200
assert r.history[0].status_code == 302
assert r.history[0].is_redirect
assert r.history[1].status_code == 200
assert not r.history[1].is_redirect
assert r.url == urls_test[2]
class TestCaseInsensitiveDict: