From 4fadcf2051d517d4e3dd425e369b965c46ec2065 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 21 Mar 2019 12:51:06 +0000 Subject: [PATCH] async redirects --- README.md | 1 - requests_async/sessions.py | 151 ++++++++++++++++++++++++++++++++++--- tests/conftest.py | 19 ++++- tests/test_redirects.py | 21 ++++++ 4 files changed, 178 insertions(+), 14 deletions(-) create mode 100644 tests/test_redirects.py diff --git a/README.md b/README.md index 3d19a58..585f1f4 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ Next set of things to deal with: * https support, and certificate checking. * streaming support for uploads and downloads. * connection pooling. -* async redirections. * async cookie persistence, for on-disk cookie stores. * make sure authentication works okay (does it use adapters, is the API broken there now?) * timeouts diff --git a/requests_async/sessions.py b/requests_async/sessions.py index 785833a..330db45 100644 --- a/requests_async/sessions.py +++ b/requests_async/sessions.py @@ -1,6 +1,21 @@ import datetime import requests from . import adapters +from requests.exceptions import TooManyRedirects, InvalidSchema, ChunkedEncodingError, ContentDecodingError +from requests.cookies import extract_cookies_to_jar, merge_cookies +from requests.status_codes import codes +from requests.utils import requote_uri +from urllib.parse import urlparse + + +def to_native_string(string, encoding='ascii'): + """Given a string object, regardless of type, returns a representation of + that string in the native string type, encoding and decoding where + necessary. This assumes ASCII unless told otherwise. + """ + if isinstance(string, str): + return string + return string.decode(encoding) class Session(requests.Session): @@ -130,11 +145,11 @@ class Session(requests.Session): requests.cookies.extract_cookies_to_jar(self.cookies, request, r.raw) - # Redirect resolving generator. - gen = self.resolve_redirects(r, request, **kwargs) - - # Resolve redirects if allowed. - history = [resp for resp in gen] if allow_redirects else [] + # Redirect resolving. + history = [] + if allow_redirects: + async for resp in self.resolve_redirects(r, request, **kwargs): + history.append(resp) # Shuffle things around if there's history. if history: @@ -145,15 +160,127 @@ class Session(requests.Session): r.history = history # If redirects aren't being followed, store the response on the Request for Response.next(). - if not allow_redirects: - try: - r._next = next( - self.resolve_redirects(r, request, yield_requests=True, **kwargs) - ) - except StopIteration: - pass + # if not allow_redirects: + # try: + # r._next = next( + # self.resolve_redirects(r, request, yield_requests=True, **kwargs) + # ) + # except StopIteration: + # pass if not stream: r.content return r + + async def resolve_redirects(self, resp, req, stream=False, timeout=None, + verify=True, cert=None, proxies=None, yield_requests=False, **adapter_kwargs): + """Receives a Response. Returns a generator of Responses or Requests.""" + hist = [] # keep track of history + + url = self.get_redirect_target(resp) + previous_fragment = urlparse(req.url).fragment + while url: + prepared_request = req.copy() + + # Update history and keep track of redirects. + # resp.history must ignore the original request in this loop + hist.append(resp) + resp.history = hist[1:] + + try: + resp.content # Consume socket so it can be released + except (ChunkedEncodingError, ContentDecodingError, RuntimeError): + resp.raw.read(decode_content=False) + + if len(resp.history) >= self.max_redirects: + raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp) + + # Release the connection back into the pool. + resp.close() + + # Handle redirection without scheme (see: RFC 1808 Section 4) + if url.startswith('//'): + parsed_rurl = urlparse(resp.url) + url = '%s:%s' % (to_native_string(parsed_rurl.scheme), url) + + # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2) + parsed = urlparse(url) + if parsed.fragment == '' and previous_fragment: + parsed = parsed._replace(fragment=previous_fragment) + elif parsed.fragment: + previous_fragment = parsed.fragment + url = parsed.geturl() + + # Facilitate relative 'location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + # Compliant with RFC3986, we percent encode the url. + if not parsed.netloc: + url = urljoin(resp.url, requote_uri(url)) + else: + url = requote_uri(url) + + prepared_request.url = to_native_string(url) + + self.rebuild_method(prepared_request, resp) + + # https://github.com/requests/requests/issues/1084 + if resp.status_code not in (codes.temporary_redirect, codes.permanent_redirect): + # https://github.com/requests/requests/issues/3490 + purged_headers = ('Content-Length', 'Content-Type', 'Transfer-Encoding') + for header in purged_headers: + prepared_request.headers.pop(header, None) + prepared_request.body = None + + headers = prepared_request.headers + try: + del headers['Cookie'] + except KeyError: + pass + + # Extract any cookies sent on the response to the cookiejar + # in the new request. Because we've mutated our copied prepared + # request, use the old one that we haven't yet touched. + extract_cookies_to_jar(prepared_request._cookies, req, resp.raw) + merge_cookies(prepared_request._cookies, self.cookies) + prepared_request.prepare_cookies(prepared_request._cookies) + + # Rebuild auth and proxy information. + proxies = self.rebuild_proxies(prepared_request, proxies) + self.rebuild_auth(prepared_request, resp) + + # A failed tell() sets `_body_position` to `object()`. This non-None + # value ensures `rewindable` will be True, allowing us to raise an + # UnrewindableBodyError, instead of hanging the connection. + rewindable = ( + prepared_request._body_position is not None and + ('Content-Length' in headers or 'Transfer-Encoding' in headers) + ) + + # Attempt to rewind consumed file-like object. + if rewindable: + rewind_body(prepared_request) + + # Override the original request. + req = prepared_request + + if yield_requests: + yield req + else: + + resp = await self.send( + req, + stream=stream, + timeout=timeout, + verify=verify, + cert=cert, + proxies=proxies, + allow_redirects=False, + **adapter_kwargs + ) + + extract_cookies_to_jar(self.cookies, prepared_request, resp.raw) + + # extract redirect url, if any, for the next loop + url = self.get_redirect_target(resp) + yield resp diff --git a/tests/conftest.py b/tests/conftest.py index cce96f9..b0119d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import asyncio import pytest from starlette.applications import Starlette -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, RedirectResponse from starlette.routing import Route from uvicorn.config import Config @@ -31,9 +31,26 @@ async def echo_form_data(request): ) +async def redirect1(request): + url = request.url_for('redirect2') + return RedirectResponse(url) + + +async def redirect2(request): + url = request.url_for('redirect3') + return RedirectResponse(url) + + +async def redirect3(request): + return JSONResponse({'hello': 'world'}) + + routes = [ Route("/", echo_request, methods=["GET", "DELETE", "OPTIONS", "POST", "PUT", "PATCH"]), Route("/echo_form_data", echo_form_data, methods=["POST", "PUT", "PATCH"]), + Route("/redirect1", redirect1, name='redirect1'), + Route("/redirect2", redirect2, name='redirect2'), + Route("/redirect3", redirect3, name='redirect3'), ] app = Starlette(routes=routes) diff --git a/tests/test_redirects.py b/tests/test_redirects.py new file mode 100644 index 0000000..b932bca --- /dev/null +++ b/tests/test_redirects.py @@ -0,0 +1,21 @@ +import requests_async +import pytest + + +@pytest.mark.asyncio +async def test_redirects(server): + url = "http://127.0.0.1:8000/redirect1" + response = await requests_async.get(url) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + assert response.url == "http://127.0.0.1:8000/redirect3" + assert len(response.history) == 2 + + +@pytest.mark.asyncio +async def test_redirects_disallowed(server): + url = "http://127.0.0.1:8000/redirect1" + response = await requests_async.get(url, allow_redirects=False) + assert response.status_code == 302 + assert response.url == "http://127.0.0.1:8000/redirect1" + assert len(response.history) == 0