mirror of
https://github.com/not-kennethreitz/requests-async.git
synced 2026-06-05 06:56:15 +00:00
async redirects
This commit is contained in:
@@ -9,7 +9,6 @@ Next set of things to deal with:
|
||||
* https support, and certificate checking.
|
||||
* streaming support for uploads and downloads.
|
||||
* connection pooling.
|
||||
* async redirections.
|
||||
* async cookie persistence, for on-disk cookie stores.
|
||||
* make sure authentication works okay (does it use adapters, is the API broken there now?)
|
||||
* timeouts
|
||||
|
||||
+139
-12
@@ -1,6 +1,21 @@
|
||||
import datetime
|
||||
import requests
|
||||
from . import adapters
|
||||
from requests.exceptions import TooManyRedirects, InvalidSchema, ChunkedEncodingError, ContentDecodingError
|
||||
from requests.cookies import extract_cookies_to_jar, merge_cookies
|
||||
from requests.status_codes import codes
|
||||
from requests.utils import requote_uri
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def to_native_string(string, encoding='ascii'):
|
||||
"""Given a string object, regardless of type, returns a representation of
|
||||
that string in the native string type, encoding and decoding where
|
||||
necessary. This assumes ASCII unless told otherwise.
|
||||
"""
|
||||
if isinstance(string, str):
|
||||
return string
|
||||
return string.decode(encoding)
|
||||
|
||||
|
||||
class Session(requests.Session):
|
||||
@@ -130,11 +145,11 @@ class Session(requests.Session):
|
||||
|
||||
requests.cookies.extract_cookies_to_jar(self.cookies, request, r.raw)
|
||||
|
||||
# Redirect resolving generator.
|
||||
gen = self.resolve_redirects(r, request, **kwargs)
|
||||
|
||||
# Resolve redirects if allowed.
|
||||
history = [resp for resp in gen] if allow_redirects else []
|
||||
# Redirect resolving.
|
||||
history = []
|
||||
if allow_redirects:
|
||||
async for resp in self.resolve_redirects(r, request, **kwargs):
|
||||
history.append(resp)
|
||||
|
||||
# Shuffle things around if there's history.
|
||||
if history:
|
||||
@@ -145,15 +160,127 @@ class Session(requests.Session):
|
||||
r.history = history
|
||||
|
||||
# If redirects aren't being followed, store the response on the Request for Response.next().
|
||||
if not allow_redirects:
|
||||
try:
|
||||
r._next = next(
|
||||
self.resolve_redirects(r, request, yield_requests=True, **kwargs)
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
# if not allow_redirects:
|
||||
# try:
|
||||
# r._next = next(
|
||||
# self.resolve_redirects(r, request, yield_requests=True, **kwargs)
|
||||
# )
|
||||
# except StopIteration:
|
||||
# pass
|
||||
|
||||
if not stream:
|
||||
r.content
|
||||
|
||||
return r
|
||||
|
||||
async def resolve_redirects(self, resp, req, stream=False, timeout=None,
|
||||
verify=True, cert=None, proxies=None, yield_requests=False, **adapter_kwargs):
|
||||
"""Receives a Response. Returns a generator of Responses or Requests."""
|
||||
hist = [] # keep track of history
|
||||
|
||||
url = self.get_redirect_target(resp)
|
||||
previous_fragment = urlparse(req.url).fragment
|
||||
while url:
|
||||
prepared_request = req.copy()
|
||||
|
||||
# Update history and keep track of redirects.
|
||||
# resp.history must ignore the original request in this loop
|
||||
hist.append(resp)
|
||||
resp.history = hist[1:]
|
||||
|
||||
try:
|
||||
resp.content # Consume socket so it can be released
|
||||
except (ChunkedEncodingError, ContentDecodingError, RuntimeError):
|
||||
resp.raw.read(decode_content=False)
|
||||
|
||||
if len(resp.history) >= self.max_redirects:
|
||||
raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp)
|
||||
|
||||
# Release the connection back into the pool.
|
||||
resp.close()
|
||||
|
||||
# Handle redirection without scheme (see: RFC 1808 Section 4)
|
||||
if url.startswith('//'):
|
||||
parsed_rurl = urlparse(resp.url)
|
||||
url = '%s:%s' % (to_native_string(parsed_rurl.scheme), url)
|
||||
|
||||
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
parsed = urlparse(url)
|
||||
if parsed.fragment == '' and previous_fragment:
|
||||
parsed = parsed._replace(fragment=previous_fragment)
|
||||
elif parsed.fragment:
|
||||
previous_fragment = parsed.fragment
|
||||
url = parsed.geturl()
|
||||
|
||||
# Facilitate relative 'location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
# Compliant with RFC3986, we percent encode the url.
|
||||
if not parsed.netloc:
|
||||
url = urljoin(resp.url, requote_uri(url))
|
||||
else:
|
||||
url = requote_uri(url)
|
||||
|
||||
prepared_request.url = to_native_string(url)
|
||||
|
||||
self.rebuild_method(prepared_request, resp)
|
||||
|
||||
# https://github.com/requests/requests/issues/1084
|
||||
if resp.status_code not in (codes.temporary_redirect, codes.permanent_redirect):
|
||||
# https://github.com/requests/requests/issues/3490
|
||||
purged_headers = ('Content-Length', 'Content-Type', 'Transfer-Encoding')
|
||||
for header in purged_headers:
|
||||
prepared_request.headers.pop(header, None)
|
||||
prepared_request.body = None
|
||||
|
||||
headers = prepared_request.headers
|
||||
try:
|
||||
del headers['Cookie']
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Extract any cookies sent on the response to the cookiejar
|
||||
# in the new request. Because we've mutated our copied prepared
|
||||
# request, use the old one that we haven't yet touched.
|
||||
extract_cookies_to_jar(prepared_request._cookies, req, resp.raw)
|
||||
merge_cookies(prepared_request._cookies, self.cookies)
|
||||
prepared_request.prepare_cookies(prepared_request._cookies)
|
||||
|
||||
# Rebuild auth and proxy information.
|
||||
proxies = self.rebuild_proxies(prepared_request, proxies)
|
||||
self.rebuild_auth(prepared_request, resp)
|
||||
|
||||
# A failed tell() sets `_body_position` to `object()`. This non-None
|
||||
# value ensures `rewindable` will be True, allowing us to raise an
|
||||
# UnrewindableBodyError, instead of hanging the connection.
|
||||
rewindable = (
|
||||
prepared_request._body_position is not None and
|
||||
('Content-Length' in headers or 'Transfer-Encoding' in headers)
|
||||
)
|
||||
|
||||
# Attempt to rewind consumed file-like object.
|
||||
if rewindable:
|
||||
rewind_body(prepared_request)
|
||||
|
||||
# Override the original request.
|
||||
req = prepared_request
|
||||
|
||||
if yield_requests:
|
||||
yield req
|
||||
else:
|
||||
|
||||
resp = await self.send(
|
||||
req,
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
allow_redirects=False,
|
||||
**adapter_kwargs
|
||||
)
|
||||
|
||||
extract_cookies_to_jar(self.cookies, prepared_request, resp.raw)
|
||||
|
||||
# extract redirect url, if any, for the next loop
|
||||
url = self.get_redirect_target(resp)
|
||||
yield resp
|
||||
|
||||
+18
-1
@@ -2,7 +2,7 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from uvicorn.config import Config
|
||||
@@ -31,9 +31,26 @@ async def echo_form_data(request):
|
||||
)
|
||||
|
||||
|
||||
async def redirect1(request):
|
||||
url = request.url_for('redirect2')
|
||||
return RedirectResponse(url)
|
||||
|
||||
|
||||
async def redirect2(request):
|
||||
url = request.url_for('redirect3')
|
||||
return RedirectResponse(url)
|
||||
|
||||
|
||||
async def redirect3(request):
|
||||
return JSONResponse({'hello': 'world'})
|
||||
|
||||
|
||||
routes = [
|
||||
Route("/", echo_request, methods=["GET", "DELETE", "OPTIONS", "POST", "PUT", "PATCH"]),
|
||||
Route("/echo_form_data", echo_form_data, methods=["POST", "PUT", "PATCH"]),
|
||||
Route("/redirect1", redirect1, name='redirect1'),
|
||||
Route("/redirect2", redirect2, name='redirect2'),
|
||||
Route("/redirect3", redirect3, name='redirect3'),
|
||||
]
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import requests_async
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirects(server):
|
||||
url = "http://127.0.0.1:8000/redirect1"
|
||||
response = await requests_async.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"hello": "world"}
|
||||
assert response.url == "http://127.0.0.1:8000/redirect3"
|
||||
assert len(response.history) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirects_disallowed(server):
|
||||
url = "http://127.0.0.1:8000/redirect1"
|
||||
response = await requests_async.get(url, allow_redirects=False)
|
||||
assert response.status_code == 302
|
||||
assert response.url == "http://127.0.0.1:8000/redirect1"
|
||||
assert len(response.history) == 0
|
||||
Reference in New Issue
Block a user