mirror of
https://github.com/kennethreitz/requests.git
synced 2026-06-05 22:50:18 +00:00
Merge pull request #3655 from nateprewitt/tzickel_post_redirect_w_streamable
Rewind readable body before POST redirect
This commit is contained in:
@@ -50,6 +50,7 @@ if is_py2:
|
||||
str = unicode
|
||||
basestring = basestring
|
||||
numeric_types = (int, long, float)
|
||||
integer_types = (int, long)
|
||||
|
||||
elif is_py3:
|
||||
from urllib.parse import urlparse, urlunparse, urljoin, urlsplit, urlencode, quote, unquote, quote_plus, unquote_plus, urldefrag
|
||||
@@ -64,3 +65,4 @@ elif is_py3:
|
||||
bytes = bytes
|
||||
basestring = (str, bytes)
|
||||
numeric_types = (int, float)
|
||||
integer_types = (int,)
|
||||
|
||||
@@ -100,6 +100,8 @@ class StreamConsumedError(RequestException, TypeError):
|
||||
class RetryError(RequestException):
|
||||
"""Custom retries logic failed"""
|
||||
|
||||
class UnrewindableBodyError(RequestException):
|
||||
"""Requests encountered an error when trying to rewind a body"""
|
||||
|
||||
# Warnings
|
||||
|
||||
|
||||
@@ -291,6 +291,8 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
|
||||
self.body = None
|
||||
#: dictionary of callback hooks, for internal usage.
|
||||
self.hooks = default_hooks()
|
||||
#: integer denoting starting position of a readable file-like body.
|
||||
self._body_position = None
|
||||
|
||||
def prepare(self, method=None, url=None, headers=None, files=None,
|
||||
data=None, params=None, auth=None, cookies=None, hooks=None, json=None):
|
||||
@@ -320,6 +322,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
|
||||
p._cookies = _copy_cookie_jar(self._cookies)
|
||||
p.body = self.body
|
||||
p.hooks = self.hooks
|
||||
p._body_position = self._body_position
|
||||
return p
|
||||
|
||||
def prepare_method(self, method):
|
||||
@@ -447,6 +450,17 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
|
||||
if is_stream:
|
||||
body = data
|
||||
|
||||
if getattr(body, 'tell', None) is not None:
|
||||
# Record the current file position before reading.
|
||||
# This will allow us to rewind a file in the event
|
||||
# of a redirect.
|
||||
try:
|
||||
self._body_position = body.tell()
|
||||
except (IOError, OSError):
|
||||
# This differentiates from None, allowing us to catch
|
||||
# a failed `tell()` later when trying to rewind the body
|
||||
self._body_position = object()
|
||||
|
||||
if files:
|
||||
raise NotImplementedError('Streamed bodies and files are mutually exclusive.')
|
||||
|
||||
|
||||
+13
-1
@@ -28,7 +28,7 @@ from .adapters import HTTPAdapter
|
||||
|
||||
from .utils import (
|
||||
requote_uri, get_environ_proxies, get_netrc_auth, should_bypass_proxies,
|
||||
get_auth_from_url
|
||||
get_auth_from_url, rewind_body
|
||||
)
|
||||
|
||||
from .status_codes import codes
|
||||
@@ -164,6 +164,18 @@ class SessionRedirectMixin(object):
|
||||
proxies = self.rebuild_proxies(prepared_request, proxies)
|
||||
self.rebuild_auth(prepared_request, resp)
|
||||
|
||||
# A failed tell() sets `_body_position` to `object()`. This non-None
|
||||
# value ensures `rewindable` will be True, allowing us to raise an
|
||||
# UnrewindableBodyError, instead of hanging the connection.
|
||||
rewindable = (
|
||||
prepared_request._body_position is not None and
|
||||
('Content-Length' in headers or 'Transfer-Encoding' in headers)
|
||||
)
|
||||
|
||||
# Attempt to rewind consumed file-like object.
|
||||
if rewindable:
|
||||
rewind_body(prepared_request)
|
||||
|
||||
# Override the original request.
|
||||
req = prepared_request
|
||||
|
||||
|
||||
+19
-3
@@ -23,11 +23,13 @@ from . import certs
|
||||
# to_native_string is unused here, but imported here for backwards compatibility
|
||||
from ._internal_utils import to_native_string
|
||||
from .compat import parse_http_list as _parse_list_header
|
||||
from .compat import (quote, urlparse, bytes, str, OrderedDict, unquote,
|
||||
getproxies, proxy_bypass, urlunparse, basestring)
|
||||
from .compat import (
|
||||
quote, urlparse, bytes, str, OrderedDict, unquote, getproxies,
|
||||
proxy_bypass, urlunparse, basestring, integer_types)
|
||||
from .cookies import RequestsCookieJar, cookiejar_from_dict
|
||||
from .structures import CaseInsensitiveDict
|
||||
from .exceptions import InvalidURL, InvalidHeader, FileModeWarning
|
||||
from .exceptions import (
|
||||
InvalidURL, InvalidHeader, FileModeWarning, UnrewindableBodyError)
|
||||
|
||||
_hush_pyflakes = (RequestsCookieJar,)
|
||||
|
||||
@@ -809,3 +811,17 @@ def urldefragauth(url):
|
||||
netloc = netloc.rsplit('@', 1)[-1]
|
||||
|
||||
return urlunparse((scheme, netloc, path, params, query, ''))
|
||||
|
||||
def rewind_body(prepared_request):
|
||||
"""Move file pointer back to its recorded starting position
|
||||
so it can be read again on redirect.
|
||||
"""
|
||||
body_seek = getattr(prepared_request.body, 'seek', None)
|
||||
if body_seek is not None and isinstance(prepared_request._body_position, integer_types):
|
||||
try:
|
||||
body_seek(prepared_request._body_position)
|
||||
except (IOError, OSError):
|
||||
raise UnrewindableBodyError("An error occured when rewinding request "
|
||||
"body for redirect.")
|
||||
else:
|
||||
raise UnrewindableBodyError("Unable to rewind request body for redirect.")
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ coverage==4.0.3
|
||||
decorator==4.0.9
|
||||
docutils==0.12
|
||||
Flask==0.10.1
|
||||
httpbin==0.4.1
|
||||
httpbin==0.5.0
|
||||
itsdangerous==0.24
|
||||
Jinja2==2.8
|
||||
MarkupSafe==0.23
|
||||
|
||||
+117
-1
@@ -24,7 +24,7 @@ from requests.cookies import (
|
||||
from requests.exceptions import (
|
||||
ConnectionError, ConnectTimeout, InvalidSchema, InvalidURL,
|
||||
MissingSchema, ReadTimeout, Timeout, RetryError, TooManyRedirects,
|
||||
ProxyError, InvalidHeader)
|
||||
ProxyError, InvalidHeader, UnrewindableBodyError)
|
||||
from requests.models import PreparedRequest
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
from requests.sessions import SessionRedirectMixin
|
||||
@@ -165,6 +165,21 @@ class TestRequests:
|
||||
assert r.history[0].status_code == 302
|
||||
assert r.history[0].is_redirect
|
||||
|
||||
def test_HTTP_307_ALLOW_REDIRECT_POST(self, httpbin):
|
||||
r = requests.post(httpbin('redirect-to'), data='test', params={'url': 'post', 'status_code': 307})
|
||||
assert r.status_code == 200
|
||||
assert r.history[0].status_code == 307
|
||||
assert r.history[0].is_redirect
|
||||
assert r.json()['data'] == 'test'
|
||||
|
||||
def test_HTTP_307_ALLOW_REDIRECT_POST_WITH_SEEKABLE(self, httpbin):
|
||||
byte_str = b'test'
|
||||
r = requests.post(httpbin('redirect-to'), data=io.BytesIO(byte_str), params={'url': 'post', 'status_code': 307})
|
||||
assert r.status_code == 200
|
||||
assert r.history[0].status_code == 307
|
||||
assert r.history[0].is_redirect
|
||||
assert r.json()['data'] == byte_str.decode('utf-8')
|
||||
|
||||
def test_HTTP_302_TOO_MANY_REDIRECTS(self, httpbin):
|
||||
try:
|
||||
requests.get(httpbin('relative-redirect', '50'))
|
||||
@@ -1386,6 +1401,107 @@ class TestRequests:
|
||||
r3 = next(rg)
|
||||
assert not r3.is_redirect
|
||||
|
||||
def test_prepare_body_position_non_stream(self):
|
||||
data = b'the data'
|
||||
s = requests.Session()
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position is None
|
||||
|
||||
def test_rewind_body(self):
|
||||
data = io.BytesIO(b'the data')
|
||||
s = requests.Session()
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position == 0
|
||||
assert prep.body.read() == b'the data'
|
||||
|
||||
# the data has all been read
|
||||
assert prep.body.read() == b''
|
||||
|
||||
# rewind it back
|
||||
requests.utils.rewind_body(prep)
|
||||
assert prep.body.read() == b'the data'
|
||||
|
||||
def test_rewind_partially_read_body(self):
|
||||
data = io.BytesIO(b'the data')
|
||||
s = requests.Session()
|
||||
data.read(4) # read some data
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position == 4
|
||||
assert prep.body.read() == b'data'
|
||||
|
||||
# the data has all been read
|
||||
assert prep.body.read() == b''
|
||||
|
||||
# rewind it back
|
||||
requests.utils.rewind_body(prep)
|
||||
assert prep.body.read() == b'data'
|
||||
|
||||
def test_rewind_body_no_seek(self):
|
||||
class BadFileObj:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def tell(self):
|
||||
return 0
|
||||
|
||||
def __iter__(self):
|
||||
return
|
||||
|
||||
data = BadFileObj('the data')
|
||||
s = requests.Session()
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position == 0
|
||||
|
||||
with pytest.raises(UnrewindableBodyError) as e:
|
||||
requests.utils.rewind_body(prep)
|
||||
|
||||
assert 'Unable to rewind request body' in str(e)
|
||||
|
||||
def test_rewind_body_failed_seek(self):
|
||||
class BadFileObj:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def tell(self):
|
||||
return 0
|
||||
|
||||
def seek(self, pos):
|
||||
raise OSError()
|
||||
|
||||
def __iter__(self):
|
||||
return
|
||||
|
||||
data = BadFileObj('the data')
|
||||
s = requests.Session()
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position == 0
|
||||
|
||||
with pytest.raises(UnrewindableBodyError) as e:
|
||||
requests.utils.rewind_body(prep)
|
||||
|
||||
assert 'error occured when rewinding request body' in str(e)
|
||||
|
||||
def test_rewind_body_failed_tell(self):
|
||||
class BadFileObj:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def tell(self):
|
||||
raise OSError()
|
||||
|
||||
def __iter__(self):
|
||||
return
|
||||
|
||||
data = BadFileObj('the data')
|
||||
s = requests.Session()
|
||||
prep = requests.Request('GET', 'http://example.com', data=data).prepare()
|
||||
assert prep._body_position is not None
|
||||
|
||||
with pytest.raises(UnrewindableBodyError) as e:
|
||||
requests.utils.rewind_body(prep)
|
||||
|
||||
assert 'Unable to rewind request body' in str(e)
|
||||
|
||||
def _patch_adapter_gzipped_redirect(self, session, url):
|
||||
adapter = session.get_adapter(url=url)
|
||||
org_build_response = adapter.build_response
|
||||
|
||||
Reference in New Issue
Block a user