diff --git a/AUTHORS.rst b/AUTHORS.rst index d29fa812..48cd155b 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -177,3 +177,5 @@ Patches and Suggestions - Andrii Soldatenko (`@a_soldatenko `_) - Moinuddin Quadri (`@moin18 `_) - Matt Kohl (`@mattkohl `_) +- Jonathan Vanasco (`@jvanasco `_) + diff --git a/HISTORY.rst b/HISTORY.rst index 5ec8bb21..c26036c1 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,15 @@ Release History --------------- +**Unreleased** ++++++++++++++++++++ + +- The behavior of ``SessionRedirectMixin`` was slightly altered. + ``resolve_redirects`` will now detect a redirect by calling + ``get_redirect_target(response)`` instead of directly + querying ``Response.is_redirect`` and ``Response.headers['location']``. + Advanced users will be able to process malformed redirects more easily. + 2.13.0 (2017-01-24) +++++++++++++++++++ diff --git a/requests/sessions.py b/requests/sessions.py index 7983282a..72ef179f 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -86,35 +86,39 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict): class SessionRedirectMixin(object): + + def get_redirect_target(self, resp): + """Receives a Response. Returns a redirect URI or ``None``""" + if resp.is_redirect: + return resp.headers['location'] + return None + def resolve_redirects(self, resp, req, stream=False, timeout=None, verify=True, cert=None, proxies=None, **adapter_kwargs): """Receives a Response. Returns a generator of Responses.""" - i = 0 hist = [] # keep track of history - while resp.is_redirect: + url = self.get_redirect_target(resp) + while url: prepared_request = req.copy() - if i > 0: - # Update history and keep track of redirects. - hist.append(resp) - new_hist = list(hist) - resp.history = new_hist + # Update history and keep track of redirects. + # resp.history must ignore the original request in this loop + hist.append(resp) + resp.history = hist[1:] try: resp.content # Consume socket so it can be released except (ChunkedEncodingError, ContentDecodingError, RuntimeError): resp.raw.read(decode_content=False) - if i >= self.max_redirects: + if len(resp.history) >= self.max_redirects: raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp) # Release the connection back into the pool. resp.close() - url = resp.headers['location'] - # Handle redirection without scheme (see: RFC 1808 Section 4) if url.startswith('//'): parsed_rurl = urlparse(resp.url) @@ -192,7 +196,8 @@ class SessionRedirectMixin(object): extract_cookies_to_jar(self.cookies, prepared_request, resp.raw) - i += 1 + # extract redirect url, if any, for the next loop + url = self.get_redirect_target(resp) yield resp def rebuild_auth(self, prepared_request, response): diff --git a/tests/test_requests.py b/tests/test_requests.py index fd35103b..cd4c68db 100755 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1740,6 +1740,49 @@ class TestRequests: assert 'Transfer-Encoding' in prepared_request.headers assert 'Content-Length' not in prepared_request.headers + def test_custom_redirect_mixin(self, httpbin): + """Tests a custom mixin to overwrite ``get_redirect_target``. + + Ensures a subclassed ``requests.Session`` can handle a certain type of + malformed redirect responses. + + 1. original request receives a proper response: 302 redirect + 2. following the redirect, a malformed response is given: + status code = HTTP 200 + location = alternate url + 3. the custom session catches the edge case and follows the redirect + """ + url_final = httpbin('html') + querystring_malformed = urlencode({'location': url_final}) + url_redirect_malformed = httpbin('response-headers?%s' % querystring_malformed) + querystring_redirect = urlencode({'url': url_redirect_malformed}) + url_redirect = httpbin('redirect-to?%s' % querystring_redirect) + urls_test = [url_redirect, + url_redirect_malformed, + url_final, + ] + + class CustomRedirectSession(requests.Session): + def get_redirect_target(self, resp): + # default behavior + if resp.is_redirect: + return resp.headers['location'] + # edge case - check to see if 'location' is in headers anyways + location = resp.headers.get('location') + if location and (location != resp.url): + return location + return None + + session = CustomRedirectSession() + r = session.get(urls_test[0]) + assert len(r.history) == 2 + assert r.status_code == 200 + assert r.history[0].status_code == 302 + assert r.history[0].is_redirect + assert r.history[1].status_code == 200 + assert not r.history[1].is_redirect + assert r.url == urls_test[2] + class TestCaseInsensitiveDict: