diff --git a/HISTORY.rst b/HISTORY.rst index 19c2ff1b..896dca4c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -14,6 +14,11 @@ Release History - The ``timeout`` parameter now affects requests with both ``stream=True`` and ``stream=False`` equally. +**Bugfixes** + +- No longer expose Authorization or Proxy-Authorization headers on redirect. + Fix CVE-2014-1829 and CVE-2014-1830 respectively. + 2.2.1 (2014-01-23) ++++++++++++++++++ diff --git a/docs/user/install.rst b/docs/user/install.rst index 718c9c51..d890fe81 100644 --- a/docs/user/install.rst +++ b/docs/user/install.rst @@ -10,7 +10,7 @@ The first step to using any software package is getting it properly installed. Distribute & Pip ---------------- -Installing Requests is simple with `pip `_:: +Installing Requests is simple with `pip `_, just run this in your terminal:: $ pip install requests @@ -21,16 +21,6 @@ or, with `easy_install `_:: But, you really `shouldn't do that `_. - -Cheeseshop (PyPI) Mirror ------------------------- - -If the Cheeseshop (a.k.a. PyPI) is down, you can also install Requests from one -of the mirrors. `Crate.io `_ is one of them:: - - $ pip install -i http://simple.crate.io/ requests - - Get the Code ------------ @@ -39,7 +29,7 @@ Requests is actively developed on GitHub, where the code is You can either clone the public repository:: - git clone git://github.com/kennethreitz/requests.git + $ git clone git://github.com/kennethreitz/requests.git Download the `tarball `_:: diff --git a/requests/models.py b/requests/models.py index e2fa09f8..7390d1c6 100644 --- a/requests/models.py +++ b/requests/models.py @@ -408,9 +408,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): is_stream = all([ hasattr(data, '__iter__'), - not isinstance(data, basestring), - not isinstance(data, list), - not isinstance(data, dict) + not isinstance(data, (basestring, list, tuple, dict)) ]) try: diff --git a/requests/packages/urllib3/__init__.py b/requests/packages/urllib3/__init__.py index 73071f70..086387f3 100644 --- a/requests/packages/urllib3/__init__.py +++ b/requests/packages/urllib3/__init__.py @@ -10,7 +10,7 @@ urllib3 - Thread-safe connection pooling and re-using. __author__ = 'Andrey Petrov (andrey.petrov@shazow.net)' __license__ = 'MIT' -__version__ = 'dev' +__version__ = '1.8' from .connectionpool import ( diff --git a/requests/packages/urllib3/_collections.py b/requests/packages/urllib3/_collections.py index 5907b0dc..9cea3a44 100644 --- a/requests/packages/urllib3/_collections.py +++ b/requests/packages/urllib3/_collections.py @@ -4,7 +4,7 @@ # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from collections import MutableMapping +from collections import Mapping, MutableMapping try: from threading import RLock except ImportError: # Platform-specific: No threads available @@ -20,9 +20,10 @@ try: # Python 2.7+ from collections import OrderedDict except ImportError: from .packages.ordered_dict import OrderedDict +from .packages.six import itervalues -__all__ = ['RecentlyUsedContainer'] +__all__ = ['RecentlyUsedContainer', 'HTTPHeaderDict'] _Null = object() @@ -101,3 +102,104 @@ class RecentlyUsedContainer(MutableMapping): def keys(self): with self.lock: return self._container.keys() + + +class HTTPHeaderDict(MutableMapping): + """ + :param headers: + An iterable of field-value pairs. Must not contain multiple field names + when compared case-insensitively. + + :param kwargs: + Additional field-value pairs to pass in to ``dict.update``. + + A ``dict`` like container for storing HTTP Headers. + + Field names are stored and compared case-insensitively in compliance with + RFC 2616. Iteration provides the first case-sensitive key seen for each + case-insensitive pair. + + Using ``__setitem__`` syntax overwrites fields that compare equal + case-insensitively in order to maintain ``dict``'s api. For fields that + compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add`` + in a loop. + + If multiple fields that are equal case-insensitively are passed to the + constructor or ``.update``, the behavior is undefined and some will be + lost. + + >>> headers = HTTPHeaderDict() + >>> headers.add('Set-Cookie', 'foo=bar') + >>> headers.add('set-cookie', 'baz=quxx') + >>> headers['content-length'] = '7' + >>> headers['SET-cookie'] + 'foo=bar, baz=quxx' + >>> headers['Content-Length'] + '7' + + If you want to access the raw headers with their original casing + for debugging purposes you can access the private ``._data`` attribute + which is a normal python ``dict`` that maps the case-insensitive key to a + list of tuples stored as (case-sensitive-original-name, value). Using the + structure from above as our example: + + >>> headers._data + {'set-cookie': [('Set-Cookie', 'foo=bar'), ('set-cookie', 'baz=quxx')], + 'content-length': [('content-length', '7')]} + """ + + def __init__(self, headers=None, **kwargs): + self._data = {} + if headers is None: + headers = {} + self.update(headers, **kwargs) + + def add(self, key, value): + """Adds a (name, value) pair, doesn't overwrite the value if it already + exists. + + >>> headers = HTTPHeaderDict(foo='bar') + >>> headers.add('Foo', 'baz') + >>> headers['foo'] + 'bar, baz' + """ + self._data.setdefault(key.lower(), []).append((key, value)) + + def getlist(self, key): + """Returns a list of all the values for the named field. Returns an + empty list if the key doesn't exist.""" + return self[key].split(', ') if key in self else [] + + def copy(self): + h = HTTPHeaderDict() + for key in self._data: + for rawkey, value in self._data[key]: + h.add(rawkey, value) + return h + + def __eq__(self, other): + if not isinstance(other, Mapping): + return False + other = HTTPHeaderDict(other) + return dict((k1, self[k1]) for k1 in self._data) == \ + dict((k2, other[k2]) for k2 in other._data) + + def __getitem__(self, key): + values = self._data[key.lower()] + return ', '.join(value[1] for value in values) + + def __setitem__(self, key, value): + self._data[key.lower()] = [(key, value)] + + def __delitem__(self, key): + del self._data[key.lower()] + + def __len__(self): + return len(self._data) + + def __iter__(self): + for headers in itervalues(self._data): + yield headers[0][0] + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, dict(self.items())) diff --git a/requests/packages/urllib3/connection.py b/requests/packages/urllib3/connection.py index 21247745..662bd2e4 100644 --- a/requests/packages/urllib3/connection.py +++ b/requests/packages/urllib3/connection.py @@ -4,6 +4,7 @@ # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import sys import socket from socket import timeout as SocketTimeout @@ -38,6 +39,7 @@ from .exceptions import ( ConnectTimeoutError, ) from .packages.ssl_match_hostname import match_hostname +from .packages import six from .util import ( assert_fingerprint, resolve_cert_reqs, @@ -53,27 +55,40 @@ port_by_scheme = { class HTTPConnection(_HTTPConnection, object): + """ + Based on httplib.HTTPConnection but provides an extra constructor + backwards-compatibility layer between older and newer Pythons. + """ + default_port = port_by_scheme['http'] # By default, disable Nagle's Algorithm. tcp_nodelay = 1 + def __init__(self, *args, **kw): + if six.PY3: # Python 3 + kw.pop('strict', None) + + if sys.version_info < (2, 7): # Python 2.6 and earlier + kw.pop('source_address', None) + self.source_address = None + + _HTTPConnection.__init__(self, *args, **kw) + def _new_conn(self): """ Establish a socket connection and set nodelay settings on it :return: a new socket connection """ - try: - conn = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address, - ) - except AttributeError: # Python 2.6 - conn = socket.create_connection( - (self.host, self.port), - self.timeout, - ) + extra_args = [] + if self.source_address: # Python 2.7+ + extra_args.append(self.source_address) + + conn = socket.create_connection( + (self.host, self.port), + self.timeout, + *extra_args + ) conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, self.tcp_nodelay) return conn @@ -95,10 +110,12 @@ class HTTPSConnection(HTTPConnection): def __init__(self, host, port=None, key_file=None, cert_file=None, strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): - try: - HTTPConnection.__init__(self, host, port, strict, timeout, source_address) - except TypeError: # Python 2.6 - HTTPConnection.__init__(self, host, port, strict, timeout) + + HTTPConnection.__init__(self, host, port, + strict=strict, + timeout=timeout, + source_address=source_address) + self.key_file = key_file self.cert_file = cert_file diff --git a/requests/packages/urllib3/connectionpool.py b/requests/packages/urllib3/connectionpool.py index 243d700e..6d0dbb18 100644 --- a/requests/packages/urllib3/connectionpool.py +++ b/requests/packages/urllib3/connectionpool.py @@ -19,6 +19,7 @@ except ImportError: from .exceptions import ( ClosedPoolError, + ConnectionError, ConnectTimeoutError, EmptyPoolError, HostChangedError, @@ -170,13 +171,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): log.info("Starting new HTTP connection (%d): %s" % (self.num_connections, self.host)) - extra_params = {} - if not six.PY3: # Python 2 - extra_params['strict'] = self.strict - conn = self.ConnectionCls(host=self.host, port=self.port, timeout=self.timeout.connect_timeout, - **extra_params) + strict=self.strict) if self.proxy is not None: # Enable Nagle's algorithm for proxies, to avoid packet # fragmentation. @@ -238,8 +235,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): pass except Full: # This should never happen if self.block == True - log.warning("HttpConnectionPool is full, discarding connection: %s" - % self.host) + log.warning( + "Connection pool is full, discarding connection: %s" % + self.host) # Connection never got put back into the pool, close it. if conn: @@ -414,10 +412,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): :param retries: Number of retries to allow before raising a MaxRetryError exception. + If `False`, then retries are disabled and any exception is raised + immediately. :param redirect: If True, automatically handle redirects (status codes 301, 302, - 303, 307, 308). Each redirect counts as a retry. + 303, 307, 308). Each redirect counts as a retry. Disabling retries + will disable redirect, too. :param assert_same_host: If ``True``, will make sure that the host of the pool requests is @@ -451,7 +452,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): if headers is None: headers = self.headers - if retries < 0: + if retries < 0 and retries is not False: raise MaxRetryError(self, url) if release_conn is None: @@ -470,6 +471,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): headers = headers.copy() headers.update(self.proxy_headers) + # Must keep the exception bound to a separate variable or else Python 3 + # complains about UnboundLocalError. + err = None + try: # Request a connection from the queue conn = self._get_conn(timeout=pool_timeout) @@ -497,38 +502,41 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): # ``response.read()``) except Empty: - # Timed out by queue + # Timed out by queue. raise EmptyPoolError(self, "No pool connections are available.") - except BaseSSLError as e: + except (BaseSSLError, CertificateError) as e: + # Release connection unconditionally because there is no way to + # close it externally in case of exception. + release_conn = True raise SSLError(e) - except CertificateError as e: - # Name mismatch - raise SSLError(e) + except (TimeoutError, HTTPException, SocketError) as e: + if conn: + # Discard the connection for these exceptions. It will be + # be replaced during the next _get_conn() call. + conn.close() + conn = None - except TimeoutError as e: - # Connection broken, discard. - conn = None - # Save the error off for retry logic. + if not retries: + if isinstance(e, TimeoutError): + # TimeoutError is exempt from MaxRetryError-wrapping. + # FIXME: ... Not sure why. Add a reason here. + raise + + # Wrap unexpected exceptions with the most appropriate + # module-level exception and re-raise. + if isinstance(e, SocketError) and self.proxy: + raise ProxyError('Cannot connect to proxy.', e) + + if retries is False: + raise ConnectionError('Connection failed.', e) + + raise MaxRetryError(self, url, e) + + # Keep track of the error for the retry warning. err = e - if retries == 0: - raise - - except (HTTPException, SocketError) as e: - # Connection broken, discard. It will be replaced next _get_conn(). - conn = None - # This is necessary so we can access e below - err = e - - if retries == 0: - if isinstance(e, SocketError) and self.proxy is not None: - raise ProxyError('Cannot connect to proxy. ' - 'Socket error: %s.' % e) - else: - raise MaxRetryError(self, url, e) - finally: if release_conn: # Put the connection back to be reused. If the connection is @@ -538,8 +546,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): if not conn: # Try again - log.warn("Retrying (%d attempts remain) after connection " - "broken by '%r': %s" % (retries, err, url)) + log.warning("Retrying (%d attempts remain) after connection " + "broken by '%r': %s" % (retries, err, url)) return self.urlopen(method, url, body, headers, retries - 1, redirect, assert_same_host, timeout=timeout, pool_timeout=pool_timeout, @@ -547,7 +555,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): # Handle redirect? redirect_location = redirect and response.get_redirect_location() - if redirect_location: + if redirect_location and retries is not False: if response.status == 303: method = 'GET' log.info("Redirecting %s -> %s" % (url, redirect_location)) diff --git a/requests/packages/urllib3/contrib/pyopenssl.py b/requests/packages/urllib3/contrib/pyopenssl.py index d9bda15a..7c513f3a 100644 --- a/requests/packages/urllib3/contrib/pyopenssl.py +++ b/requests/packages/urllib3/contrib/pyopenssl.py @@ -29,9 +29,8 @@ Module Variables ---------------- :var DEFAULT_SSL_CIPHER_LIST: The list of supported SSL/TLS cipher suites. - Default: ``EECDH+ECDSA+AESGCM EECDH+aRSA+AESGCM EECDH+ECDSA+SHA256 - EECDH+aRSA+SHA256 EECDH+aRSA+RC4 EDH+aRSA EECDH RC4 !aNULL !eNULL !LOW !3DES - !MD5 !EXP !PSK !SRP !DSS'`` + Default: ``ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES: + ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:!aNULL:!MD5:!DSS`` .. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication .. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) @@ -43,7 +42,7 @@ from ndg.httpsclient.subj_alt_name import SubjectAltName as BaseSubjectAltName import OpenSSL.SSL from pyasn1.codec.der import decoder as der_decoder from pyasn1.type import univ, constraint -from socket import _fileobject +from socket import _fileobject, timeout import ssl import select from cStringIO import StringIO @@ -69,12 +68,22 @@ _openssl_verify = { + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } -# Default SSL/TLS cipher list. -# Recommendation by https://community.qualys.com/blogs/securitylabs/2013/08/05/ -# configuring-apache-nginx-and-openssl-for-forward-secrecy -DEFAULT_SSL_CIPHER_LIST = 'EECDH+ECDSA+AESGCM EECDH+aRSA+AESGCM ' + \ - 'EECDH+ECDSA+SHA256 EECDH+aRSA+SHA256 EECDH+aRSA+RC4 EDH+aRSA ' + \ - 'EECDH RC4 !aNULL !eNULL !LOW !3DES !MD5 !EXP !PSK !SRP !DSS' +# A secure default. +# Sources for more information on TLS ciphers: +# +# - https://wiki.mozilla.org/Security/Server_Side_TLS +# - https://www.ssllabs.com/projects/best-practices/index.html +# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ +# +# The general intent is: +# - Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE), +# - prefer ECDHE over DHE for better performance, +# - prefer any AES-GCM over any AES-CBC for better performance and security, +# - use 3DES as fallback which is secure but slow, +# - disable NULL authentication, MD5 MACs and DSS for security reasons. +DEFAULT_SSL_CIPHER_LIST = "ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:" + \ + "ECDH+AES128:DH+AES:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:" + \ + "!aNULL:!MD5:!DSS" orig_util_HAS_SNI = util.HAS_SNI @@ -139,6 +148,13 @@ def get_subj_alt_name(peer_cert): class fileobject(_fileobject): + def _wait_for_sock(self): + rd, wd, ed = select.select([self._sock], [], [], + self._sock.gettimeout()) + if not rd: + raise timeout() + + def read(self, size=-1): # Use max, disallow tiny reads in a loop as they are very inefficient. # We never leave read() with any leftover data from a new recv() call @@ -156,6 +172,7 @@ class fileobject(_fileobject): try: data = self._sock.recv(rbufsize) except OpenSSL.SSL.WantReadError: + self._wait_for_sock() continue if not data: break @@ -183,6 +200,7 @@ class fileobject(_fileobject): try: data = self._sock.recv(left) except OpenSSL.SSL.WantReadError: + self._wait_for_sock() continue if not data: break @@ -234,6 +252,7 @@ class fileobject(_fileobject): break buffers.append(data) except OpenSSL.SSL.WantReadError: + self._wait_for_sock() continue break return "".join(buffers) @@ -244,6 +263,7 @@ class fileobject(_fileobject): try: data = self._sock.recv(self._rbufsize) except OpenSSL.SSL.WantReadError: + self._wait_for_sock() continue if not data: break @@ -271,7 +291,8 @@ class fileobject(_fileobject): try: data = self._sock.recv(self._rbufsize) except OpenSSL.SSL.WantReadError: - continue + self._wait_for_sock() + continue if not data: break left = size - buf_len @@ -366,6 +387,8 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, ctx.load_verify_locations(ca_certs, None) except OpenSSL.SSL.Error as e: raise ssl.SSLError('bad ca_certs: %r' % ca_certs, e) + else: + ctx.set_default_verify_paths() # Disable TLS compression to migitate CRIME attack (issue #309) OP_NO_COMPRESSION = 0x20000 diff --git a/requests/packages/urllib3/exceptions.py b/requests/packages/urllib3/exceptions.py index 98ef9abc..b4df831f 100644 --- a/requests/packages/urllib3/exceptions.py +++ b/requests/packages/urllib3/exceptions.py @@ -44,6 +44,11 @@ class ProxyError(HTTPError): pass +class ConnectionError(HTTPError): + "Raised when a normal connection fails." + pass + + class DecodeError(HTTPError): "Raised when automatic decoding based on Content-Type fails." pass diff --git a/requests/packages/urllib3/packages/ssl_match_hostname/__init__.py b/requests/packages/urllib3/packages/ssl_match_hostname/__init__.py index 3aa5b2e1..dd59a75f 100644 --- a/requests/packages/urllib3/packages/ssl_match_hostname/__init__.py +++ b/requests/packages/urllib3/packages/ssl_match_hostname/__init__.py @@ -7,7 +7,7 @@ except ImportError: from backports.ssl_match_hostname import CertificateError, match_hostname except ImportError: # Our vendored copy - from _implementation import CertificateError, match_hostname + from ._implementation import CertificateError, match_hostname # Not needed, but documenting what we provide. __all__ = ('CertificateError', 'match_hostname') diff --git a/requests/packages/urllib3/response.py b/requests/packages/urllib3/response.py index 6a1fe1a7..db441828 100644 --- a/requests/packages/urllib3/response.py +++ b/requests/packages/urllib3/response.py @@ -9,6 +9,7 @@ import logging import zlib import io +from ._collections import HTTPHeaderDict from .exceptions import DecodeError from .packages.six import string_types as basestring, binary_type from .util import is_fp_closed @@ -79,7 +80,10 @@ class HTTPResponse(io.IOBase): def __init__(self, body='', headers=None, status=0, version=0, reason=None, strict=0, preload_content=True, decode_content=True, original_response=None, pool=None, connection=None): - self.headers = headers or {} + + self.headers = HTTPHeaderDict() + if headers: + self.headers.update(headers) self.status = status self.version = version self.reason = reason @@ -249,17 +253,9 @@ class HTTPResponse(io.IOBase): with ``original_response=r``. """ - # Normalize headers between different versions of Python - headers = {} + headers = HTTPHeaderDict() for k, v in r.getheaders(): - # Python 3: Header keys are returned capitalised - k = k.lower() - - has_value = headers.get(k) - if has_value: # Python 3: Repeating header keys are unmerged. - v = ', '.join([has_value, v]) - - headers[k] = v + headers.add(k, v) # HTTPResponse objects in Python 3 don't have a .strict attribute strict = getattr(r, 'strict', 0) diff --git a/requests/sessions.py b/requests/sessions.py index 79ea7773..28b5da9b 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -162,8 +162,11 @@ class SessionRedirectMixin(object): proxies = self.rebuild_proxies(prepared_request, proxies) self.rebuild_auth(prepared_request, resp) + # Override the original request. + req = prepared_request + resp = self.send( - prepared_request, + req, stream=stream, timeout=timeout, verify=verify, @@ -583,7 +586,7 @@ class Session(SessionRedirectMixin): history.insert(0, r) # Get the last request made r = history.pop() - r.history = tuple(history) + r.history = history return r diff --git a/requests/utils.py b/requests/utils.py index 6f4eb500..b703ac4c 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -61,7 +61,7 @@ def super_len(o): return os.fstat(fileno).st_size if hasattr(o, 'getvalue'): - # e.g. BytesIO, cStringIO.StringI + # e.g. BytesIO, cStringIO.StringIO return len(o.getvalue()) diff --git a/test_requests.py b/test_requests.py index 1bebb1ad..55e7a3d1 100755 --- a/test_requests.py +++ b/test_requests.py @@ -8,6 +8,7 @@ import json import os import pickle import unittest +import collections import requests import pytest @@ -19,6 +20,9 @@ from requests.cookies import cookiejar_from_dict, morsel_to_cookie from requests.exceptions import InvalidURL, MissingSchema from requests.models import PreparedRequest, Response from requests.structures import CaseInsensitiveDict +from requests.sessions import SessionRedirectMixin +from requests.models import PreparedRequest, urlencode +from requests.hooks import default_hooks try: import StringIO @@ -212,6 +216,16 @@ class RequestsTestCase(unittest.TestCase): req_urls = [r.request.url for r in resp.history] assert urls == req_urls + def test_history_is_always_a_list(self): + """ + Show that even with redirects, Response.history is always a list. + """ + resp = requests.get(httpbin('get')) + assert isinstance(resp.history, list) + resp = requests.get(httpbin('redirect/1')) + assert isinstance(resp.history, list) + assert not isinstance(resp.history, tuple) + def test_headers_on_session_with_None_are_not_sent(self): """Do not send headers in Session.headers with None values.""" ses = requests.Session() @@ -1204,5 +1218,89 @@ class TestTimeout: assert 'Read timed out' in e.args[0].args[0] +SendCall = collections.namedtuple('SendCall', ('args', 'kwargs')) + + +class RedirectSession(SessionRedirectMixin): + def __init__(self, order_of_redirects): + self.redirects = order_of_redirects + self.calls = [] + self.max_redirects = 30 + self.cookies = {} + self.trust_env = False + + def send(self, *args, **kwargs): + self.calls.append(SendCall(args, kwargs)) + return self.build_response() + + def build_response(self): + request = self.calls[-1].args[0] + r = requests.Response() + + try: + r.status_code = int(self.redirects.pop(0)) + except IndexError: + r.status_code = 200 + + r.headers = CaseInsensitiveDict({'Location': '/'}) + r.raw = self._build_raw() + r.request = request + return r + + def _build_raw(self): + string = StringIO.StringIO('') + setattr(string, 'release_conn', lambda *args: args) + return string + + +class TestRedirects: + default_keyword_args = { + 'stream': False, + 'verify': True, + 'cert': None, + 'timeout': None, + 'allow_redirects': False, + 'proxies': {}, + } + + def test_requests_are_updated_each_time(self): + session = RedirectSession([303, 307]) + prep = requests.Request('POST', 'http://httpbin.org/post').prepare() + r0 = session.send(prep) + assert r0.request.method == 'POST' + assert session.calls[-1] == SendCall((r0.request,), {}) + redirect_generator = session.resolve_redirects(r0, prep) + for response in redirect_generator: + assert response.request.method == 'GET' + send_call = SendCall((response.request,), + TestRedirects.default_keyword_args) + assert session.calls[-1] == send_call + + +@pytest.fixture +def list_of_tuples(): + return [ + (('a', 'b'), ('c', 'd')), + (('c', 'd'), ('a', 'b')), + (('a', 'b'), ('c', 'd'), ('e', 'f')), + ] + + +def test_data_argument_accepts_tuples(list_of_tuples): + """ + Ensure that the data argument will accept tuples of strings + and properly encode them. + """ + for data in list_of_tuples: + p = PreparedRequest() + p.prepare( + method='GET', + url='http://www.example.com', + data=data, + hooks=default_hooks() + ) + assert p.body == urlencode(data) + + if __name__ == '__main__': unittest.main()