This commit is contained in:
2019-04-20 13:30:15 -04:00
parent 13de38df0a
commit 260cd50aec
85 changed files with 1788 additions and 2130 deletions
+1 -1
View File
@@ -138,7 +138,6 @@ logging.getLogger(__name__).addHandler(NullHandler())
warnings.simplefilter("default", FileModeWarning, append=True)
# -*- coding: utf-8 -*-
"""
requests.api
@@ -215,6 +214,7 @@ def get(url: types.URL, *, params: types.Params = None, **kwargs) -> types.Respo
kwargs.setdefault("allow_redirects", True)
return request("get", url, params=params, **kwargs)
def head(url: types.URL, **kwargs) -> types.Response:
r"""Sends a HEAD request.
+2 -4
View File
@@ -10,22 +10,20 @@ Available hooks:
``response``:
The response generated from a Request.
"""
HOOKS = ['response']
HOOKS = ["response"]
def default_hooks():
return {event: [] for event in HOOKS}
# TODO: response is the only one
def dispatch_hook(key, hooks, hook_data, **kwargs):
"""Dispatches a hook dictionary on a given piece of data."""
hooks = hooks or {}
hooks = hooks.get(key)
if hooks:
if hasattr(hooks, '__call__'):
if hasattr(hooks, "__call__"):
hooks = [hooks]
for hook in hooks:
_hook_data = hook(hook_data, **kwargs)
+2 -2
View File
@@ -10,7 +10,7 @@ which depend on extremely few external helpers (such as compat)
from ._basics import builtin_str, str
def to_native_string(string, encoding='ascii'):
def to_native_string(string, encoding="ascii"):
"""Given a string object, regardless of type, returns a representation of
that string in the native string type, encoding and decoding where
necessary. This assumes ASCII unless told otherwise.
@@ -33,7 +33,7 @@ def unicode_is_ascii(u_string):
return None
try:
u_string.encode('ascii')
u_string.encode("ascii")
return True
except UnicodeEncodeError:
+13 -19
View File
@@ -37,7 +37,8 @@ class CaseInsensitiveDict(collections.MutableMapping):
operations are given keys that have equal ``.lower()``s, the
behavior is undefined.
"""
__slots__ = ('_store')
__slots__ = "_store"
def __init__(self, data=None, **kwargs):
self._store = collections.OrderedDict()
@@ -64,9 +65,7 @@ class CaseInsensitiveDict(collections.MutableMapping):
def lower_items(self):
"""Like iteritems(), but with all lowercase keys."""
return (
(lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()
)
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())
def __eq__(self, other):
if isinstance(other, collections.Mapping):
@@ -77,7 +76,6 @@ class CaseInsensitiveDict(collections.MutableMapping):
# Compare insensitively
return dict(self.lower_items()) == dict(other.lower_items())
# Copy is required
def copy(self):
return CaseInsensitiveDict(self._store.values())
@@ -96,14 +94,13 @@ class HTTPHeaderDict(CaseInsensitiveDict):
super(HTTPHeaderDict, self).__init__()
self.extend({} if data is None else data, **kwargs)
# We'll store tuples in the internal dictionary, but present them as a
# concatenated string when we use item access methods.
#
def __setitem__(self, key, val):
# Specialcase null values.
if (not isinstance(val, basestring)) and (val is not None):
raise ValueError('only string-type values (or None) are allowed')
raise ValueError("only string-type values (or None) are allowed")
super(HTTPHeaderDict, self).__setitem__(key, (val,))
@@ -113,12 +110,10 @@ class HTTPHeaderDict(CaseInsensitiveDict):
if len(val) == 1 and val[0] is None:
return val[0]
return ', '.join(val)
return ", ".join(val)
def lower_items(self):
return (
(lk, ', '.join(vals)) for (lk, (k, vals)) in self._store.items()
)
return ((lk, ", ".join(vals)) for (lk, (k, vals)) in self._store.items())
def copy(self):
return type(self)(self)
@@ -132,10 +127,10 @@ class HTTPHeaderDict(CaseInsensitiveDict):
"""Set a sequence of strings to the associated key - this will overwrite
any previously stored value."""
if not isinstance(values, (list, tuple)):
raise ValueError('argument is not sequence')
raise ValueError("argument is not sequence")
if any(not isinstance(v, basestring) for v in values):
raise ValueError('non-string items in sequence')
raise ValueError("non-string items in sequence")
if not values:
self.pop(key, None)
@@ -157,7 +152,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
value for this key, then the value will be appended to those values.
"""
if not isinstance(val, basestring):
raise ValueError('value must be a string-type object')
raise ValueError("value must be a string-type object")
self._extend(key, (val,))
@@ -168,8 +163,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
"""
if len(args) > 1:
raise TypeError(
f"extend() takes at most 1 positional "
"arguments ({len(args)} given)"
f"extend() takes at most 1 positional " "arguments ({len(args)} given)"
)
for other in args + (kwargs,):
@@ -177,7 +171,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
# See if looks like a HTTPHeaderDict (either urllib3's
# implementation or ours). If so, then we have to add values
# in one go for each key.
multiget = getattr(other, 'getlist', None)
multiget = getattr(other, "getlist", None)
if multiget:
for key in other:
self._extend(key, tuple(multiget(key)))
@@ -191,7 +185,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
if isinstance(iv, basestring):
self._extend(ik, (iv,))
elif any(not isinstance(v, basestring) for v in iv):
raise ValueError('non-string items in sequence')
raise ValueError("non-string items in sequence")
else:
self._extend(ik, tuple(iv))
@@ -216,7 +210,7 @@ class LookupDict(dict):
super(LookupDict, self).__init__()
def __repr__(self):
return f'<lookup \'{self.name}\'>'
return f"<lookup '{self.name}'>"
def __getitem__(self, key):
# We allow fall-through here, so values default to None
@@ -37,7 +37,8 @@ class CaseInsensitiveDict(collections.MutableMapping):
operations are given keys that have equal ``.lower()``s, the
behavior is undefined.
"""
__slots__ = ('_store')
__slots__ = "_store"
def __init__(self, data=None, **kwargs):
self._store = collections.OrderedDict()
@@ -64,9 +65,7 @@ class CaseInsensitiveDict(collections.MutableMapping):
def lower_items(self):
"""Like iteritems(), but with all lowercase keys."""
return (
(lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()
)
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())
def __eq__(self, other):
if isinstance(other, collections.Mapping):
@@ -77,7 +76,6 @@ class CaseInsensitiveDict(collections.MutableMapping):
# Compare insensitively
return dict(self.lower_items()) == dict(other.lower_items())
# Copy is required
def copy(self):
return CaseInsensitiveDict(self._store.values())
@@ -96,14 +94,13 @@ class HTTPHeaderDict(CaseInsensitiveDict):
super(HTTPHeaderDict, self).__init__()
self.extend({} if data is None else data, **kwargs)
# We'll store tuples in the internal dictionary, but present them as a
# concatenated string when we use item access methods.
#
def __setitem__(self, key, val):
# Specialcase null values.
if (not isinstance(val, basestring)) and (val is not None):
raise ValueError('only string-type values (or None) are allowed')
raise ValueError("only string-type values (or None) are allowed")
super(HTTPHeaderDict, self).__setitem__(key, (val,))
@@ -113,12 +110,10 @@ class HTTPHeaderDict(CaseInsensitiveDict):
if len(val) == 1 and val[0] is None:
return val[0]
return ', '.join(val)
return ", ".join(val)
def lower_items(self):
return (
(lk, ', '.join(vals)) for (lk, (k, vals)) in self._store.items()
)
return ((lk, ", ".join(vals)) for (lk, (k, vals)) in self._store.items())
def copy(self):
return type(self)(self)
@@ -132,10 +127,10 @@ class HTTPHeaderDict(CaseInsensitiveDict):
"""Set a sequence of strings to the associated key - this will overwrite
any previously stored value."""
if not isinstance(values, (list, tuple)):
raise ValueError('argument is not sequence')
raise ValueError("argument is not sequence")
if any(not isinstance(v, basestring) for v in values):
raise ValueError('non-string items in sequence')
raise ValueError("non-string items in sequence")
if not values:
self.pop(key, None)
@@ -157,7 +152,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
value for this key, then the value will be appended to those values.
"""
if not isinstance(val, basestring):
raise ValueError('value must be a string-type object')
raise ValueError("value must be a string-type object")
self._extend(key, (val,))
@@ -168,8 +163,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
"""
if len(args) > 1:
raise TypeError(
f"extend() takes at most 1 positional "
"arguments ({len(args)} given)"
f"extend() takes at most 1 positional " "arguments ({len(args)} given)"
)
for other in args + (kwargs,):
@@ -177,7 +171,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
# See if looks like a HTTPHeaderDict (either urllib3's
# implementation or ours). If so, then we have to add values
# in one go for each key.
multiget = getattr(other, 'getlist', None)
multiget = getattr(other, "getlist", None)
if multiget:
for key in other:
self._extend(key, tuple(multiget(key)))
@@ -191,7 +185,7 @@ class HTTPHeaderDict(CaseInsensitiveDict):
if isinstance(iv, basestring):
self._extend(ik, (iv,))
elif any(not isinstance(v, basestring) for v in iv):
raise ValueError('non-string items in sequence')
raise ValueError("non-string items in sequence")
else:
self._extend(ik, tuple(iv))
@@ -216,7 +210,7 @@ class LookupDict(dict):
super(LookupDict, self).__init__()
def __repr__(self):
return f'<lookup \'{self.name}\'>'
return f"<lookup '{self.name}'>"
def __getitem__(self, key):
# We allow fall-through here, so values default to None
+2 -5
View File
@@ -14,7 +14,7 @@ from typing import (
Dict,
)
from .import http_auth as auth
from . import http_auth as auth
from .http_models import Response, PreparedRequest
from .http_cookies import RequestsCookieJar
from .http_sessions import Session
@@ -50,10 +50,7 @@ Headers = Optional[Union[None, MutableMapping[Text, Text]]]
Cookies = Optional[Union[None, RequestsCookieJar, MutableMapping[Text, Text]]]
Files = Optional[MutableMapping[Text, IO]]
Auth = Union[
None,
Tuple[Text, Text],
auth.AuthBase,
Callable[[PreparedRequest], PreparedRequest],
None, Tuple[Text, Text], auth.AuthBase, Callable[[PreparedRequest], PreparedRequest]
]
Timeout = Union[None, float, Tuple[float, float]]
AllowRedirects = Optional[bool]
+9 -8
View File
@@ -3,10 +3,11 @@ import trio
from ._http import AsyncPoolManager, PoolManager
from ._http._backends import TrioBackend
from .import _http
from . import _http
__all__ = ["request", "blocking_request"]
async def request(
method,
url,
@@ -22,13 +23,13 @@ async def request(
if not pool:
pool = AsyncPoolManager(backend=TrioBackend())
return await pool.urlopen(
method=method,
url=url,
headers=headers,
preload_content=preload_content,
body=body,
**kwargs
)
method=method,
url=url,
headers=headers,
preload_content=preload_content,
body=body,
**kwargs
)
def blocking_request(
+42 -34
View File
@@ -4,11 +4,7 @@ urllib3 - Thread-safe connection pooling and re-using.
from __future__ import absolute_import
import warnings
from .connectionpool import (
HTTPConnectionPool,
HTTPSConnectionPool,
connection_from_url
)
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from . import exceptions
from .filepost import encode_multipart_formdata
@@ -22,47 +18,60 @@ from .util.retry import Retry
# Set default logging handler to avoid "No handler found" warnings.
import logging
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT'
__version__ = '2.0.dev0+bleach.spike.proof.of.concept.dont.use'
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT"
__version__ = "2.0.dev0+bleach.spike.proof.of.concept.dont.use"
__all__ = [
'HTTPConnectionPool',
'HTTPSConnectionPool',
'PoolManager',
'ProxyManager',
'HTTPResponse',
'Retry',
'Timeout',
'add_stderr_logger',
'connection_from_url',
'disable_warnings',
'encode_multipart_formdata',
'get_host',
'make_headers',
'proxy_from_url',
"HTTPConnectionPool",
"HTTPSConnectionPool",
"PoolManager",
"ProxyManager",
"HTTPResponse",
"Retry",
"Timeout",
"add_stderr_logger",
"connection_from_url",
"disable_warnings",
"encode_multipart_formdata",
"get_host",
"make_headers",
"proxy_from_url",
]
# For now we only support async on 3.6, because we use async generators
import sys
if sys.version_info >= (3, 6):
from ._async.connectionpool import (
HTTPConnectionPool as AsyncHTTPConnectionPool,
HTTPSConnectionPool as AsyncHTTPSConnectionPool)
HTTPSConnectionPool as AsyncHTTPSConnectionPool,
)
from ._async.poolmanager import (
PoolManager as AsyncPoolManager,
ProxyManager as AsyncProxyManager)
ProxyManager as AsyncProxyManager,
)
from ._async.response import HTTPResponse as AsyncHTTPResponse
__all__.extend(
('AsyncHTTPConnectionPool', 'AsyncHTTPSConnectionPool',
'AsyncPoolManager', 'AsyncProxyManager', 'AsyncHTTPResponse'))
(
"AsyncHTTPConnectionPool",
"AsyncHTTPSConnectionPool",
"AsyncPoolManager",
"AsyncProxyManager",
"AsyncHTTPResponse",
)
)
logging.getLogger(__name__).addHandler(NullHandler())
@@ -79,10 +88,10 @@ def add_stderr_logger(level=logging.DEBUG):
# even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(level)
logger.debug('Added a stderr logging handler to logger: %s', __name__)
logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler
@@ -94,18 +103,17 @@ del NullHandler
# shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them.
# SecurityWarning's always go off by default.
warnings.simplefilter('always', exceptions.SecurityWarning, append=True)
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter('default', exceptions.SubjectAltNameWarning, append=True)
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter('default', exceptions.InsecurePlatformWarning,
append=True)
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter('default', exceptions.SNIMissingWarning, append=True)
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning):
"""
Helper for quickly disabling all urllib3 warnings.
"""
warnings.simplefilter('ignore', category)
warnings.simplefilter("ignore", category)
+43 -58
View File
@@ -45,7 +45,7 @@ except ImportError:
# within two years of the current date, and no
# earlier than 6 months ago.
RECENT_DATE = datetime.date(2016, 1, 1)
_SUPPORTED_VERSIONS = frozenset([b'1.0', b'1.1'])
_SUPPORTED_VERSIONS = frozenset([b"1.0", b"1.1"])
# A sentinel object returned when some syscalls return EAGAIN.
_EAGAIN = object()
@@ -61,9 +61,9 @@ def _headers_to_native_string(headers):
# 3 and need to decode the headers using Latin1.
for n, v in headers:
if not isinstance(n, str):
n = n.decode('latin1')
n = n.decode("latin1")
if not isinstance(v, str):
v = v.decode('latin1')
v = v.decode("latin1")
yield (n, v)
@@ -74,11 +74,11 @@ def _stringify_headers(headers):
# TODO: revisit
for name, value in headers:
if isinstance(name, six.text_type):
name = name.encode('ascii')
name = name.encode("ascii")
if isinstance(value, six.text_type):
value = value.encode('latin-1')
value = value.encode("latin-1")
elif isinstance(value, int):
value = str(value).encode('ascii')
value = str(value).encode("ascii")
yield (name, value)
@@ -93,8 +93,6 @@ def _read_readable(readable):
yield datablock
# XX this should return an async iterator
def _make_body_iterable(body):
"""
@@ -122,17 +120,13 @@ def _make_body_iterable(body):
elif hasattr(body, "read"):
return _read_readable(body)
elif isinstance(body, collections.Iterable) and not isinstance(
body, six.text_type
):
elif isinstance(body, collections.Iterable) and not isinstance(body, six.text_type):
return body
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))
# XX this should return an async iterator
def _request_bytes_iterable(request, state_machine):
"""
@@ -158,7 +152,7 @@ def _response_from_h11(h11_response, body_object):
if h11_response.http_version not in _SUPPORTED_VERSIONS:
raise BadVersionError(h11_response.http_version)
version = b'HTTP/' + h11_response.http_version
version = b"HTTP/" + h11_response.http_version
our_response = Response(
status_code=h11_response.status_code,
headers=_headers_to_native_string(h11_response.headers),
@@ -175,9 +169,9 @@ def _build_tunnel_request(host, port, headers):
"""
target = "%s:%d" % (host, port)
if not isinstance(target, bytes):
target = target.encode('latin1')
target = target.encode("latin1")
tunnel_request = Request(method=b"CONNECT", target=target, headers=headers)
tunnel_request.add_host(host=host, port=port, scheme='http')
tunnel_request.add_host(host=host, port=port, scheme="http")
return tunnel_request
@@ -195,14 +189,14 @@ async def _start_http_request(request, state_machine, conn):
"""
# Before we begin, confirm that the state machine is ok.
if (
state_machine.our_state is not h11.IDLE or
state_machine.their_state is not h11.IDLE
state_machine.our_state is not h11.IDLE
or state_machine.their_state is not h11.IDLE
):
raise ProtocolError("Invalid internal state transition")
request_bytes_iterable = _request_bytes_iterable(request, state_machine)
# Hack around Python 2 lack of nonlocal
context = {'send_aborted': True, 'h11_response': None}
context = {"send_aborted": True, "h11_response": None}
async def next_bytes_to_send():
try:
@@ -210,7 +204,7 @@ async def _start_http_request(request, state_machine, conn):
except StopIteration:
# We successfully sent the whole body!
context['send_aborted'] = False
context["send_aborted"] = False
return None
def consume_bytes(data):
@@ -226,7 +220,7 @@ async def _start_http_request(request, state_machine, conn):
elif isinstance(event, h11.Response):
# We have our response! Save it and get out of here.
context['h11_response'] = event
context["h11_response"] = event
raise LoopAbort
else:
@@ -234,8 +228,8 @@ async def _start_http_request(request, state_machine, conn):
raise RuntimeError("Unexpected h11 event {}".format(event))
await conn.send_and_receive_for_a_while(next_bytes_to_send, consume_bytes)
assert context['h11_response'] is not None
if context['send_aborted']:
assert context["h11_response"] is not None
if context["send_aborted"]:
# Our state machine thinks we sent a bunch of data... but maybe we
# didn't! Maybe our send got cancelled while we were only half-way
# through sending the last chunk, and then h11 thinks we sent a
@@ -246,7 +240,7 @@ async def _start_http_request(request, state_machine, conn):
# state_machine.poison()
# XX kluge for now
state_machine._cstate.process_error(state_machine.our_role)
return context['h11_response']
return context["h11_response"]
async def _read_until_event(state_machine, conn):
@@ -281,6 +275,7 @@ class HTTP1Connection(object):
data is buffered it will issue one read syscall and return all of that
data. Buffering of response data must happen at a higher layer.
"""
# : Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
@@ -301,7 +296,9 @@ class HTTP1Connection(object):
self._host = host
self._port = port
self._socket_options = (
socket_options if socket_options is not _DEFAULT_SOCKET_OPTIONS else self.default_socket_options
socket_options
if socket_options is not _DEFAULT_SOCKET_OPTIONS
else self.default_socket_options
)
self._source_address = source_address
self._tunnel_host = tunnel_host
@@ -310,9 +307,7 @@ class HTTP1Connection(object):
self._sock = None
self._state_machine = h11.Connection(our_role=h11.CLIENT)
async def _wrap_socket(
self, conn, ssl_context, fingerprint, assert_hostname
):
async def _wrap_socket(self, conn, ssl_context, fingerprint, assert_hostname):
"""
Handles extra logic to wrap the socket in TLS magic.
"""
@@ -320,11 +315,9 @@ class HTTP1Connection(object):
if is_time_off:
warnings.warn(
(
'System time is way off (before {0}). This will probably '
'lead to SSL verification errors'
).format(
RECENT_DATE
),
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# XX need to know whether this is the proxy or the final host that
@@ -337,30 +330,24 @@ class HTTP1Connection(object):
check_host = check_host.rstrip(".")
conn = await conn.start_tls(check_host, ssl_context)
if fingerprint:
ssl_util.assert_fingerprint(
conn.getpeercert(binary_form=True), fingerprint
)
elif (
ssl_context.verify_mode != ssl.CERT_NONE and
assert_hostname is not False
):
ssl_util.assert_fingerprint(conn.getpeercert(binary_form=True), fingerprint)
elif ssl_context.verify_mode != ssl.CERT_NONE and assert_hostname is not False:
cert = conn.getpeercert()
if not cert.get('subjectAltName', ()):
if not cert.get("subjectAltName", ()):
warnings.warn(
(
'Certificate for {0} has no `subjectAltName`, falling '
'back to check for a `commonName` for now. This '
'feature is being removed by major browsers and '
'deprecated by RFC 2818. (See '
'https://github.com/shazow/urllib3/issues/497 for '
'details.)'.format(self._host)
"Certificate for {0} has no `subjectAltName`, falling "
"back to check for a `commonName` for now. This "
"feature is being removed by major browsers and "
"deprecated by RFC 2818. (See "
"https://github.com/shazow/urllib3/issues/497 for "
"details.)".format(self._host)
),
SubjectAltNameWarning,
)
ssl_util.match_hostname(cert, check_host)
self.is_verified = (
ssl_context.verify_mode == ssl.CERT_REQUIRED and
(assert_hostname is not False or fingerprint)
self.is_verified = ssl_context.verify_mode == ssl.CERT_REQUIRED and (
assert_hostname is not False or fingerprint
)
return conn
@@ -415,24 +402,22 @@ class HTTP1Connection(object):
extra_kw = {}
if self._source_address:
extra_kw['source_address'] = self._source_address
extra_kw["source_address"] = self._source_address
if self._socket_options:
extra_kw['socket_options'] = self._socket_options
extra_kw["socket_options"] = self._socket_options
# XX pass connect_timeout to backend
# This was factored out into a separate function to allow overriding
# by subclasses, but in the backend approach the way to to this is to
# provide a custom backend. (Composition >> inheritance.)
try:
conn = await self._backend.connect(
self._host, self._port, **extra_kw
)
conn = await self._backend.connect(self._host, self._port, **extra_kw)
# XX these two error handling blocks needs to be re-done in a
# backend-agnostic way
except socket.timeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self._host, connect_timeout),
"Connection to %s timed out. (connect timeout=%s)"
% (self._host, connect_timeout),
)
except socket.error as e:
@@ -501,7 +486,7 @@ class HTTP1Connection(object):
"""
our_state = self._state_machine.our_state
their_state = self._state_machine.their_state
return (our_state is h11.IDLE and their_state is h11.IDLE)
return our_state is h11.IDLE and their_state is h11.IDLE
def __aiter__(self):
return self
+40 -63
View File
@@ -66,12 +66,12 @@ def _add_transport_headers(headers):
This should be a bit smarter: in particular, it should allow for bad or
unexpected versions of these headers, particularly transfer-encoding.
"""
transfer_headers = ('content-length', 'transfer-encoding')
transfer_headers = ("content-length", "transfer-encoding")
for header_name in headers:
if header_name.lower() in transfer_headers:
return
headers['transfer-encoding'] = 'chunked'
headers["transfer-encoding"] = "chunked"
def _build_context(
@@ -97,14 +97,13 @@ def _build_context(
return context
# Pool objects
class ConnectionPool(object):
"""
Base class for all connection pools, such as
:class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`.
"""
scheme = None
QueueCls = queue.LifoQueue
@@ -116,9 +115,7 @@ class ConnectionPool(object):
self.port = port
def __str__(self):
return '%s(host=%r, port=%r)' % (
type(self).__name__, self.host, self.port
)
return "%s(host=%r, port=%r)" % (type(self).__name__, self.host, self.port)
def __enter__(self):
return self
@@ -198,7 +195,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Additional parameters are used to create fresh :class:`urllib3.connection.HTTPConnection`,
:class:`urllib3.connection.HTTPSConnection` instances.
"""
scheme = 'http'
scheme = "http"
ConnectionCls = HTTP1Connection
ResponseCls = HTTPResponse
@@ -238,7 +236,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Enable Nagle's algorithm for proxies, to avoid packet fragmentation.
# We cannot know if the user has added default socket options, so we cannot replace the
# list.
self.conn_kw.setdefault('socket_options', [])
self.conn_kw.setdefault("socket_options", [])
def _new_conn(self):
"""
@@ -247,7 +245,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.num_connections += 1
# TODO: Huge hack.
for kw in ('strict',):
for kw in ("strict",):
if kw in self.conn_kw:
self.conn_kw.pop(kw)
@@ -257,9 +255,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.host,
self.port or "80",
)
conn = self.ConnectionCls(
host=self.host, port=self.port, ** self.conn_kw
)
conn = self.ConnectionCls(host=self.host, port=self.port, **self.conn_kw)
return conn
async def _get_conn(self, timeout=None):
@@ -284,8 +280,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if self.block:
raise EmptyPoolError(
self,
"Pool reached maximum size and no more "
"connections are allowed.",
"Pool reached maximum size and no more " "connections are allowed.",
)
pass # Oh well, we'll create a new connection then
@@ -318,9 +313,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass
except queue.Full:
# This should never happen if self.block == True
log.warning(
"Connection pool is full, discarding connection: %s", self.host
)
log.warning("Connection pool is full, discarding connection: %s", self.host)
# Connection never got put back into the pool, close it.
if conn:
conn.close()
@@ -353,7 +346,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# See the above comment about EAGAIN in Python 3. In Python 2 we have
# to specifically catch it and throw the timeout error
if hasattr(err, 'errno') and err.errno in _blocking_errnos:
if hasattr(err, "errno") and err.errno in _blocking_errnos:
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
@@ -362,7 +355,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# case, rethrow the original. We need to do this because of:
# http://bugs.python.org/issue10272
# TODO: Can we remove this?
if 'timed out' in str(err) or 'did not complete (read)' in str(
if "timed out" in str(err) or "did not complete (read)" in str(
err
): # Python 2.6
raise ReadTimeoutError(
@@ -398,9 +391,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
raise
# TODO: We need to encapsulate our proxy logic in here somewhere.
request = Request(
method=method, target=url, headers=headers, body=body
)
request = Request(method=method, target=url, headers=headers, body=body)
host = self.host
port = self.port
scheme = self.scheme
@@ -421,17 +412,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
read_timeout = socket.getdefaulttimeout()
# Receive the response from the server
try:
response = await conn.send_request(
request, read_timeout=read_timeout
)
response = await conn.send_request(request, read_timeout=read_timeout)
except (SocketTimeout, BaseSSLError, SocketError) as e:
self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
raise
# AppEngine doesn't have a version attr.
http_version = getattr(conn, '_http_vsn_str', 'HTTP/?')
http_version = getattr(conn, "_http_vsn_str", "HTTP/?")
log.debug(
"%s://%s:%s \"%s %s %s\" %s",
'%s://%s:%s "%s %s %s" %s',
self.scheme,
self.host,
self.port,
@@ -443,9 +432,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
return response
def _absolute_url(self, path):
return Url(
scheme=self.scheme, host=self.host, port=self.port, path=path
).url
return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url
def close(self):
"""
@@ -469,7 +456,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Check if the given ``url`` is a member of the same host as this
connection pool.
"""
if url.startswith('/'):
if url.startswith("/"):
return True
# TODO: Add optional support for socket.gethostbyname checking.
@@ -555,9 +542,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if headers is None:
headers = self.headers
if not isinstance(retries, Retry):
retries = Retry.from_int(
retries, default=self.retries, redirect=False
)
retries = Retry.from_int(retries, default=self.retries, redirect=False)
conn = None
# Track whether `conn` needs to be released before
# returning/raising/recursing.
@@ -565,7 +550,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Merge the proxy headers. Only do this in HTTP. We have to copy the
# headers dict so we can safely change it without those changes being
# reflected in anyone else's copy.
if self.scheme == 'http':
if self.scheme == "http":
headers = headers.copy()
headers.update(self.proxy_headers)
# Must keep the exception bound to a separate variable or else Python 3
@@ -586,15 +571,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
conn.timeout = timeout_obj.connect_timeout
# Make the request on the base connection object.
base_response = await self._make_request(
conn,
method,
url,
timeout=timeout_obj,
body=body,
headers=headers,
conn, method, url, timeout=timeout_obj, body=body, headers=headers
)
# Pass method to Response for length checking
response_kw['request_method'] = method
response_kw["request_method"] = method
# Import httplib's response into our own wrapper object
response = self.ResponseCls.from_base(
base_response, pool=self, retries=retries, **response_kw
@@ -619,12 +599,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
clean_exit = False
if isinstance(e, (BaseSSLError, CertificateError)):
e = SSLError(e)
elif isinstance(
e, (SocketError, NewConnectionError)
) and self.proxy:
e = ProxyError('Cannot connect to proxy.', e)
elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy:
e = ProxyError("Cannot connect to proxy.", e)
elif isinstance(e, (SocketError, h11.ProtocolError)):
e = ProtocolError('Connection aborted.', e)
e = ProtocolError("Connection aborted.", e)
retries = retries.increment(
method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2]
)
@@ -679,12 +657,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass
# Check if we should retry the HTTP response.
has_retry_after = bool(response.getheader('Retry-After'))
has_retry_after = bool(response.getheader("Retry-After"))
if retries.is_retry(method, response.status, has_retry_after):
try:
retries = retries.increment(
method, url, response=response, _pool=self
)
retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError:
if retries.raise_on_status:
# Drain and release the connection for this response, since
@@ -730,7 +706,8 @@ class HTTPSConnectionPool(HTTPConnectionPool):
available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket.
"""
scheme = 'https'
scheme = "https"
def __init__(
self,
@@ -771,7 +748,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
raise SSLError("SSL module is not available")
if ca_certs and cert_reqs is None:
cert_reqs = 'CERT_REQUIRED'
cert_reqs = "CERT_REQUIRED"
self.ssl_context = _build_context(
ssl_context,
keyfile=key_file,
@@ -808,7 +785,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
tunnel_headers = self.proxy_headers
# TODO: Huge hack.
for kw in ('strict', 'redirect'):
for kw in ("strict", "redirect"):
if kw in self.conn_kw:
self.conn_kw.pop(kw)
@@ -818,7 +795,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
tunnel_host=tunnel_host,
tunnel_port=tunnel_port,
tunnel_headers=tunnel_headers,
** self.conn_kw
**self.conn_kw
)
return conn
@@ -835,10 +812,10 @@ class HTTPSConnectionPool(HTTPConnectionPool):
if not conn.is_verified:
warnings.warn(
(
'Unverified HTTPS request is being made. '
'Adding certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings'
"Unverified HTTPS request is being made. "
"Adding certificate verification is strongly advised. See: "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings"
),
InsecureRequestWarning,
)
@@ -866,7 +843,7 @@ def connection_from_url(url, **kw):
"""
scheme, host, port = get_host(url)
port = port or DEFAULT_PORTS.get(scheme, 80)
if scheme == 'https':
if scheme == "https":
return HTTPSConnectionPool(host, port=port, **kw)
else:
@@ -886,6 +863,6 @@ def _ipv6_host(host):
#
# Also if an IPv6 address literal has a zone identifier, the
# percent sign might be URIencoded, convert it back into ASCII
if host.startswith('[') and host.endswith(']'):
host = host.replace('%25', '%').strip('[]')
if host.startswith("[") and host.endswith("]"):
host = host.replace("%25", "%").strip("[]")
return host
+81 -101
View File
@@ -13,47 +13,47 @@ from ..util.url import parse_url
from ..util.request import set_file_position
from ..util.retry import Retry
__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url']
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
log = logging.getLogger(__name__)
SSL_KEYWORDS = (
'key_file',
'cert_file',
'cert_reqs',
'ca_certs',
'ssl_version',
'ca_cert_dir',
'ssl_context',
"key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
)
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
'key_scheme', # str
'key_host', # str
'key_strict',
'key_port', # int
'key_timeout', # int or float or Timeout
'key_retries', # int or Retry
'key_block', # bool
'key_source_address', # str
'key_key_file', # str
'key_cert_file', # str
'key_cert_reqs', # str
'key_ca_certs', # str
'key_ssl_version', # str
'key_ca_cert_dir', # str
'key_ssl_context', # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
'key_maxsize', # int
'key_headers', # dict
'key__proxy', # parsed proxy url
'key__proxy_headers', # dict
'key_socket_options', # list of (level (int), optname (int), value (int or str)) tuples
'key__socks_options', # dict
'key_assert_hostname', # bool or string
'key_assert_fingerprint', # str
"key_scheme", # str
"key_host", # str
"key_strict",
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
)
# : The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple('PoolKey', _key_fields)
PoolKey = collections.namedtuple("PoolKey", _key_fields)
def _default_key_normalizer(key_class, request_context):
@@ -78,21 +78,21 @@ def _default_key_normalizer(key_class, request_context):
"""
# Since we mutate the dictionary, make a copy first
context = request_context.copy()
context['scheme'] = context['scheme'].lower()
context['host'] = context['host'].lower()
context["scheme"] = context["scheme"].lower()
context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets
for key in ('headers', '_proxy_headers', '_socks_options'):
for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None:
context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a
# tuple.
socket_opts = context.get('socket_options')
socket_opts = context.get("socket_options")
if socket_opts is not None:
context['socket_options'] = tuple(socket_opts)
context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'.
for key in list(context.keys()):
context['key_' + key] = context.pop(key)
context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context
for field in key_class._fields:
if field not in context:
@@ -105,12 +105,10 @@ def _default_key_normalizer(key_class, request_context):
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.
key_fn_by_scheme = {
'http': functools.partial(_default_key_normalizer, PoolKey),
'https': functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {
'http': HTTPConnectionPool, 'https': HTTPSConnectionPool
"http": functools.partial(_default_key_normalizer, PoolKey),
"https": functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
class PoolManager(RequestMethods):
@@ -140,16 +138,13 @@ class PoolManager(RequestMethods):
2
"""
proxy = None
def __init__(
self, num_pools=10, headers=None, backend=None, **connection_pool_kw
):
def __init__(self, num_pools=10, headers=None, backend=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(
num_pools, dispose_func=lambda p: p.close()
)
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
@@ -181,9 +176,9 @@ class PoolManager(RequestMethods):
# this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can
# be removed.
for key in ('scheme', 'host', 'port'):
for key in ("scheme", "host", "port"):
request_context.pop(key, None)
if scheme == 'http':
if scheme == "http":
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)
return pool_cls(host, port, backend=self.backend, **request_context)
@@ -197,9 +192,7 @@ class PoolManager(RequestMethods):
"""
self.pools.clear()
def connection_from_host(
self, host, port=None, scheme='http', pool_kwargs=None
):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
"""
Get a :class:`ConnectionPool` based on the host, port, and scheme.
@@ -213,11 +206,11 @@ class PoolManager(RequestMethods):
raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs)
request_context['scheme'] = scheme or 'http'
request_context["scheme"] = scheme or "http"
if not port:
port = DEFAULT_PORTS.get(request_context['scheme'].lower(), 80)
request_context['port'] = port
request_context['host'] = host
port = DEFAULT_PORTS.get(request_context["scheme"].lower(), 80)
request_context["port"] = port
request_context["host"] = host
return self.connection_from_context(request_context)
def connection_from_context(self, request_context):
@@ -227,12 +220,10 @@ class PoolManager(RequestMethods):
``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable.
"""
scheme = request_context['scheme'].lower()
scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme[scheme]
pool_key = pool_key_constructor(request_context)
return self.connection_from_pool_key(
pool_key, request_context=request_context
)
return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None):
"""
@@ -250,12 +241,10 @@ class PoolManager(RequestMethods):
return pool
# Make a fresh ConnectionPool of the desired type
scheme = request_context['scheme']
host = request_context['host']
port = request_context['port']
pool = self._new_pool(
scheme, host, port, request_context=request_context
)
scheme = request_context["scheme"]
host = request_context["host"]
port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool
return pool
@@ -308,11 +297,11 @@ class PoolManager(RequestMethods):
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
# Rewind body position, if needed. Record current position
# for future rewinds in the event of a redirect/retry.
body = kw.get('body')
body_pos = kw.get('body_pos')
kw['body_pos'] = set_file_position(body, body_pos)
if 'headers' not in kw:
kw['headers'] = self.headers
body = kw.get("body")
body_pos = kw.get("body_pos")
kw["body_pos"] = set_file_position(body, body_pos)
if "headers" not in kw:
kw["headers"] = self.headers
if self.proxy is not None and u.scheme == "http":
response = await conn.urlopen(method, url, **kw)
else:
@@ -325,22 +314,20 @@ class PoolManager(RequestMethods):
redirect_location = urljoin(url, redirect_location)
# RFC 7231, Section 6.4.4
if response.status == 303:
method = 'GET'
retries = kw.get('retries')
method = "GET"
retries = kw.get("retries")
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect)
try:
retries = retries.increment(
method, url, response=response, _pool=conn
)
retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError:
if retries.raise_on_redirect:
raise
return response
kw['retries'] = retries
kw['redirect'] = redirect
kw["retries"] = retries
kw["redirect"] = redirect
retries.sleep_for_retry(response)
log.info("Redirecting %s -> %s", url, redirect_location)
return self.urlopen(method, redirect_location, **kw)
@@ -382,8 +369,10 @@ class ProxyManager(PoolManager):
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = '%s://%s:%i' % (
proxy_url.scheme, proxy_url.host, proxy_url.port
proxy_url = "%s://%s:%i" % (
proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url)
if not proxy.port:
@@ -394,25 +383,18 @@ class ProxyManager(PoolManager):
self.proxy = proxy
self.proxy_headers = proxy_headers or {}
connection_pool_kw['_proxy'] = self.proxy
connection_pool_kw['_proxy_headers'] = self.proxy_headers
super(ProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw
)
connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(
self, host, port=None, scheme='http', pool_kwargs=None
):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https":
return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host(
self.proxy.host,
self.proxy.port,
self.proxy.scheme,
pool_kwargs=pool_kwargs,
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None):
@@ -420,10 +402,10 @@ class ProxyManager(PoolManager):
Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user.
"""
headers_ = {'Accept': '*/*'}
headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc
if netloc:
headers_['Host'] = netloc
headers_["Host"] = netloc
if headers:
headers_.update(headers)
return headers_
@@ -435,11 +417,9 @@ class ProxyManager(PoolManager):
# For proxied HTTPS requests, httplib sets the necessary headers
# on the CONNECT to the proxy. For HTTP, we'll definitely
# need to set 'Host' at the very least.
headers = kw.get('headers', self.headers)
kw['headers'] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(
method, url, redirect=redirect, **kw
)
headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
def proxy_from_url(url, **kw):
+24 -28
View File
@@ -9,7 +9,7 @@ from socket import error as SocketError
import h11
from .._collections import HTTPHeaderDict
from ..exceptions import (ProtocolError, DecodeError, ReadTimeoutError)
from ..exceptions import ProtocolError, DecodeError, ReadTimeoutError
from ..packages.six import string_types as basestring, binary_type
from ..util.ssl_ import BaseSSLError
@@ -17,7 +17,6 @@ log = logging.getLogger(__name__)
class DeflateDecoder(object):
def __init__(self):
self._first_try = True
self._data = binary_type()
@@ -52,7 +51,6 @@ class DeflateDecoder(object):
class GzipDecoder(object):
def __init__(self):
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
@@ -67,7 +65,7 @@ class GzipDecoder(object):
def _get_decoder(mode):
if mode == 'gzip':
if mode == "gzip":
return GzipDecoder()
return DeflateDecoder()
@@ -97,12 +95,13 @@ class HTTPResponse(io.IOBase):
The retries contains the last :class:`~urllib3.util.retry.Retry` that
was used during the request.
"""
CONTENT_DECODERS = ['gzip', 'deflate']
CONTENT_DECODERS = ["gzip", "deflate"]
REDIRECT_STATUSES = [301, 302, 303, 307, 308]
def __init__(
self,
body='',
body="",
headers=None,
status=0,
version=0,
@@ -131,7 +130,7 @@ class HTTPResponse(io.IOBase):
self._fp = None
self._original_response = original_response
self._fp_bytes_read = 0
self._buffer = b''
self._buffer = b""
if body and isinstance(body, (basestring, binary_type)):
self._body = body
else:
@@ -151,7 +150,7 @@ class HTTPResponse(io.IOBase):
location. ``False`` if not a redirect status code.
"""
if self.status in self.REDIRECT_STATUSES:
return self.headers.get('location')
return self.headers.get("location")
return False
@@ -189,7 +188,7 @@ class HTTPResponse(io.IOBase):
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
if self._decoder is None and content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
@@ -201,7 +200,7 @@ class HTTPResponse(io.IOBase):
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding,
@@ -218,10 +217,10 @@ class HTTPResponse(io.IOBase):
being used.
"""
if self._decoder:
buf = self._decoder.decompress(b'')
buf = self._decoder.decompress(b"")
return buf + self._decoder.flush()
return b''
return b""
@contextmanager
def _error_catcher(self):
@@ -240,20 +239,20 @@ class HTTPResponse(io.IOBase):
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive:
if "read operation timed out" not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except (h11.ProtocolError, SocketError) as e:
# This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e)
raise ProtocolError("Connection broken: %r" % e, e)
except GeneratorExit:
# We swallow GeneratorExit when it is emitted: this allows the
@@ -305,7 +304,7 @@ class HTTPResponse(io.IOBase):
# data into the buffer. That's unfortunate, but right now I'm not smart
# enough to come up with a way to solve that problem.
if self._fp is None and not self._buffer:
return b''
return b""
data = self._buffer
with self._error_catcher():
@@ -313,8 +312,8 @@ class HTTPResponse(io.IOBase):
chunks = []
async for chunk in self.stream(decode_content):
chunks.append(chunk)
data += b''.join(chunks)
self._buffer = b''
data += b"".join(chunks)
self._buffer = b""
# We only cache the body data for simple read calls.
self._body = data
else:
@@ -330,7 +329,7 @@ class HTTPResponse(io.IOBase):
else:
chunks.append(chunk)
data_len += len(chunk)
data = b''.join(chunks)
data = b"".join(chunks)
self._buffer = data[amt:]
data = data[:amt]
return data
@@ -366,7 +365,7 @@ class HTTPResponse(io.IOBase):
# coverage. Happily, the code here is so simple that testing the
# branch we don't enter is basically entirely unnecessary (it's
# just a yield statement).
final_chunk = self._decode(b'', decode_content, flush_decoder=True)
final_chunk = self._decode(b"", decode_content, flush_decoder=True)
if final_chunk: # Platform-specific: Jython
yield final_chunk
@@ -382,7 +381,7 @@ class HTTPResponse(io.IOBase):
with ``original_response=r``.
"""
# TODO: Huge hack.
for kw in ('redirect', 'assert_same_host', 'enforce_content_length'):
for kw in ("redirect", "assert_same_host", "enforce_content_length"):
if kw in response_kw:
response_kw.pop(kw)
@@ -397,7 +396,6 @@ class HTTPResponse(io.IOBase):
)
return resp
# Backwards-compatibility methods for httplib.HTTPResponse
def getheaders(self):
return self.headers
@@ -405,17 +403,15 @@ class HTTPResponse(io.IOBase):
def getheader(self, name, default=None):
return self.headers.get(name, default)
# Backwards compatibility for http.cookiejar
def info(self):
return self.headers
# Overrides from io.IOBase
def close(self):
if not self.closed:
self._fp.close()
self._buffer = b''
self._buffer = b""
self._fp = None
if self._connection:
self._connection.close()
@@ -426,7 +422,7 @@ class HTTPResponse(io.IOBase):
if self._fp is None and not self._buffer:
return True
elif hasattr(self._fp, 'complete'):
elif hasattr(self._fp, "complete"):
return self._fp.complete
else:
@@ -457,5 +453,5 @@ class HTTPResponse(io.IOBase):
return 0
else:
b[:len(temp)] = temp
b[: len(temp)] = temp
return len(temp)
+3 -2
View File
@@ -1,9 +1,10 @@
from ..packages import six
from .sync_backend import SyncBackend
__all__ = ['SyncBackend']
__all__ = ["SyncBackend"]
if six.PY3:
from .trio_backend import TrioBackend
from .twisted_backend import TwistedBackend
__all__ += ['TrioBackend', 'TwistedBackend']
__all__ += ["TrioBackend", "TwistedBackend"]
@@ -26,4 +26,5 @@ class LoopAbort(Exception):
"""
Tell backends that enough bytes have been consumed
"""
pass
@@ -13,7 +13,6 @@ BUFSIZE = 65536
class SyncBackend(object):
def __init__(self, connect_timeout=None, read_timeout=None):
self._connect_timeout = connect_timeout
self._read_timeout = read_timeout
@@ -29,7 +28,6 @@ class SyncBackend(object):
class SyncSocket(object):
def __init__(self, sock, read_timeout):
self._sock = sock
self._read_timeout = read_timeout
@@ -40,14 +38,11 @@ class SyncSocket(object):
def start_tls(self, server_hostname, ssl_context):
self._sock.setblocking(True)
wrapped = ssl_wrap_socket(
self._sock,
server_hostname=server_hostname,
ssl_context=ssl_context,
self._sock, server_hostname=server_hostname, ssl_context=ssl_context
)
wrapped.setblocking(False)
return SyncSocket(wrapped, self._read_timeout)
# Only for SSL-wrapped sockets
def getpeercert(self, binary=False):
return self._sock.getpeercert(binary_form=binary)
+1 -10
View File
@@ -6,10 +6,7 @@ BUFSIZE = 65536
class TrioBackend:
async def connect(
self, host, port, source_address=None, socket_options=None
):
async def connect(self, host, port, source_address=None, socket_options=None):
if source_address is not None:
# You can't really combine source_address= and happy eyeballs
# (can we get rid of source_address? or at least make it a source
@@ -30,14 +27,11 @@ class TrioBackend:
return len(self) > other
# XX it turns out that we don't need SSLStream to be robustified against
# cancellation, but we probably should do something to detect when the stream
# has been broken by cancellation (e.g. a timeout) and make is_readable return
# True so the connection won't be reused.
class TrioSocket:
def __init__(self, stream):
self._stream = stream
@@ -57,7 +51,6 @@ class TrioSocket:
return await self._stream.receive_some(BUFSIZE)
async def send_and_receive_for_a_while(self, produce_bytes, consume_bytes):
async def sender():
while True:
outgoing = await produce_bytes()
@@ -78,7 +71,6 @@ class TrioSocket:
except LoopAbort:
pass
# Pull out the underlying trio socket, because it turns out HTTP is not so
# great at respecting abstraction boundaries.
def _socket(self):
@@ -89,7 +81,6 @@ class TrioSocket:
# Now we have a SocketStream
return stream.socket
# We want this to be synchronous, and don't care about graceful teardown
# of the SSL/TLS layer.
def forceful_close(self):
@@ -4,7 +4,10 @@ from twisted.internet import protocol, ssl
from twisted.internet.interfaces import IHandshakeListener
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
from twisted.internet.defer import (
Deferred, DeferredList, CancelledError, ensureDeferred
Deferred,
DeferredList,
CancelledError,
ensureDeferred,
)
from zope.interface import implementer
@@ -12,16 +15,12 @@ from ..contrib.pyopenssl import get_subj_alt_name
from ._common import LoopAbort
# XX need to add timeout support, esp. on connect
class TwistedBackend:
def __init__(self, reactor):
self._reactor = reactor
async def connect(
self, host, port, source_address=None, socket_options=None
):
async def connect(self, host, port, source_address=None, socket_options=None):
# HostnameEndpoint only supports setting source host, not source port
if source_address is not None:
raise NotImplementedError(
@@ -45,8 +44,6 @@ class TwistedBackend:
return TwistedSocket(protocol)
# enums
class _DATA_RECEIVED:
pass
@@ -62,7 +59,6 @@ class _HANDSHAKE_COMPLETED:
@implementer(IHandshakeListener)
class TwistedSocketProtocol(protocol.Protocol):
def connectionMade(self):
self._receive_buffer = bytearray()
self.transport.pauseProducing()
@@ -161,7 +157,6 @@ class TwistedSocketProtocol(protocol.Protocol):
class DoubleError(Exception):
def __init__(self, exc1, exc2):
self.exc1 = exc1
self.exc2 = exc2
@@ -171,7 +166,6 @@ class DoubleError(Exception):
class TwistedSocket:
def __init__(self, protocol):
self._protocol = protocol
@@ -185,9 +179,7 @@ class TwistedSocket:
return x509
if binary:
return OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_ASN1, x509
)
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
return {
"subject": ((("commonName", x509.get_subject().CN),),),
@@ -198,7 +190,6 @@ class TwistedSocket:
return await self._protocol.receive_some()
async def send_and_receive_for_a_while(self, produce_bytes, consume_bytes):
async def sender():
while True:
outgoing = await produce_bytes()
@@ -226,7 +217,6 @@ class TwistedSocket:
receive_loop.cancel()
return failure
# If the receive_loop errors out *or* exits cleanly due to LoopAbort,
# then cancel the send_loop and preserve the result
@receive_loop.addBoth
+14 -16
View File
@@ -9,7 +9,6 @@ try:
except ImportError: # Platform-specific: No threads available
class RLock:
def __enter__(self):
pass
@@ -24,7 +23,7 @@ except ImportError:
from .exceptions import InvalidHeader
from .packages.six import iterkeys, itervalues, PY3
__all__ = ['RecentlyUsedContainer', 'HTTPHeaderDict']
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object()
@@ -41,6 +40,7 @@ class RecentlyUsedContainer(MutableMapping):
Every time an item is evicted from the container,
``dispose_func(value)`` is called. Callback which will get called
"""
ContainerCls = OrderedDict
def __init__(self, maxsize=10, dispose_func=None):
@@ -81,7 +81,7 @@ class RecentlyUsedContainer(MutableMapping):
def __iter__(self):
raise NotImplementedError(
'Iteration over this class is unlikely to be threadsafe.'
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self):
@@ -149,7 +149,7 @@ class HTTPHeaderDict(MutableMapping):
def __getitem__(self, key):
val = self._container[key.lower()]
return ', '.join(val[1:])
return ", ".join(val[1:])
def __delitem__(self, key):
del self._container[key.lower()]
@@ -158,14 +158,13 @@ class HTTPHeaderDict(MutableMapping):
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return (
dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged())
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
(k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other):
@@ -185,9 +184,9 @@ class HTTPHeaderDict(MutableMapping):
yield vals[0]
def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
'''
"""
# Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel.
@@ -300,7 +299,7 @@ class HTTPHeaderDict(MutableMapping):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])
yield val[0], ", ".join(val[1:])
def items(self):
return list(self.iteritems())
@@ -311,7 +310,7 @@ class HTTPHeaderDict(MutableMapping):
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
obs_fold_continued_leaders = (' ', '\t')
obs_fold_continued_leaders = (" ", "\t")
headers = []
for line in message.headers:
if line.startswith(obs_fold_continued_leaders):
@@ -320,15 +319,14 @@ class HTTPHeaderDict(MutableMapping):
# in RFC-7230 S3.2.4. This indicates a multiline header, but
# there exists no previous header to which we can attach it.
raise InvalidHeader(
'Header continuation with no previous header: %s' %
line
"Header continuation with no previous header: %s" % line
)
else:
key, value = headers[-1]
headers[-1] = (key, value + ' ' + line.strip())
headers[-1] = (key, value + " " + line.strip())
continue
key, value = line.split(':', 1)
key, value = line.split(":", 1)
headers.append((key, value.strip()))
return cls(headers)
+46 -67
View File
@@ -45,7 +45,7 @@ except ImportError:
# within two years of the current date, and no
# earlier than 6 months ago.
RECENT_DATE = datetime.date(2016, 1, 1)
_SUPPORTED_VERSIONS = frozenset([b'1.0', b'1.1'])
_SUPPORTED_VERSIONS = frozenset([b"1.0", b"1.1"])
# A sentinel object returned when some syscalls return EAGAIN.
_EAGAIN = object()
@@ -61,9 +61,9 @@ def _headers_to_native_string(headers):
# 3 and need to decode the headers using Latin1.
for n, v in headers:
if not isinstance(n, str):
n = n.decode('latin1')
n = n.decode("latin1")
if not isinstance(v, str):
v = v.decode('latin1')
v = v.decode("latin1")
yield (n, v)
@@ -74,11 +74,11 @@ def _stringify_headers(headers):
# TODO: revisit
for name, value in headers:
if isinstance(name, six.text_type):
name = name.encode('ascii')
name = name.encode("ascii")
if isinstance(value, six.text_type):
value = value.encode('latin-1')
value = value.encode("latin-1")
elif isinstance(value, int):
value = str(value).encode('ascii')
value = str(value).encode("ascii")
yield (name, value)
@@ -93,8 +93,6 @@ def _read_readable(readable):
yield datablock
# XX this should return an async iterator
def _make_body_iterable(body):
"""
@@ -122,17 +120,13 @@ def _make_body_iterable(body):
elif hasattr(body, "read"):
return _read_readable(body)
elif isinstance(body, collections.Iterable) and not isinstance(
body, six.text_type
):
elif isinstance(body, collections.Iterable) and not isinstance(body, six.text_type):
return body
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))
# XX this should return an async iterator
def _request_bytes_iterable(request, state_machine):
"""
@@ -158,7 +152,7 @@ def _response_from_h11(h11_response, body_object):
if h11_response.http_version not in _SUPPORTED_VERSIONS:
raise BadVersionError(h11_response.http_version)
version = b'HTTP/' + h11_response.http_version
version = b"HTTP/" + h11_response.http_version
our_response = Response(
status_code=h11_response.status_code,
headers=_headers_to_native_string(h11_response.headers),
@@ -175,9 +169,9 @@ def _build_tunnel_request(host, port, headers):
"""
target = "%s:%d" % (host, port)
if not isinstance(target, bytes):
target = target.encode('latin1')
target = target.encode("latin1")
tunnel_request = Request(method=b"CONNECT", target=target, headers=headers)
tunnel_request.add_host(host=host, port=port, scheme='http')
tunnel_request.add_host(host=host, port=port, scheme="http")
return tunnel_request
@@ -195,14 +189,14 @@ def _start_http_request(request, state_machine, conn):
"""
# Before we begin, confirm that the state machine is ok.
if (
state_machine.our_state is not h11.IDLE or
state_machine.their_state is not h11.IDLE
state_machine.our_state is not h11.IDLE
or state_machine.their_state is not h11.IDLE
):
raise ProtocolError("Invalid internal state transition")
request_bytes_iterable = _request_bytes_iterable(request, state_machine)
# Hack around Python 2 lack of nonlocal
context = {'send_aborted': True, 'h11_response': None}
context = {"send_aborted": True, "h11_response": None}
def next_bytes_to_send():
try:
@@ -210,7 +204,7 @@ def _start_http_request(request, state_machine, conn):
except StopIteration:
# We successfully sent the whole body!
context['send_aborted'] = False
context["send_aborted"] = False
return None
def consume_bytes(data):
@@ -226,7 +220,7 @@ def _start_http_request(request, state_machine, conn):
elif isinstance(event, h11.Response):
# We have our response! Save it and get out of here.
context['h11_response'] = event
context["h11_response"] = event
raise LoopAbort
else:
@@ -234,8 +228,8 @@ def _start_http_request(request, state_machine, conn):
raise RuntimeError("Unexpected h11 event {}".format(event))
conn.send_and_receive_for_a_while(next_bytes_to_send, consume_bytes)
assert context['h11_response'] is not None
if context['send_aborted']:
assert context["h11_response"] is not None
if context["send_aborted"]:
# Our state machine thinks we sent a bunch of data... but maybe we
# didn't! Maybe our send got cancelled while we were only half-way
# through sending the last chunk, and then h11 thinks we sent a
@@ -246,7 +240,7 @@ def _start_http_request(request, state_machine, conn):
# state_machine.poison()
# XX kluge for now
state_machine._cstate.process_error(state_machine.our_role)
return context['h11_response']
return context["h11_response"]
def _read_until_event(state_machine, conn):
@@ -281,6 +275,7 @@ class HTTP1Connection(object):
data is buffered it will issue one read syscall and return all of that
data. Buffering of response data must happen at a higher layer.
"""
# : Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
@@ -301,7 +296,9 @@ class HTTP1Connection(object):
self._host = host
self._port = port
self._socket_options = (
socket_options if socket_options is not _DEFAULT_SOCKET_OPTIONS else self.default_socket_options
socket_options
if socket_options is not _DEFAULT_SOCKET_OPTIONS
else self.default_socket_options
)
self._source_address = source_address
self._tunnel_host = tunnel_host
@@ -310,9 +307,7 @@ class HTTP1Connection(object):
self._sock = None
self._state_machine = h11.Connection(our_role=h11.CLIENT)
def _wrap_socket(
self, conn, ssl_context, fingerprint, assert_hostname
):
def _wrap_socket(self, conn, ssl_context, fingerprint, assert_hostname):
"""
Handles extra logic to wrap the socket in TLS magic.
"""
@@ -320,11 +315,9 @@ class HTTP1Connection(object):
if is_time_off:
warnings.warn(
(
'System time is way off (before {0}). This will probably '
'lead to SSL verification errors'
).format(
RECENT_DATE
),
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# XX need to know whether this is the proxy or the final host that
@@ -337,30 +330,24 @@ class HTTP1Connection(object):
check_host = check_host.rstrip(".")
conn = conn.start_tls(check_host, ssl_context)
if fingerprint:
ssl_util.assert_fingerprint(
conn.getpeercert(binary_form=True), fingerprint
)
elif (
ssl_context.verify_mode != ssl.CERT_NONE and
assert_hostname is not False
):
ssl_util.assert_fingerprint(conn.getpeercert(binary_form=True), fingerprint)
elif ssl_context.verify_mode != ssl.CERT_NONE and assert_hostname is not False:
cert = conn.getpeercert()
if not cert.get('subjectAltName', ()):
if not cert.get("subjectAltName", ()):
warnings.warn(
(
'Certificate for {0} has no `subjectAltName`, falling '
'back to check for a `commonName` for now. This '
'feature is being removed by major browsers and '
'deprecated by RFC 2818. (See '
'https://github.com/shazow/urllib3/issues/497 for '
'details.)'.format(self._host)
"Certificate for {0} has no `subjectAltName`, falling "
"back to check for a `commonName` for now. This "
"feature is being removed by major browsers and "
"deprecated by RFC 2818. (See "
"https://github.com/shazow/urllib3/issues/497 for "
"details.)".format(self._host)
),
SubjectAltNameWarning,
)
ssl_util.match_hostname(cert, check_host)
self.is_verified = (
ssl_context.verify_mode == ssl.CERT_REQUIRED and
(assert_hostname is not False or fingerprint)
self.is_verified = ssl_context.verify_mode == ssl.CERT_REQUIRED and (
assert_hostname is not False or fingerprint
)
return conn
@@ -368,9 +355,7 @@ class HTTP1Connection(object):
"""
Given a Request object, performs the logic required to get a response.
"""
h11_response = _start_http_request(
request, self._state_machine, self._sock
)
h11_response = _start_http_request(request, self._state_machine, self._sock)
return _response_from_h11(h11_response, self)
def _tunnel(self, conn):
@@ -383,9 +368,7 @@ class HTTP1Connection(object):
self._tunnel_host, self._tunnel_port, self._tunnel_headers
)
tunnel_state_machine = h11.Connection(our_role=h11.CLIENT)
h11_response = _start_http_request(
tunnel_request, tunnel_state_machine, conn
)
h11_response = _start_http_request(tunnel_request, tunnel_state_machine, conn)
# XX this is wrong -- 'self' here will try to iterate using
# self._state_machine, not tunnel_state_machine. Also, we need to
# think about how this failure case interacts with the pool's
@@ -415,24 +398,22 @@ class HTTP1Connection(object):
extra_kw = {}
if self._source_address:
extra_kw['source_address'] = self._source_address
extra_kw["source_address"] = self._source_address
if self._socket_options:
extra_kw['socket_options'] = self._socket_options
extra_kw["socket_options"] = self._socket_options
# XX pass connect_timeout to backend
# This was factored out into a separate function to allow overriding
# by subclasses, but in the backend approach the way to to this is to
# provide a custom backend. (Composition >> inheritance.)
try:
conn = self._backend.connect(
self._host, self._port, **extra_kw
)
conn = self._backend.connect(self._host, self._port, **extra_kw)
# XX these two error handling blocks needs to be re-done in a
# backend-agnostic way
except socket.timeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self._host, connect_timeout),
"Connection to %s timed out. (connect timeout=%s)"
% (self._host, connect_timeout),
)
except socket.error as e:
@@ -443,9 +424,7 @@ class HTTP1Connection(object):
if ssl_context is not None:
if self._tunnel_host is not None:
self._tunnel(conn)
conn = self._wrap_socket(
conn, ssl_context, fingerprint, assert_hostname
)
conn = self._wrap_socket(conn, ssl_context, fingerprint, assert_hostname)
# XX We should pick one of these names and use it consistently...
self._sock = conn
@@ -501,7 +480,7 @@ class HTTP1Connection(object):
"""
our_state = self._state_machine.our_state
their_state = self._state_machine.their_state
return (our_state is h11.IDLE and their_state is h11.IDLE)
return our_state is h11.IDLE and their_state is h11.IDLE
def __iter__(self):
return self
+40 -63
View File
@@ -66,12 +66,12 @@ def _add_transport_headers(headers):
This should be a bit smarter: in particular, it should allow for bad or
unexpected versions of these headers, particularly transfer-encoding.
"""
transfer_headers = ('content-length', 'transfer-encoding')
transfer_headers = ("content-length", "transfer-encoding")
for header_name in headers:
if header_name.lower() in transfer_headers:
return
headers['transfer-encoding'] = 'chunked'
headers["transfer-encoding"] = "chunked"
def _build_context(
@@ -97,14 +97,13 @@ def _build_context(
return context
# Pool objects
class ConnectionPool(object):
"""
Base class for all connection pools, such as
:class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`.
"""
scheme = None
QueueCls = queue.LifoQueue
@@ -116,9 +115,7 @@ class ConnectionPool(object):
self.port = port
def __str__(self):
return '%s(host=%r, port=%r)' % (
type(self).__name__, self.host, self.port
)
return "%s(host=%r, port=%r)" % (type(self).__name__, self.host, self.port)
def __enter__(self):
return self
@@ -198,7 +195,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Additional parameters are used to create fresh :class:`urllib3.connection.HTTPConnection`,
:class:`urllib3.connection.HTTPSConnection` instances.
"""
scheme = 'http'
scheme = "http"
ConnectionCls = HTTP1Connection
ResponseCls = HTTPResponse
@@ -238,7 +236,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Enable Nagle's algorithm for proxies, to avoid packet fragmentation.
# We cannot know if the user has added default socket options, so we cannot replace the
# list.
self.conn_kw.setdefault('socket_options', [])
self.conn_kw.setdefault("socket_options", [])
def _new_conn(self):
"""
@@ -247,7 +245,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.num_connections += 1
# TODO: Huge hack.
for kw in ('strict',):
for kw in ("strict",):
if kw in self.conn_kw:
self.conn_kw.pop(kw)
@@ -257,9 +255,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.host,
self.port or "80",
)
conn = self.ConnectionCls(
host=self.host, port=self.port, ** self.conn_kw
)
conn = self.ConnectionCls(host=self.host, port=self.port, **self.conn_kw)
return conn
def _get_conn(self, timeout=None):
@@ -284,8 +280,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if self.block:
raise EmptyPoolError(
self,
"Pool reached maximum size and no more "
"connections are allowed.",
"Pool reached maximum size and no more " "connections are allowed.",
)
pass # Oh well, we'll create a new connection then
@@ -318,9 +313,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass
except queue.Full:
# This should never happen if self.block == True
log.warning(
"Connection pool is full, discarding connection: %s", self.host
)
log.warning("Connection pool is full, discarding connection: %s", self.host)
# Connection never got put back into the pool, close it.
if conn:
conn.close()
@@ -353,7 +346,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# See the above comment about EAGAIN in Python 3. In Python 2 we have
# to specifically catch it and throw the timeout error
if hasattr(err, 'errno') and err.errno in _blocking_errnos:
if hasattr(err, "errno") and err.errno in _blocking_errnos:
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
@@ -362,7 +355,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# case, rethrow the original. We need to do this because of:
# http://bugs.python.org/issue10272
# TODO: Can we remove this?
if 'timed out' in str(err) or 'did not complete (read)' in str(
if "timed out" in str(err) or "did not complete (read)" in str(
err
): # Python 2.6
raise ReadTimeoutError(
@@ -398,9 +391,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
raise
# TODO: We need to encapsulate our proxy logic in here somewhere.
request = Request(
method=method, target=url, headers=headers, body=body
)
request = Request(method=method, target=url, headers=headers, body=body)
host = self.host
port = self.port
scheme = self.scheme
@@ -421,17 +412,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
read_timeout = socket.getdefaulttimeout()
# Receive the response from the server
try:
response = conn.send_request(
request, read_timeout=read_timeout
)
response = conn.send_request(request, read_timeout=read_timeout)
except (SocketTimeout, BaseSSLError, SocketError) as e:
self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
raise
# AppEngine doesn't have a version attr.
http_version = getattr(conn, '_http_vsn_str', 'HTTP/?')
http_version = getattr(conn, "_http_vsn_str", "HTTP/?")
log.debug(
"%s://%s:%s \"%s %s %s\" %s",
'%s://%s:%s "%s %s %s" %s',
self.scheme,
self.host,
self.port,
@@ -443,9 +432,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
return response
def _absolute_url(self, path):
return Url(
scheme=self.scheme, host=self.host, port=self.port, path=path
).url
return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url
def close(self):
"""
@@ -469,7 +456,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Check if the given ``url`` is a member of the same host as this
connection pool.
"""
if url.startswith('/'):
if url.startswith("/"):
return True
# TODO: Add optional support for socket.gethostbyname checking.
@@ -555,9 +542,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if headers is None:
headers = self.headers
if not isinstance(retries, Retry):
retries = Retry.from_int(
retries, default=self.retries, redirect=False
)
retries = Retry.from_int(retries, default=self.retries, redirect=False)
conn = None
# Track whether `conn` needs to be released before
# returning/raising/recursing.
@@ -565,7 +550,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Merge the proxy headers. Only do this in HTTP. We have to copy the
# headers dict so we can safely change it without those changes being
# reflected in anyone else's copy.
if self.scheme == 'http':
if self.scheme == "http":
headers = headers.copy()
headers.update(self.proxy_headers)
# Must keep the exception bound to a separate variable or else Python 3
@@ -586,15 +571,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
conn.timeout = timeout_obj.connect_timeout
# Make the request on the base connection object.
base_response = self._make_request(
conn,
method,
url,
timeout=timeout_obj,
body=body,
headers=headers,
conn, method, url, timeout=timeout_obj, body=body, headers=headers
)
# Pass method to Response for length checking
response_kw['request_method'] = method
response_kw["request_method"] = method
# Import httplib's response into our own wrapper object
response = self.ResponseCls.from_base(
base_response, pool=self, retries=retries, **response_kw
@@ -619,12 +599,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
clean_exit = False
if isinstance(e, (BaseSSLError, CertificateError)):
e = SSLError(e)
elif isinstance(
e, (SocketError, NewConnectionError)
) and self.proxy:
e = ProxyError('Cannot connect to proxy.', e)
elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy:
e = ProxyError("Cannot connect to proxy.", e)
elif isinstance(e, (SocketError, h11.ProtocolError)):
e = ProtocolError('Connection aborted.', e)
e = ProtocolError("Connection aborted.", e)
retries = retries.increment(
method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2]
)
@@ -679,12 +657,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass
# Check if we should retry the HTTP response.
has_retry_after = bool(response.getheader('Retry-After'))
has_retry_after = bool(response.getheader("Retry-After"))
if retries.is_retry(method, response.status, has_retry_after):
try:
retries = retries.increment(
method, url, response=response, _pool=self
)
retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError:
if retries.raise_on_status:
# Drain and release the connection for this response, since
@@ -730,7 +706,8 @@ class HTTPSConnectionPool(HTTPConnectionPool):
available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket.
"""
scheme = 'https'
scheme = "https"
def __init__(
self,
@@ -771,7 +748,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
raise SSLError("SSL module is not available")
if ca_certs and cert_reqs is None:
cert_reqs = 'CERT_REQUIRED'
cert_reqs = "CERT_REQUIRED"
self.ssl_context = _build_context(
ssl_context,
keyfile=key_file,
@@ -808,7 +785,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
tunnel_headers = self.proxy_headers
# TODO: Huge hack.
for kw in ('strict', 'redirect'):
for kw in ("strict", "redirect"):
if kw in self.conn_kw:
self.conn_kw.pop(kw)
@@ -818,7 +795,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
tunnel_host=tunnel_host,
tunnel_port=tunnel_port,
tunnel_headers=tunnel_headers,
** self.conn_kw
**self.conn_kw
)
return conn
@@ -835,10 +812,10 @@ class HTTPSConnectionPool(HTTPConnectionPool):
if not conn.is_verified:
warnings.warn(
(
'Unverified HTTPS request is being made. '
'Adding certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings'
"Unverified HTTPS request is being made. "
"Adding certificate verification is strongly advised. See: "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings"
),
InsecureRequestWarning,
)
@@ -866,7 +843,7 @@ def connection_from_url(url, **kw):
"""
scheme, host, port = get_host(url)
port = port or DEFAULT_PORTS.get(scheme, 80)
if scheme == 'https':
if scheme == "https":
return HTTPSConnectionPool(host, port=port, **kw)
else:
@@ -886,6 +863,6 @@ def _ipv6_host(host):
#
# Also if an IPv6 address literal has a zone identifier, the
# percent sign might be URIencoded, convert it back into ASCII
if host.startswith('[') and host.endswith(']'):
host = host.replace('%25', '%').strip('[]')
if host.startswith("[") and host.endswith("]"):
host = host.replace("%25", "%").strip("[]")
return host
+81 -101
View File
@@ -13,47 +13,47 @@ from ..util.url import parse_url
from ..util.request import set_file_position
from ..util.retry import Retry
__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url']
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
log = logging.getLogger(__name__)
SSL_KEYWORDS = (
'key_file',
'cert_file',
'cert_reqs',
'ca_certs',
'ssl_version',
'ca_cert_dir',
'ssl_context',
"key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
)
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
'key_scheme', # str
'key_host', # str
'key_strict',
'key_port', # int
'key_timeout', # int or float or Timeout
'key_retries', # int or Retry
'key_block', # bool
'key_source_address', # str
'key_key_file', # str
'key_cert_file', # str
'key_cert_reqs', # str
'key_ca_certs', # str
'key_ssl_version', # str
'key_ca_cert_dir', # str
'key_ssl_context', # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
'key_maxsize', # int
'key_headers', # dict
'key__proxy', # parsed proxy url
'key__proxy_headers', # dict
'key_socket_options', # list of (level (int), optname (int), value (int or str)) tuples
'key__socks_options', # dict
'key_assert_hostname', # bool or string
'key_assert_fingerprint', # str
"key_scheme", # str
"key_host", # str
"key_strict",
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
)
# : The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple('PoolKey', _key_fields)
PoolKey = collections.namedtuple("PoolKey", _key_fields)
def _default_key_normalizer(key_class, request_context):
@@ -78,21 +78,21 @@ def _default_key_normalizer(key_class, request_context):
"""
# Since we mutate the dictionary, make a copy first
context = request_context.copy()
context['scheme'] = context['scheme'].lower()
context['host'] = context['host'].lower()
context["scheme"] = context["scheme"].lower()
context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets
for key in ('headers', '_proxy_headers', '_socks_options'):
for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None:
context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a
# tuple.
socket_opts = context.get('socket_options')
socket_opts = context.get("socket_options")
if socket_opts is not None:
context['socket_options'] = tuple(socket_opts)
context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'.
for key in list(context.keys()):
context['key_' + key] = context.pop(key)
context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context
for field in key_class._fields:
if field not in context:
@@ -105,12 +105,10 @@ def _default_key_normalizer(key_class, request_context):
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.
key_fn_by_scheme = {
'http': functools.partial(_default_key_normalizer, PoolKey),
'https': functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {
'http': HTTPConnectionPool, 'https': HTTPSConnectionPool
"http": functools.partial(_default_key_normalizer, PoolKey),
"https": functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
class PoolManager(RequestMethods):
@@ -140,16 +138,13 @@ class PoolManager(RequestMethods):
2
"""
proxy = None
def __init__(
self, num_pools=10, headers=None, backend=None, **connection_pool_kw
):
def __init__(self, num_pools=10, headers=None, backend=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(
num_pools, dispose_func=lambda p: p.close()
)
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
@@ -181,9 +176,9 @@ class PoolManager(RequestMethods):
# this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can
# be removed.
for key in ('scheme', 'host', 'port'):
for key in ("scheme", "host", "port"):
request_context.pop(key, None)
if scheme == 'http':
if scheme == "http":
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)
return pool_cls(host, port, backend=self.backend, **request_context)
@@ -197,9 +192,7 @@ class PoolManager(RequestMethods):
"""
self.pools.clear()
def connection_from_host(
self, host, port=None, scheme='http', pool_kwargs=None
):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
"""
Get a :class:`ConnectionPool` based on the host, port, and scheme.
@@ -213,11 +206,11 @@ class PoolManager(RequestMethods):
raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs)
request_context['scheme'] = scheme or 'http'
request_context["scheme"] = scheme or "http"
if not port:
port = DEFAULT_PORTS.get(request_context['scheme'].lower(), 80)
request_context['port'] = port
request_context['host'] = host
port = DEFAULT_PORTS.get(request_context["scheme"].lower(), 80)
request_context["port"] = port
request_context["host"] = host
return self.connection_from_context(request_context)
def connection_from_context(self, request_context):
@@ -227,12 +220,10 @@ class PoolManager(RequestMethods):
``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable.
"""
scheme = request_context['scheme'].lower()
scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme[scheme]
pool_key = pool_key_constructor(request_context)
return self.connection_from_pool_key(
pool_key, request_context=request_context
)
return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None):
"""
@@ -250,12 +241,10 @@ class PoolManager(RequestMethods):
return pool
# Make a fresh ConnectionPool of the desired type
scheme = request_context['scheme']
host = request_context['host']
port = request_context['port']
pool = self._new_pool(
scheme, host, port, request_context=request_context
)
scheme = request_context["scheme"]
host = request_context["host"]
port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool
return pool
@@ -308,11 +297,11 @@ class PoolManager(RequestMethods):
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
# Rewind body position, if needed. Record current position
# for future rewinds in the event of a redirect/retry.
body = kw.get('body')
body_pos = kw.get('body_pos')
kw['body_pos'] = set_file_position(body, body_pos)
if 'headers' not in kw:
kw['headers'] = self.headers
body = kw.get("body")
body_pos = kw.get("body_pos")
kw["body_pos"] = set_file_position(body, body_pos)
if "headers" not in kw:
kw["headers"] = self.headers
if self.proxy is not None and u.scheme == "http":
response = conn.urlopen(method, url, **kw)
else:
@@ -325,22 +314,20 @@ class PoolManager(RequestMethods):
redirect_location = urljoin(url, redirect_location)
# RFC 7231, Section 6.4.4
if response.status == 303:
method = 'GET'
retries = kw.get('retries')
method = "GET"
retries = kw.get("retries")
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect)
try:
retries = retries.increment(
method, url, response=response, _pool=conn
)
retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError:
if retries.raise_on_redirect:
raise
return response
kw['retries'] = retries
kw['redirect'] = redirect
kw["retries"] = retries
kw["redirect"] = redirect
retries.sleep_for_retry(response)
log.info("Redirecting %s -> %s", url, redirect_location)
return self.urlopen(method, redirect_location, **kw)
@@ -382,8 +369,10 @@ class ProxyManager(PoolManager):
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = '%s://%s:%i' % (
proxy_url.scheme, proxy_url.host, proxy_url.port
proxy_url = "%s://%s:%i" % (
proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url)
if not proxy.port:
@@ -394,25 +383,18 @@ class ProxyManager(PoolManager):
self.proxy = proxy
self.proxy_headers = proxy_headers or {}
connection_pool_kw['_proxy'] = self.proxy
connection_pool_kw['_proxy_headers'] = self.proxy_headers
super(ProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw
)
connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(
self, host, port=None, scheme='http', pool_kwargs=None
):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https":
return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host(
self.proxy.host,
self.proxy.port,
self.proxy.scheme,
pool_kwargs=pool_kwargs,
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None):
@@ -420,10 +402,10 @@ class ProxyManager(PoolManager):
Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user.
"""
headers_ = {'Accept': '*/*'}
headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc
if netloc:
headers_['Host'] = netloc
headers_["Host"] = netloc
if headers:
headers_.update(headers)
return headers_
@@ -435,11 +417,9 @@ class ProxyManager(PoolManager):
# For proxied HTTPS requests, httplib sets the necessary headers
# on the CONNECT to the proxy. For HTTP, we'll definitely
# need to set 'Host' at the very least.
headers = kw.get('headers', self.headers)
kw['headers'] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(
method, url, redirect=redirect, **kw
)
headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
def proxy_from_url(url, **kw):
+24 -28
View File
@@ -9,7 +9,7 @@ from socket import error as SocketError
import h11
from .._collections import HTTPHeaderDict
from ..exceptions import (ProtocolError, DecodeError, ReadTimeoutError)
from ..exceptions import ProtocolError, DecodeError, ReadTimeoutError
from ..packages.six import string_types as basestring, binary_type
from ..util.ssl_ import BaseSSLError
@@ -17,7 +17,6 @@ log = logging.getLogger(__name__)
class DeflateDecoder(object):
def __init__(self):
self._first_try = True
self._data = binary_type()
@@ -52,7 +51,6 @@ class DeflateDecoder(object):
class GzipDecoder(object):
def __init__(self):
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
@@ -67,7 +65,7 @@ class GzipDecoder(object):
def _get_decoder(mode):
if mode == 'gzip':
if mode == "gzip":
return GzipDecoder()
return DeflateDecoder()
@@ -97,12 +95,13 @@ class HTTPResponse(io.IOBase):
The retries contains the last :class:`~urllib3.util.retry.Retry` that
was used during the request.
"""
CONTENT_DECODERS = ['gzip', 'deflate']
CONTENT_DECODERS = ["gzip", "deflate"]
REDIRECT_STATUSES = [301, 302, 303, 307, 308]
def __init__(
self,
body='',
body="",
headers=None,
status=0,
version=0,
@@ -131,7 +130,7 @@ class HTTPResponse(io.IOBase):
self._fp = None
self._original_response = original_response
self._fp_bytes_read = 0
self._buffer = b''
self._buffer = b""
if body and isinstance(body, (basestring, binary_type)):
self._body = body
else:
@@ -151,7 +150,7 @@ class HTTPResponse(io.IOBase):
location. ``False`` if not a redirect status code.
"""
if self.status in self.REDIRECT_STATUSES:
return self.headers.get('location')
return self.headers.get("location")
return False
@@ -189,7 +188,7 @@ class HTTPResponse(io.IOBase):
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
if self._decoder is None and content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
@@ -201,7 +200,7 @@ class HTTPResponse(io.IOBase):
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding,
@@ -218,10 +217,10 @@ class HTTPResponse(io.IOBase):
being used.
"""
if self._decoder:
buf = self._decoder.decompress(b'')
buf = self._decoder.decompress(b"")
return buf + self._decoder.flush()
return b''
return b""
@contextmanager
def _error_catcher(self):
@@ -240,20 +239,20 @@ class HTTPResponse(io.IOBase):
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive:
if "read operation timed out" not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except (h11.ProtocolError, SocketError) as e:
# This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e)
raise ProtocolError("Connection broken: %r" % e, e)
except GeneratorExit:
# We swallow GeneratorExit when it is emitted: this allows the
@@ -305,7 +304,7 @@ class HTTPResponse(io.IOBase):
# data into the buffer. That's unfortunate, but right now I'm not smart
# enough to come up with a way to solve that problem.
if self._fp is None and not self._buffer:
return b''
return b""
data = self._buffer
with self._error_catcher():
@@ -313,8 +312,8 @@ class HTTPResponse(io.IOBase):
chunks = []
for chunk in self.stream(decode_content):
chunks.append(chunk)
data += b''.join(chunks)
self._buffer = b''
data += b"".join(chunks)
self._buffer = b""
# We only cache the body data for simple read calls.
self._body = data
else:
@@ -330,7 +329,7 @@ class HTTPResponse(io.IOBase):
else:
chunks.append(chunk)
data_len += len(chunk)
data = b''.join(chunks)
data = b"".join(chunks)
self._buffer = data[amt:]
data = data[:amt]
return data
@@ -366,7 +365,7 @@ class HTTPResponse(io.IOBase):
# coverage. Happily, the code here is so simple that testing the
# branch we don't enter is basically entirely unnecessary (it's
# just a yield statement).
final_chunk = self._decode(b'', decode_content, flush_decoder=True)
final_chunk = self._decode(b"", decode_content, flush_decoder=True)
if final_chunk: # Platform-specific: Jython
yield final_chunk
@@ -382,7 +381,7 @@ class HTTPResponse(io.IOBase):
with ``original_response=r``.
"""
# TODO: Huge hack.
for kw in ('redirect', 'assert_same_host', 'enforce_content_length'):
for kw in ("redirect", "assert_same_host", "enforce_content_length"):
if kw in response_kw:
response_kw.pop(kw)
@@ -397,7 +396,6 @@ class HTTPResponse(io.IOBase):
)
return resp
# Backwards-compatibility methods for httplib.HTTPResponse
def getheaders(self):
return self.headers
@@ -405,17 +403,15 @@ class HTTPResponse(io.IOBase):
def getheader(self, name, default=None):
return self.headers.get(name, default)
# Backwards compatibility for http.cookiejar
def info(self):
return self.headers
# Overrides from io.IOBase
def close(self):
if not self.closed:
self._fp.close()
self._buffer = b''
self._buffer = b""
self._fp = None
if self._connection:
self._connection.close()
@@ -426,7 +422,7 @@ class HTTPResponse(io.IOBase):
if self._fp is None and not self._buffer:
return True
elif hasattr(self._fp, 'complete'):
elif hasattr(self._fp, "complete"):
return self._fp.complete
else:
@@ -457,5 +453,5 @@ class HTTPResponse(io.IOBase):
return 0
else:
b[:len(temp)] = temp
b[: len(temp)] = temp
return len(temp)
+1 -1
View File
@@ -60,7 +60,7 @@ class Request(object):
do not duplicate the host header: if there already is one, we just use
that one.
"""
if b'host' not in self.headers:
if b"host" not in self.headers:
# We test against a sentinel object here to forcibly always insert
# the port for schemes we don't understand.
if port is DEFAULT_PORTS.get(scheme, object()):
+48 -55
View File
@@ -52,7 +52,7 @@ from .util import connection
from ._collections import HTTPHeaderDict
log = logging.getLogger(__name__)
port_by_scheme = {'http': 80, 'https': 443}
port_by_scheme = {"http": 80, "https": 443}
# When updating RECENT_DATE, move it to within two years of the current date,
# and not less than 6 months ago.
# Example: if Today is 2018-01-01, then RECENT_DATE should be any date on or
@@ -62,6 +62,7 @@ RECENT_DATE = datetime.date(2017, 6, 30)
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
@@ -91,7 +92,8 @@ class HTTPConnection(_HTTPConnection, object):
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
"""
default_port = port_by_scheme['http']
default_port = port_by_scheme["http"]
# : Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
@@ -100,20 +102,18 @@ class HTTPConnection(_HTTPConnection, object):
def __init__(self, *args, **kw):
if six.PY3: # Python 3
kw.pop('strict', None)
kw.pop("strict", None)
# Pre-set source_address in case we have an older Python like 2.6.
self.source_address = kw.get('source_address')
self.source_address = kw.get("source_address")
if sys.version_info < (2, 7): # Python 2.6
# _HTTPConnection on Python 2.6 will balk at this keyword arg, but
# not newer versions. We can still use it when creating a
# connection though, so we pop it *after* we have saved it as
# self.source_address.
kw.pop('source_address', None)
kw.pop("source_address", None)
# : The socket options provided by the user. If no options are
#: provided, we use the default options.
self.socket_options = kw.pop(
'socket_options', self.default_socket_options
)
self.socket_options = kw.pop("socket_options", self.default_socket_options)
# Superclass also sets self.source_address in Python 2.7+.
_HTTPConnection.__init__(self, *args, **kw)
@@ -134,7 +134,7 @@ class HTTPConnection(_HTTPConnection, object):
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
actual TCP connection across which we're going to send HTTP requests).
"""
return self._dns_host.rstrip('.')
return self._dns_host.rstrip(".")
@host.setter
def host(self, value):
@@ -153,9 +153,9 @@ class HTTPConnection(_HTTPConnection, object):
"""
extra_kw = {}
if self.source_address:
extra_kw['source_address'] = self.source_address
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw['socket_options'] = self.socket_options
extra_kw["socket_options"] = self.socket_options
try:
conn = connection.create_connection(
(self._dns_host, self.port), self.timeout, **extra_kw
@@ -163,8 +163,8 @@ class HTTPConnection(_HTTPConnection, object):
except SocketTimeout as e:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout),
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except SocketError as e:
@@ -179,7 +179,7 @@ class HTTPConnection(_HTTPConnection, object):
# the _tunnel_host attribute was added in python 2.6.3 (via
# http://hg.python.org/cpython/rev/0f57b30a152f) so pythons 2.6(0-2) do
# not have them.
if getattr(self, '_tunnel_host', None):
if getattr(self, "_tunnel_host", None):
# TODO: Fix tunnel so it doesn't depend on self.sock state.
self._tunnel()
# Mark this connection as not reusable
@@ -195,18 +195,15 @@ class HTTPConnection(_HTTPConnection, object):
body with chunked encoding and not as one block
"""
headers = HTTPHeaderDict(headers if headers is not None else {})
skip_accept_encoding = 'accept-encoding' in headers
skip_host = 'host' in headers
skip_accept_encoding = "accept-encoding" in headers
skip_host = "host" in headers
self.putrequest(
method,
url,
skip_accept_encoding=skip_accept_encoding,
skip_host=skip_host,
method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
)
for header, value in headers.items():
self.putheader(header, value)
if 'transfer-encoding' not in headers:
self.putheader('Transfer-Encoding', 'chunked')
if "transfer-encoding" not in headers:
self.putheader("Transfer-Encoding", "chunked")
self.endheaders()
if body is not None:
stringish_types = six.string_types + (six.binary_type,)
@@ -217,18 +214,18 @@ class HTTPConnection(_HTTPConnection, object):
continue
if not isinstance(chunk, six.binary_type):
chunk = chunk.encode('utf8')
chunk = chunk.encode("utf8")
len_str = hex(len(chunk))[2:]
self.send(len_str.encode('utf-8'))
self.send(b'\r\n')
self.send(len_str.encode("utf-8"))
self.send(b"\r\n")
self.send(chunk)
self.send(b'\r\n')
self.send(b"\r\n")
# After the if clause, to always have a closed body
self.send(b'0\r\n\r\n')
self.send(b"0\r\n\r\n")
class HTTPSConnection(HTTPConnection):
default_port = port_by_scheme['https']
default_port = port_by_scheme["https"]
ssl_version = None
def __init__(
@@ -242,23 +239,20 @@ class HTTPSConnection(HTTPConnection):
ssl_context=None,
**kw
):
HTTPConnection.__init__(
self, host, port, strict=strict, timeout=timeout, **kw
)
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
self.ssl_context = ssl_context
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = 'https'
self._protocol = "https"
def connect(self):
conn = self._new_conn()
self._prepare_conn(conn)
if self.ssl_context is None:
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(None),
cert_reqs=resolve_cert_reqs(None),
ssl_version=resolve_ssl_version(None), cert_reqs=resolve_cert_reqs(None)
)
self.sock = ssl_wrap_socket(
sock=conn,
@@ -273,6 +267,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
Based on httplib.HTTPSConnection but wraps the socket with
SSL certification.
"""
cert_reqs = None
ca_certs = None
ca_cert_dir = None
@@ -298,7 +293,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
# it.
if cert_reqs is None:
if ca_certs or ca_cert_dir:
cert_reqs = 'CERT_REQUIRED'
cert_reqs = "CERT_REQUIRED"
elif self.ssl_context is not None:
cert_reqs = self.ssl_context.verify_mode
self.key_file = key_file
@@ -313,7 +308,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
# Add certificate verification
conn = self._new_conn()
hostname = self.host
if getattr(self, '_tunnel_host', None):
if getattr(self, "_tunnel_host", None):
# _tunnel_host was added in Python 2.6.3
# (See: http://hg.python.org/cpython/rev/0f57b30a152f)
self.sock = conn
@@ -328,11 +323,9 @@ class VerifiedHTTPSConnection(HTTPSConnection):
if is_time_off:
warnings.warn(
(
'System time is way off (before {0}). This will probably '
'lead to SSL verification errors'
).format(
RECENT_DATE
),
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# Wrap socket using verification with the root certs in
@@ -355,30 +348,31 @@ class VerifiedHTTPSConnection(HTTPSConnection):
)
if self.assert_fingerprint:
assert_fingerprint(
self.sock.getpeercert(binary_form=True),
self.assert_fingerprint,
self.sock.getpeercert(binary_form=True), self.assert_fingerprint
)
elif context.verify_mode != ssl.CERT_NONE and not getattr(
context, 'check_hostname', False
) and self.assert_hostname is not False:
elif (
context.verify_mode != ssl.CERT_NONE
and not getattr(context, "check_hostname", False)
and self.assert_hostname is not False
):
# While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames.
cert = self.sock.getpeercert()
if not cert.get('subjectAltName', ()):
if not cert.get("subjectAltName", ()):
warnings.warn(
(
'Certificate for {0} has no `subjectAltName`, falling back to check for a '
'`commonName` for now. This feature is being removed by major browsers and '
'deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 '
'for details.)'.format(hostname)
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
"`commonName` for now. This feature is being removed by major browsers and "
"deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 "
"for details.)".format(hostname)
),
SubjectAltNameWarning,
)
_match_hostname(cert, self.assert_hostname or hostname)
self.is_verified = (
context.verify_mode == ssl.CERT_REQUIRED or
self.assert_fingerprint is not None
context.verify_mode == ssl.CERT_REQUIRED
or self.assert_fingerprint is not None
)
@@ -387,8 +381,7 @@ def _match_hostname(cert, asserted_hostname):
match_hostname(cert, asserted_hostname)
except CertificateError as e:
log.error(
'Certificate did not match expected hostname: %s. '
'Certificate: %s',
"Certificate did not match expected hostname: %s. " "Certificate: %s",
asserted_hostname,
cert,
)
+4 -4
View File
@@ -6,8 +6,8 @@ from ._sync.connectionpool import (
)
__all__ = [
'ConnectionPool',
'HTTPConnectionPool',
'HTTPSConnectionPool',
'connection_from_url',
"ConnectionPool",
"HTTPConnectionPool",
"HTTPSConnectionPool",
"connection_from_url",
]
@@ -46,20 +46,20 @@ from ctypes import (
)
from ctypes import CDLL, POINTER, CFUNCTYPE
security_path = find_library('Security')
security_path = find_library("Security")
if not security_path:
raise ImportError('The library Security could not be found')
raise ImportError("The library Security could not be found")
core_foundation_path = find_library('CoreFoundation')
core_foundation_path = find_library("CoreFoundation")
if not core_foundation_path:
raise ImportError('The library CoreFoundation could not be found')
raise ImportError("The library CoreFoundation could not be found")
version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split('.')))
version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8):
raise OSError(
'Only OS X 10.8 and newer are supported, not %s.%s' %
(version_info[0], version_info[1])
"Only OS X 10.8 and newer are supported, not %s.%s"
% (version_info[0], version_info[1])
)
Security = CDLL(security_path, use_errno=True)
@@ -121,16 +121,16 @@ try:
Security.SecIdentityGetTypeID.restype = CFTypeID
Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [
CFAllocatorRef, CFDataRef
]
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef, SecCertificateRef, POINTER(SecIdentityRef)
CFTypeRef,
SecCertificateRef,
POINTER(SecIdentityRef),
]
Security.SecIdentityCreateWithCertificate.restype = OSStatus
Security.SecKeychainCreate.argtypes = [
@@ -145,96 +145,83 @@ try:
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [
CFDataRef, CFDictionaryRef, POINTER(CFArrayRef)
CFDataRef,
CFDictionaryRef,
POINTER(CFArrayRef),
]
Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t)
)
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [
SSLContextRef, SSLReadFunc, SSLWriteFunc
]
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [
SSLContextRef, CFTypeRef, Boolean
]
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [
SSLContextRef, c_char_p, c_size_t
]
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [SSLContextRef]
Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [
SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)
]
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [
SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)
]
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [SSLContextRef]
Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [
SSLContextRef, POINTER(c_size_t)
]
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef, POINTER(SSLCipherSuite), POINTER(c_size_t)
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef, POINTER(SSLCipherSuite), c_size_t
SSLContextRef,
POINTER(SSLCipherSuite),
c_size_t,
]
Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [
SSLContextRef, POINTER(c_size_t)
]
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef, POINTER(SSLCipherSuite), POINTER(c_size_t)
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [
SSLContextRef, POINTER(SSLCipherSuite)
]
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef, POINTER(SSLProtocol)
SSLContextRef,
POINTER(SSLProtocol),
]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [
SecTrustRef, Boolean
]
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [
SecTrustRef, POINTER(SecTrustResultType)
]
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [
CFAllocatorRef, SSLProtocolSide, SSLConnectionType
CFAllocatorRef,
SSLProtocolSide,
SSLConnectionType,
]
Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [
SSLContextRef, SSLSessionOption, Boolean
]
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMin.restype = OSStatus
@@ -254,10 +241,10 @@ try:
Security.SecExternalFormat = SecExternalFormat
Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, 'kSecImportExportPassphrase'
Security, "kSecImportExportPassphrase"
)
Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, 'kSecImportItemIdentity'
Security, "kSecImportItemIdentity"
)
# CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
@@ -267,15 +254,18 @@ try:
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef, c_char_p, CFStringEncoding
CFAllocatorRef,
c_char_p,
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [
CFStringRef, CFStringEncoding
]
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef, c_char_p, CFIndex, CFStringEncoding
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
@@ -296,11 +286,16 @@ try:
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [
CFAllocatorRef, POINTER(CFTypeRef), CFIndex, CFArrayCallBacks
CFAllocatorRef,
POINTER(CFTypeRef),
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreate.restype = CFArrayRef
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef, CFIndex, CFArrayCallBacks
CFAllocatorRef,
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
@@ -310,23 +305,23 @@ try:
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, 'kCFAllocatorDefault'
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeArrayCallBacks'
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryKeyCallBacks'
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
)
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryValueCallBacks'
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef
CoreFoundation.CFArrayRef = CFArrayRef
CoreFoundation.CFStringRef = CFStringRef
CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError):
raise ImportError('Error initializing ctypes')
raise ImportError("Error initializing ctypes")
class CFConst(object):
@@ -334,6 +329,7 @@ class CFConst(object):
A class object that acts as essentially a namespace for CoreFoundation
constants.
"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
@@ -341,6 +337,7 @@ class SecurityConst(object):
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
@@ -70,11 +70,11 @@ def _cf_string_to_unicode(value):
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError('Error copying C string from CFStringRef')
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode('utf-8')
string = string.decode("utf-8")
return string
@@ -89,8 +89,8 @@ def _assert_no_error(error, exception_class=None):
cf_error_string = Security.SecCopyErrorMessageString(error, None)
output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u'':
output = u'OSStatus %s' % error
if output is None or output == u"":
output = u"OSStatus %s" % error
if exception_class is None:
exception_class = ssl.SSLError
raise exception_class(output)
@@ -102,8 +102,7 @@ def _cert_array_from_pem(pem_bundle):
that can be used to validate a cert chain.
"""
der_certs = [
base64.b64decode(match.group(1))
for match in _PEM_CERTS_RE.finditer(pem_bundle)
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
]
if not der_certs:
raise ssl.SSLError("No root certificates specified")
@@ -173,19 +172,14 @@ def _temporary_keychain():
# some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes.
random_bytes = os.urandom(40)
filename = base64.b64encode(random_bytes[:8]).decode('utf-8')
filename = base64.b64encode(random_bytes[:8]).decode("utf-8")
password = base64.b64encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode('utf-8')
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself.
keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate(
keychain_path,
len(password),
password,
False,
None,
ctypes.byref(keychain),
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
)
_assert_no_error(status)
# Having created the keychain, we want to pass it off to the caller.
@@ -202,7 +196,7 @@ def _load_items_from_file(keychain, path):
certificates = []
identities = []
result_array = None
with open(path, 'rb') as f:
with open(path, "rb") as f:
raw_filedata = f.read()
try:
filedata = CoreFoundation.CFDataCreate(
@@ -279,9 +273,7 @@ def _load_client_cert_chain(keychain, *paths):
paths = (path for path in paths if path)
try:
for file_path in paths:
new_identities, new_certs = _load_items_from_file(
keychain, file_path
)
new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities)
certificates.extend(new_certs)
# Ok, we have everything. The question is: do we have an identity? If
+23 -33
View File
@@ -144,9 +144,7 @@ class AppEngineManager(RequestMethods):
):
retries = self._get_retries(retries, redirect)
try:
follow_redirects = (
redirect and retries.redirect != 0 and retries.total
)
follow_redirects = redirect and retries.redirect != 0 and retries.total
response = urlfetch.fetch(
url,
payload=body,
@@ -161,7 +159,7 @@ class AppEngineManager(RequestMethods):
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if 'too large' in str(e):
if "too large" in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.",
@@ -171,7 +169,7 @@ class AppEngineManager(RequestMethods):
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if 'Too many redirects' in str(e):
if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
@@ -198,12 +196,12 @@ class AppEngineManager(RequestMethods):
redirect_location = redirect and http_response.get_redirect_location()
if redirect_location:
# Check for redirect response
if (self.urlfetch_retries and retries.raise_on_redirect):
if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
else:
if http_response.status == 303:
method = 'GET'
method = "GET"
try:
retries = retries.increment(
method, url, response=http_response, _pool=self
@@ -229,11 +227,9 @@ class AppEngineManager(RequestMethods):
)
# Check if we should retry the HTTP response.
has_retry_after = bool(http_response.getheader('Retry-After'))
has_retry_after = bool(http_response.getheader("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment(
method, url, response=http_response, _pool=self
)
retries = retries.increment(method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url)
retries.sleep(http_response)
return self.urlopen(
@@ -249,22 +245,20 @@ class AppEngineManager(RequestMethods):
return http_response
def _urlfetch_response_to_http_response(
self, urlfetch_resp, **response_kw
):
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get('content-encoding')
if content_encoding == 'deflate':
del urlfetch_resp.headers['content-encoding']
transfer_encoding = urlfetch_resp.headers.get('transfer-encoding')
content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == "deflate":
del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content,
# so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == 'chunked':
if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",")
encodings.remove('chunked')
urlfetch_resp.headers['transfer-encoding'] = ','.join(encodings)
encodings.remove("chunked")
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
return HTTPResponse(
# In order for decoding to work, we must present the content as
# a file-like object.
@@ -291,9 +285,7 @@ class AppEngineManager(RequestMethods):
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(
retries, redirect=redirect, default=self.retries
)
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
@@ -304,9 +296,7 @@ class AppEngineManager(RequestMethods):
def is_appengine():
return (
is_local_appengine() or is_prod_appengine() or is_prod_appengine_mvms()
)
return is_local_appengine() or is_prod_appengine() or is_prod_appengine_mvms()
def is_appengine_sandbox():
@@ -315,18 +305,18 @@ def is_appengine_sandbox():
def is_local_appengine():
return (
'APPENGINE_RUNTIME' in os.environ and
'Development/' in os.environ['SERVER_SOFTWARE']
"APPENGINE_RUNTIME" in os.environ
and "Development/" in os.environ["SERVER_SOFTWARE"]
)
def is_prod_appengine():
return (
'APPENGINE_RUNTIME' in os.environ and
'Google App Engine/' in os.environ['SERVER_SOFTWARE'] and
not is_prod_appengine_mvms()
"APPENGINE_RUNTIME" in os.environ
and "Google App Engine/" in os.environ["SERVER_SOFTWARE"]
and not is_prod_appengine_mvms()
)
def is_prod_appengine_mvms():
return os.environ.get('GAE_VM', False) == 'true'
return os.environ.get("GAE_VM", False) == "true"
+36 -50
View File
@@ -63,7 +63,7 @@ import sys
from .. import util
__all__ = ['inject_into_urllib3', 'extract_from_urllib3']
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works.
HAS_SNI = True
# Map from urllib3 to PyOpenSSL compatible parameter-values.
@@ -71,9 +71,9 @@ _openssl_versions = {
ssl.PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD,
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
}
if hasattr(ssl, 'PROTOCOL_TLSv1_1') and hasattr(OpenSSL.SSL, 'TLSv1_1_METHOD'):
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, 'PROTOCOL_TLSv1_2') and hasattr(OpenSSL.SSL, 'TLSv1_2_METHOD'):
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
try:
_openssl_versions.update({ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD})
@@ -82,12 +82,10 @@ except AttributeError:
_stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER +
OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
_openssl_to_stdlib_verify = dict(
(v, k) for k, v in _stdlib_to_openssl_verify.items()
)
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384
orig_util_HAS_SNI = util.HAS_SNI
@@ -96,7 +94,7 @@ log = logging.getLogger(__name__)
def inject_into_urllib3():
'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.'
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met()
util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI
@@ -106,7 +104,7 @@ def inject_into_urllib3():
def extract_from_urllib3():
'Undo monkey-patching by :func:`inject_into_urllib3`.'
"Undo monkey-patching by :func:`inject_into_urllib3`."
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
@@ -159,16 +157,16 @@ def _dnsname_to_stdlib(name):
"""
import idna
for prefix in [u'*.', u'.']:
for prefix in [u"*.", u"."]:
if name.startswith(prefix):
name = name[len(prefix):]
return prefix.encode('ascii') + idna.encode(name)
name = name[len(prefix) :]
return prefix.encode("ascii") + idna.encode(name)
return idna.encode(name)
name = idna_encode(name)
if sys.version_info >= (3, 0):
name = name.decode('utf-8')
name = name.decode("utf-8")
return name
@@ -186,9 +184,7 @@ def get_subj_alt_name(peer_cert):
# We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python)
try:
ext = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
).value
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
except x509.ExtensionNotFound:
# No such extension, return the empty list.
return []
@@ -216,22 +212,21 @@ def get_subj_alt_name(peer_cert):
# decoded. This is pretty frustrating, but that's what the standard library
# does with certificates, and so we need to attempt to do the same.
names = [
('DNS', _dnsname_to_stdlib(name))
("DNS", _dnsname_to_stdlib(name))
for name in ext.get_values_for_type(x509.DNSName)
]
names.extend(
('IP Address', str(name))
for name in ext.get_values_for_type(x509.IPAddress)
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
)
return names
class WrappedSocket(object):
'''API-compatibility wrapper for Python OpenSSL's Connection-class.
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
collector of pypy.
'''
"""
def __init__(self, connection, socket, suppress_ragged_eofs=True):
self.connection = connection
@@ -243,7 +238,6 @@ class WrappedSocket(object):
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
@@ -255,15 +249,15 @@ class WrappedSocket(object):
try:
data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
return b''
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b""
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError as e:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b''
return b""
else:
raise
@@ -271,7 +265,7 @@ class WrappedSocket(object):
except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(self.socket, self.socket.gettimeout())
if not rd:
raise timeout('The read operation timed out')
raise timeout("The read operation timed out")
else:
return self.recv(*args, **kwargs)
@@ -284,7 +278,7 @@ class WrappedSocket(object):
return self.connection.recv_into(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0
else:
@@ -300,7 +294,7 @@ class WrappedSocket(object):
except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(self.socket, self.socket.gettimeout())
if not rd:
raise timeout('The read operation timed out')
raise timeout("The read operation timed out")
else:
return self.recv_into(*args, **kwargs)
@@ -330,7 +324,7 @@ class WrappedSocket(object):
total_sent = 0
while total_sent < len(data):
sent = self._send_until_done(
data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE]
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
)
total_sent += sent
@@ -356,13 +350,11 @@ class WrappedSocket(object):
return x509
if binary_form:
return OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_ASN1, x509
)
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
return {
'subject': ((('commonName', x509.get_subject().CN),),),
'subjectAltName': get_subj_alt_name(x509),
"subject": ((("commonName", x509.get_subject().CN),),),
"subjectAltName": get_subj_alt_name(x509),
}
def setblocking(self, flag):
@@ -418,23 +410,21 @@ class PyOpenSSLContext(object):
@verify_mode.setter
def verify_mode(self, value):
self._ctx.set_verify(
_stdlib_to_openssl_verify[value], _verify_callback
)
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
def set_default_verify_paths(self):
self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers):
if isinstance(ciphers, six.text_type):
ciphers = ciphers.encode('utf-8')
ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
if cafile is not None:
cafile = cafile.encode('utf-8')
cafile = cafile.encode("utf-8")
if capath is not None:
capath = capath.encode('utf-8')
capath = capath.encode("utf-8")
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
@@ -442,9 +432,7 @@ class PyOpenSSLContext(object):
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_chain_file(certfile)
if password is not None:
self._ctx.set_passwd_cb(
lambda max_length, prompt_twice, userdata: password
)
self._ctx.set_passwd_cb(lambda max_length, prompt_twice, userdata: password)
self._ctx.use_privatekey_file(keyfile or certfile)
def wrap_socket(
@@ -456,10 +444,8 @@ class PyOpenSSLContext(object):
server_hostname=None,
):
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(
server_hostname, six.text_type
): # Platform-specific: Python 3
server_hostname = server_hostname.encode('utf-8')
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
server_hostname = server_hostname.encode("utf-8")
if server_hostname is not None:
cnx.set_tlsext_host_name(server_hostname)
cnx.set_connect_state()
@@ -469,12 +455,12 @@ class PyOpenSSLContext(object):
except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(sock, sock.gettimeout())
if not rd:
raise timeout('select timed out')
raise timeout("select timed out")
continue
except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad handshake: %r' % e)
raise ssl.SSLError("bad handshake: %r" % e)
break
+31 -52
View File
@@ -37,9 +37,7 @@ import threading
import weakref
from .. import util
from ._securetransport.bindings import (
Security, SecurityConst, CoreFoundation
)
from ._securetransport.bindings import Security, SecurityConst, CoreFoundation
from ._securetransport.low_level import (
_assert_no_error,
_cert_array_from_pem,
@@ -53,11 +51,11 @@ except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
try:
memoryview(b'')
memoryview(b"")
except NameError:
raise ImportError("SecureTransport only works on Pythons with memoryview")
__all__ = ['inject_into_urllib3', 'extract_from_urllib3']
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works
HAS_SNI = True
orig_util_HAS_SNI = util.HAS_SNI
@@ -124,34 +122,35 @@ CIPHER_SUITES = [
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
_protocol_to_min_max = {
ssl.PROTOCOL_SSLv23: (
SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12
)
ssl.PROTOCOL_SSLv23: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12)
}
if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2, SecurityConst.kSSLProtocol2
SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
)
if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3, SecurityConst.kSSLProtocol3
SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
)
if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol1
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
)
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11, SecurityConst.kTLSProtocol11
SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
)
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12, SecurityConst.kTLSProtocol12
SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
)
if hasattr(ssl, "PROTOCOL_TLS"):
_protocol_to_min_max[ssl.PROTOCOL_TLS] = _protocol_to_min_max[
ssl.PROTOCOL_SSLv23
]
_protocol_to_min_max[ssl.PROTOCOL_TLS] = _protocol_to_min_max[ssl.PROTOCOL_SSLv23]
def inject_into_urllib3():
@@ -199,7 +198,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
if timeout is None or timeout >= 0:
readables = util.wait_for_read([base_socket], timeout)
if not readables:
raise socket.error(errno.EAGAIN, 'timed out')
raise socket.error(errno.EAGAIN, "timed out")
# We need to tell ctypes that we have a buffer that can be
# written to. Upsettingly, we do that like this:
@@ -255,7 +254,7 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
if timeout is None or timeout >= 0:
writables = util.wait_for_write([base_socket], timeout)
if not writables:
raise socket.error(errno.EAGAIN, 'timed out')
raise socket.error(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data)
sent += chunk_sent
@@ -342,9 +341,7 @@ class WrappedSocket(object):
custom and doesn't allow changing at this time, mostly because parsing
OpenSSL cipher strings is going to be a freaking nightmare.
"""
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(
*CIPHER_SUITES
)
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
result = Security.SSLSetEnabledCiphers(
self.context, ciphers, len(CIPHER_SUITES)
)
@@ -362,7 +359,7 @@ class WrappedSocket(object):
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, 'rb') as f:
with open(trust_bundle, "rb") as f:
trust_bundle = f.read()
cert_array = None
trust = Security.SecTrustRef()
@@ -373,9 +370,7 @@ class WrappedSocket(object):
# created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust(
self.context, ctypes.byref(trust)
)
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
raise ssl.SSLError("Failed to copy trust reference")
@@ -385,9 +380,7 @@ class WrappedSocket(object):
result = Security.SecTrustSetAnchorCertificatesOnly(trust, True)
_assert_no_error(result)
trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate(
trust, ctypes.byref(trust_result)
)
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
_assert_no_error(result)
finally:
if trust:
@@ -401,8 +394,7 @@ class WrappedSocket(object):
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" %
trust_result.value
"certificate verify failed, error code: %d" % trust_result.value
)
def handshake(
@@ -442,7 +434,7 @@ class WrappedSocket(object):
# If we have a server hostname, we should set that too.
if server_hostname:
if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode('utf-8')
server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname)
)
@@ -460,9 +452,7 @@ class WrappedSocket(object):
# authing in that case.
if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption(
self.context,
SecurityConst.kSSLSessionOptionBreakOnServerAuth,
True,
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
)
_assert_no_error(result)
# If there's a client cert, we need to use it.
@@ -471,9 +461,7 @@ class WrappedSocket(object):
self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key
)
result = Security.SSLSetCertificate(
self.context, self._client_cert_chain
)
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
_assert_no_error(result)
while True:
with self._raise_on_error():
@@ -492,7 +480,6 @@ class WrappedSocket(object):
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
@@ -522,7 +509,7 @@ class WrappedSocket(object):
# There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify.
if (result == SecurityConst.errSSLWouldBlock):
if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short"
@@ -570,7 +557,7 @@ class WrappedSocket(object):
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self.send(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE])
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self):
@@ -618,18 +605,14 @@ class WrappedSocket(object):
# instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport.
if not binary_form:
raise ValueError(
"SecureTransport only supports dumping binary certs"
)
raise ValueError("SecureTransport only supports dumping binary certs")
trust = Security.SecTrustRef()
certdata = None
der_bytes = None
try:
# Grab the trust store.
result = Security.SSLCopyPeerTrust(
self.context, ctypes.byref(trust)
)
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
# Probably we haven't done the handshake yet. No biggie.
@@ -758,16 +741,12 @@ class SecureTransportContext(object):
def set_ciphers(self, ciphers):
# For now, we just require the default cipher string.
if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError(
"SecureTransport doesn't support custom cipher strings"
)
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# OK, we only really support cadata and cafile.
if capath is not None:
raise ValueError(
"SecureTransport does not support cert directories"
)
raise ValueError("SecureTransport does not support cert directories")
self._trust_bundle = cafile or cadata
+32 -33
View File
@@ -31,9 +31,9 @@ except ImportError:
warnings.warn(
(
'SOCKS support in urllib3 requires the installation of optional '
'dependencies: specifically, PySocks. For more information, see '
'https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies'
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
@@ -41,8 +41,8 @@ except ImportError:
from socket import error as SocketError, timeout as SocketTimeout
from .._sync.connection import (HTTP1Connection)
from ..connectionpool import (HTTPConnectionPool, HTTPSConnectionPool)
from .._sync.connection import HTTP1Connection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
@@ -54,7 +54,7 @@ class SOCKSConnection(HTTP1Connection):
"""
def __init__(self, *args, **kwargs):
self._socks_options = kwargs.pop('_socks_options')
self._socks_options = kwargs.pop("_socks_options")
super(SOCKSConnection, self).__init__(*args, **kwargs)
def _do_socket_connect(self, connect_timeout, connect_kw):
@@ -64,20 +64,20 @@ class SOCKSConnection(HTTP1Connection):
try:
conn = socks.create_connection(
(self._host, self._port),
proxy_type=self._socks_options['socks_version'],
proxy_addr=self._socks_options['proxy_host'],
proxy_port=self._socks_options['proxy_port'],
proxy_username=self._socks_options['username'],
proxy_password=self._socks_options['password'],
proxy_rdns=self._socks_options['rdns'],
proxy_type=self._socks_options["socks_version"],
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"],
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=connect_timeout,
**connect_kw
)
except SocketTimeout as e:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self._host, connect_timeout),
"Connection to %s timed out. (connect timeout=%s)"
% (self._host, connect_timeout),
)
except socks.ProxyError as e:
@@ -88,14 +88,13 @@ class SOCKSConnection(HTTP1Connection):
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self._host, connect_timeout),
"Connection to %s timed out. (connect timeout=%s)"
% (self._host, connect_timeout),
)
else:
raise NewConnectionError(
self,
"Failed to establish a new connection: %s" % error,
self, "Failed to establish a new connection: %s" % error
)
else:
@@ -124,8 +123,10 @@ class SOCKSProxyManager(PoolManager):
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
'http': SOCKSHTTPConnectionPool, 'https': SOCKSHTTPSConnectionPool
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
@@ -138,33 +139,31 @@ class SOCKSProxyManager(PoolManager):
**connection_pool_kw
):
parsed = parse_url(proxy_url)
if parsed.scheme == 'socks5':
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == 'socks5h':
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == 'socks4':
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == 'socks4a':
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError(
"Unable to determine SOCKS version from %s" % proxy_url
)
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
self.proxy_url = proxy_url
socks_options = {
'socks_version': socks_version,
'proxy_host': parsed.host,
'proxy_port': parsed.port,
'username': username,
'password': password,
'rdns': rdns,
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw['_socks_options'] = socks_options
connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw
)
+8 -11
View File
@@ -1,7 +1,6 @@
from __future__ import absolute_import
# Base Exceptions
class HTTPError(Exception):
"Base exception used by this module."
@@ -74,14 +73,13 @@ class MaxRetryError(RequestError):
def __init__(self, pool, url, reason=None):
self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % (
url, reason
)
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
RequestError.__init__(self, pool, url, message)
class TimeoutStateError(HTTPError):
""" Raised when passing an invalid state to a timeout """
pass
@@ -91,6 +89,7 @@ class TimeoutError(HTTPError):
Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
"""
pass
@@ -99,8 +98,6 @@ class ReadTimeoutError(TimeoutError, RequestError):
pass
# This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError
class ConnectTimeoutError(TimeoutError):
@@ -139,8 +136,8 @@ class LocationParseError(LocationValueError):
class ResponseError(HTTPError):
"Used as a container for an error reason supplied in a MaxRetryError."
GENERIC_ERROR = 'too many error responses'
SPECIFIC_ERROR = 'too many {status_code} error responses'
GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning):
@@ -178,6 +175,7 @@ class DependencyWarning(HTTPWarning):
Warned when an attempt is made to import a module with missing optional
dependencies.
"""
pass
@@ -209,9 +207,7 @@ class HeaderParsingError(HTTPError):
"Raised by assert_header_parsing, but we convert it to a log.warning statement."
def __init__(self, defects, unparsed_data):
message = '%s, unparsed data: %r' % (
defects or 'Unknown', unparsed_data
)
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
super(HeaderParsingError, self).__init__(message)
@@ -235,4 +231,5 @@ class InvalidBodyError(HTTPError):
An attempt was made to send a request with a body object that urllib3 does
not support.
"""
pass
+18 -23
View File
@@ -5,7 +5,7 @@ import mimetypes
from .packages import six
def guess_content_type(filename, default='application/octet-stream'):
def guess_content_type(filename, default="application/octet-stream"):
"""
Guess the "Content-Type" of a file.
@@ -36,16 +36,16 @@ def format_header_param(name, value):
if not any(ch in value for ch in '"\\\r\n'):
result = '%s="%s"' % (name, value)
try:
result.encode('ascii')
result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError):
pass
else:
return result
if not six.PY3 and isinstance(value, six.text_type): # Python 2:
value = value.encode('utf-8')
value = email.utils.encode_rfc2231(value, 'utf-8')
value = '%s*=%s' % (name, value)
value = value.encode("utf-8")
value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value)
return value
@@ -132,29 +132,26 @@ class RequestField(object):
for name, value in iterable:
if value is not None:
parts.append(self._render_part(name, value))
return '; '.join(parts)
return "; ".join(parts)
def render_headers(self):
"""
Renders the headers for this request field.
"""
lines = []
sort_keys = ['Content-Disposition', 'Content-Type', 'Content-Location']
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys:
if self.headers.get(sort_key, False):
lines.append('%s: %s' % (sort_key, self.headers[sort_key]))
lines.append("%s: %s" % (sort_key, self.headers[sort_key]))
for header_name, header_value in self.headers.items():
if header_name not in sort_keys:
if header_value:
lines.append('%s: %s' % (header_name, header_value))
lines.append('\r\n')
return '\r\n'.join(lines)
lines.append("%s: %s" % (header_name, header_value))
lines.append("\r\n")
return "\r\n".join(lines)
def make_multipart(
self,
content_disposition=None,
content_type=None,
content_location=None,
self, content_disposition=None, content_type=None, content_location=None
):
"""
Makes this request field into a multipart request field.
@@ -168,16 +165,14 @@ class RequestField(object):
The 'Content-Location' of the request body.
"""
self.headers[
'Content-Disposition'
] = content_disposition or 'form-data'
self.headers['Content-Disposition'] += '; '.join(
self.headers["Content-Disposition"] = content_disposition or "form-data"
self.headers["Content-Disposition"] += "; ".join(
[
'',
"",
self._render_parts(
(('name', self._name), ('filename', self._filename))
(("name", self._name), ("filename", self._filename))
),
]
)
self.headers['Content-Type'] = content_type
self.headers['Content-Location'] = content_location
self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location
+6 -5
View File
@@ -7,7 +7,7 @@ from .packages import six
from .packages.six import b
from .fields import RequestField
writer = codecs.lookup('utf-8')[3]
writer = codecs.lookup("utf-8")[3]
def choose_boundary():
@@ -22,6 +22,7 @@ def choose_boundary():
to affect our entire library.
"""
from uuid import uuid4
return uuid4().hex
@@ -78,7 +79,7 @@ def encode_multipart_formdata(fields, boundary=None):
if boundary is None:
boundary = choose_boundary()
for field in iter_field_objects(fields):
body.write(b('--%s\r\n' % (boundary)))
body.write(b("--%s\r\n" % (boundary)))
writer(body).write(field.render_headers())
data = field.data
if isinstance(data, int):
@@ -87,7 +88,7 @@ def encode_multipart_formdata(fields, boundary=None):
writer(body).write(data)
else:
body.write(data)
body.write(b'\r\n')
body.write(b('--%s--\r\n' % (boundary)))
content_type = str('multipart/form-data; boundary=%s' % boundary)
body.write(b"\r\n")
body.write(b("--%s--\r\n" % (boundary)))
content_type = str("multipart/form-data; boundary=%s" % boundary)
return body.getvalue(), content_type
+1 -1
View File
@@ -2,4 +2,4 @@ from __future__ import absolute_import
from . import ssl_match_hostname
__all__ = ('ssl_match_hostname',)
__all__ = ("ssl_match_hostname",)
+37 -39
View File
@@ -13,7 +13,7 @@ except ImportError:
class OrderedDict(dict):
'Dictionary that remembers insertion order'
"Dictionary that remembers insertion order"
# An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get.
@@ -24,13 +24,13 @@ class OrderedDict(dict):
# The sentinel element never gets deleted (this simplifies the algorithm).
# Each link is stored as a list of length three: [PREV, NEXT, KEY].
def __init__(self, *args, **kwds):
'''Initialize an ordered dictionary. Signature is the same as for
"""Initialize an ordered dictionary. Signature is the same as for
regular dictionaries, but keyword arguments are not recommended
because their insertion order is arbitrary.
'''
"""
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
raise TypeError("expected at most 1 arguments, got %d" % len(args))
try:
self.__root
@@ -41,7 +41,7 @@ class OrderedDict(dict):
self.__update(*args, **kwds)
def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
'od.__setitem__(i, y) <==> od[i]=y'
"od.__setitem__(i, y) <==> od[i]=y"
# Setting a new item creates a new link which goes at the end of the linked
# list, and the inherited dictionary is updated with the new key/value pair.
if key not in self:
@@ -51,7 +51,7 @@ class OrderedDict(dict):
dict_setitem(self, key, value)
def __delitem__(self, key, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]'
"od.__delitem__(y) <==> del od[y]"
# Deleting an existing item uses self.__map to find the link which is
# then removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key)
@@ -60,7 +60,7 @@ class OrderedDict(dict):
link_next[0] = link_prev
def __iter__(self):
'od.__iter__() <==> iter(od)'
"od.__iter__() <==> iter(od)"
root = self.__root
curr = root[1]
while curr is not root:
@@ -69,7 +69,7 @@ class OrderedDict(dict):
curr = curr[1]
def __reversed__(self):
'od.__reversed__() <==> reversed(od)'
"od.__reversed__() <==> reversed(od)"
root = self.__root
curr = root[0]
while curr is not root:
@@ -78,7 +78,7 @@ class OrderedDict(dict):
curr = curr[0]
def clear(self):
'od.clear() -> None. Remove all items from od.'
"od.clear() -> None. Remove all items from od."
try:
for node in self.__map.itervalues():
del node[:]
@@ -90,12 +90,12 @@ class OrderedDict(dict):
dict.clear(self)
def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair.
"""od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false.
'''
"""
if not self:
raise KeyError('dictionary is empty')
raise KeyError("dictionary is empty")
root = self.__root
if last:
@@ -113,51 +113,50 @@ class OrderedDict(dict):
value = dict.pop(self, key)
return key, value
# -- the following methods do not depend on the internal structure --
def keys(self):
'od.keys() -> list of keys in od'
"od.keys() -> list of keys in od"
return list(self)
def values(self):
'od.values() -> list of values in od'
"od.values() -> list of values in od"
return [self[key] for key in self]
def items(self):
'od.items() -> list of (key, value) pairs in od'
"od.items() -> list of (key, value) pairs in od"
return [(key, self[key]) for key in self]
def iterkeys(self):
'od.iterkeys() -> an iterator over the keys in od'
"od.iterkeys() -> an iterator over the keys in od"
return iter(self)
def itervalues(self):
'od.itervalues -> an iterator over the values in od'
"od.itervalues -> an iterator over the values in od"
for k in self:
yield self[k]
def iteritems(self):
'od.iteritems -> an iterator over the (key, value) items in od'
"od.iteritems -> an iterator over the (key, value) items in od"
for k in self:
yield (k, self[k])
def update(*args, **kwds):
'''od.update(E, **F) -> None. Update od from dict/iterable E and F.
"""od.update(E, **F) -> None. Update od from dict/iterable E and F.
If E is a dict instance, does: for k in E: od[k] = E[k]
If E has a .keys() method, does: for k in E.keys(): od[k] = E[k]
Or if E is an iterable of items, does: for k, v in E: od[k] = v
In either case, this is followed by: for k, v in F.items(): od[k] = v
'''
"""
if len(args) > 2:
raise TypeError(
'update() takes at most 2 positional '
'arguments (%d given)' % (len(args),)
"update() takes at most 2 positional "
"arguments (%d given)" % (len(args),)
)
elif not args:
raise TypeError('update() takes at least 1 argument (0 given)')
raise TypeError("update() takes at least 1 argument (0 given)")
self = args[0]
# Make progressively weaker assumptions about "other"
@@ -167,7 +166,7 @@ class OrderedDict(dict):
if isinstance(other, dict):
for key in other:
self[key] = other[key]
elif hasattr(other, 'keys'):
elif hasattr(other, "keys"):
for key in other.keys():
self[key] = other[key]
else:
@@ -180,10 +179,10 @@ class OrderedDict(dict):
__marker = object()
def pop(self, key, default=__marker):
'''od.pop(k[,d]) -> v, remove specified key and return the corresponding value.
"""od.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
'''
"""
if key in self:
result = self[key]
del self[key]
@@ -195,7 +194,7 @@ class OrderedDict(dict):
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
"od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od"
if key in self:
return self[key]
@@ -203,23 +202,23 @@ class OrderedDict(dict):
return default
def __repr__(self, _repr_running={}):
'od.__repr__() <==> repr(od)'
"od.__repr__() <==> repr(od)"
call_key = id(self), _get_ident()
if call_key in _repr_running:
return '...'
return "..."
_repr_running[call_key] = 1
try:
if not self:
return '%s()' % (self.__class__.__name__,)
return "%s()" % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items())
return "%s(%r)" % (self.__class__.__name__, self.items())
finally:
del _repr_running[call_key]
def __reduce__(self):
'Return state information for pickling'
"Return state information for pickling"
items = [[k, self[k]] for k in self]
inst_dict = vars(self).copy()
for k in vars(OrderedDict()):
@@ -230,25 +229,25 @@ class OrderedDict(dict):
return self.__class__, (items,)
def copy(self):
'od.copy() -> a shallow copy of od'
"od.copy() -> a shallow copy of od"
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
"""OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
and values equal to v (which defaults to None).
'''
"""
d = cls()
for key in iterable:
d[key] = value
return d
def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
"""od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive.
'''
"""
if isinstance(other, OrderedDict):
return len(self) == len(other) and self.items() == other.items()
@@ -257,7 +256,6 @@ class OrderedDict(dict):
def __ne__(self, other):
return not self == other
# -- the following methods are only used in Python 2.7 --
def viewkeys(self):
"od.viewkeys() -> a set-like object providing a view on od's keys"
+39 -76
View File
@@ -33,14 +33,14 @@ PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4)
if PY3:
string_types = str,
integer_types = int,
class_types = type,
string_types = (str,)
integer_types = (int,)
class_types = (type,)
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
string_types = (basestring,)
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
@@ -52,7 +52,6 @@ else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
@@ -79,7 +78,6 @@ def _import_module(name):
class _LazyDescr(object):
def __init__(self, name):
self.name = name
@@ -96,7 +94,6 @@ class _LazyDescr(object):
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
@@ -117,7 +114,6 @@ class MovedModule(_LazyDescr):
class _LazyModule(types.ModuleType):
def __init__(self, name):
super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__
@@ -132,7 +128,6 @@ class _LazyModule(types.ModuleType):
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
@@ -227,6 +222,7 @@ _importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule):
"""Lazy loading of moved objects"""
__path__ = [] # mark as package
@@ -243,10 +239,7 @@ _moved_attributes = [
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute(
"reload_module",
"__builtin__",
"importlib" if PY34 else "imp",
"reload",
"reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"
),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
@@ -269,13 +262,9 @@ _moved_attributes = [
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule(
"email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"
),
MovedModule(
"email_mime_nonmultipart",
"email.MIMENonMultipart",
"email.mime.nonmultipart",
"email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"
),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
@@ -290,37 +279,21 @@ _moved_attributes = [
MovedModule("tkinter", "Tkinter"),
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
MovedModule(
"tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"
),
MovedModule(
"tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"
),
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule(
"tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"
),
MovedModule(
"tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"
),
MovedModule("tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule(
"tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"
),
MovedModule(
"urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"
),
MovedModule(
"urllib_error", __name__ + ".moves.urllib_error", "urllib.error"
),
MovedModule(
"urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"
),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
@@ -417,9 +390,7 @@ _urllib_request_moved_attributes = [
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
MovedAttribute(
"HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"
),
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
@@ -481,11 +452,11 @@ _urllib_robotparser_moved_attributes = [
for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
Module_six_moves_urllib_robotparser._moved_attributes = (
_urllib_robotparser_moved_attributes
)
_importer._add_module(
Module_six_moves_urllib_robotparser(
__name__ + ".moves.urllib.robotparser"
),
Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser",
"moves.urllib.robotparser",
)
@@ -493,6 +464,7 @@ _importer._add_module(
class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error")
@@ -501,7 +473,7 @@ class Module_six_moves_urllib(types.ModuleType):
robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser']
return ["parse", "error", "request", "response", "robotparser"]
_importer._add_module(
@@ -579,14 +551,12 @@ else:
return types.MethodType(func, None, cls)
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(
get_unbound_function,
"""Get the function out of a possibly unbound function""",
get_unbound_function, """Get the function out of a possibly unbound function"""
)
get_method_function = operator.attrgetter(_meth_func)
get_method_self = operator.attrgetter(_meth_self)
@@ -630,13 +600,9 @@ else:
viewitems = operator.methodcaller("viewitems")
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems, "Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(
iteritems,
"Return an iterator over the (key, value) pairs of a dictionary.",
)
_add_doc(
iterlists,
"Return an iterator over the (key, [values]) pairs of a dictionary.",
iterlists, "Return an iterator over the (key, [values]) pairs of a dictionary."
)
if PY3:
@@ -670,10 +636,9 @@ else:
def b(s):
return s
# Workaround for standalone backslash
def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
return unicode(s.replace(r"\\", r"\\\\"), "unicode_escape")
unichr = unichr
int2byte = chr
@@ -731,7 +696,7 @@ else:
del frame
elif _locs_ is None:
_locs_ = _globs_
exec ("""exec _code_ in _globs_, _locs_""")
exec("""exec _code_ in _globs_, _locs_""")
exec_(
"""def reraise(tp, value, tb=None):
@@ -772,9 +737,9 @@ if print_ is None:
data = str(data)
# If the file has an encoding, encode unicode with it.
if (
isinstance(fp, file) and
isinstance(data, unicode) and
fp.encoding is not None
isinstance(fp, file)
and isinstance(data, unicode)
and fp.encoding is not None
):
errors = getattr(fp, "errors", None)
if errors is None:
@@ -842,7 +807,6 @@ if sys.version_info[0:2] < (3, 4):
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
def wrapper(f):
f = functools.wraps(wrapped, assigned, updated)(f)
f.__wrapped__ = wrapped
@@ -862,11 +826,10 @@ def with_metaclass(meta, *bases):
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})
return type.__new__(metaclass, "temporary_class", (), {})
def add_metaclass(metaclass):
@@ -874,14 +837,14 @@ def add_metaclass(metaclass):
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
slots = orig_vars.get("__slots__")
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
orig_vars.pop("__dict__", None)
orig_vars.pop("__weakref__", None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
@@ -896,14 +859,14 @@ def python_2_unicode_compatible(klass):
returning text and apply this decorator to the class.
"""
if PY2:
if '__str__' not in klass.__dict__:
if "__str__" not in klass.__dict__:
raise ValueError(
"@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." % klass.__name__
)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
klass.__str__ = lambda self: self.__unicode__().encode("utf-8")
return klass
@@ -924,8 +887,8 @@ if sys.meta_path:
# the six meta path importer, since the other six instance will have
# inserted an importer with different class.
if (
type(importer).__name__ == "_SixMetaPathImporter" and
importer.name == __name__
type(importer).__name__ == "_SixMetaPathImporter"
and importer.name == __name__
):
del sys.meta_path[i]
break
@@ -15,4 +15,4 @@ except ImportError:
# Our vendored copy
from ._implementation import CertificateError, match_hostname
# Not needed, but documenting what we provide.
__all__ = ('CertificateError', 'match_hostname')
__all__ = ("CertificateError", "match_hostname")
@@ -13,7 +13,7 @@ try:
import ipaddress
except ImportError:
ipaddress = None
__version__ = '3.5.0.1'
__version__ = "3.5.0.1"
class CertificateError(ValueError):
@@ -31,10 +31,10 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# Ported from python3-syntax:
# leftmost, *remainder = dn.split(r'.')
parts = dn.split(r'.')
parts = dn.split(r".")
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count('*')
wildcards = leftmost.count("*")
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established
@@ -51,11 +51,11 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == '*':
if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append('[^.]+')
elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
pats.append("[^.]+")
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
@@ -63,17 +63,17 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname)
def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,):
obj = unicode(obj, encoding='ascii', errors='strict')
obj = unicode(obj, encoding="ascii", errors="strict")
return obj
@@ -123,14 +123,14 @@ def match_hostname(cert, hostname):
raise
dnsnames = []
san = cert.get('subjectAltName', ())
san = cert.get("subjectAltName", ())
for key, value in san:
if key == 'DNS':
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
return
dnsnames.append(value)
elif key == 'IP Address':
elif key == "IP Address":
if host_ip is not None and _ipaddress_match(value, host_ip):
return
@@ -138,11 +138,11 @@ def match_hostname(cert, hostname):
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get('subject', ()):
for sub in cert.get("subject", ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == 'commonName':
if key == "commonName":
if _dnsname_match(value, hostname):
return
@@ -150,8 +150,7 @@ def match_hostname(cert, hostname):
if len(dnsnames) > 1:
raise CertificateError(
"hostname %r "
"doesn't match either of %s" %
(hostname, ', '.join(map(repr, dnsnames)))
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1:
+1 -1
View File
@@ -1,3 +1,3 @@
from ._sync.poolmanager import PoolManager, ProxyManager, proxy_from_url
__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url']
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
+16 -16
View File
@@ -4,7 +4,7 @@ from .filepost import encode_multipart_formdata
from .packages import six
from .packages.six.moves.urllib.parse import urlencode
__all__ = ['RequestMethods']
__all__ = ["RequestMethods"]
class RequestMethods(object):
@@ -35,7 +35,8 @@ class RequestMethods(object):
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS'])
_encode_url_methods = set(["DELETE", "GET", "HEAD", "OPTIONS"])
def __init__(self, headers=None):
self.headers = headers or {}
@@ -77,19 +78,17 @@ class RequestMethods(object):
method, url, fields=fields, headers=headers, **urlopen_kw
)
def request_encode_url(
self, method, url, fields=None, headers=None, **urlopen_kw
):
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw = {'headers': headers}
extra_kw = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += '?' + urlencode(fields)
url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw)
def request_encode_body(
@@ -139,9 +138,9 @@ class RequestMethods(object):
"""
if headers is None:
headers = self.headers
extra_kw = {'headers': {}}
extra_kw = {"headers": {}}
if fields:
if 'body' in urlopen_kw:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one."
)
@@ -151,13 +150,14 @@ class RequestMethods(object):
fields, boundary=multipart_boundary
)
else:
body, content_type = urlencode(
fields
), 'application/x-www-form-urlencoded'
body, content_type = (
urlencode(fields),
"application/x-www-form-urlencoded",
)
if isinstance(body, six.text_type):
body = body.encode('utf-8')
extra_kw['body'] = body
extra_kw['headers'] = {'Content-Type': content_type}
extra_kw['headers'].update(headers)
body = body.encode("utf-8")
extra_kw["body"] = body
extra_kw["headers"] = {"Content-Type": content_type}
extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw)
+1 -1
View File
@@ -1,3 +1,3 @@
from ._sync.response import DeflateDecoder, GzipDecoder, HTTPResponse
__all__ = ['DeflateDecoder', 'GzipDecoder', 'HTTPResponse']
__all__ = ["DeflateDecoder", "GzipDecoder", "HTTPResponse"]
+23 -23
View File
@@ -14,31 +14,31 @@ from .ssl_ import (
resolve_ssl_version,
ssl_wrap_socket,
)
from .timeout import (current_time, Timeout)
from .timeout import current_time, Timeout
from .retry import Retry
from .url import (get_host, parse_url, split_first, Url)
from .wait import (wait_for_read, wait_for_write)
from .url import get_host, parse_url, split_first, Url
from .wait import wait_for_read, wait_for_write
__all__ = (
'HAS_SNI',
'IS_PYOPENSSL',
'IS_SECURETRANSPORT',
'SSLContext',
'Retry',
'Timeout',
'Url',
'assert_fingerprint',
'current_time',
'is_connection_dropped',
'is_fp_closed',
'get_host',
'parse_url',
'make_headers',
'resolve_cert_reqs',
'resolve_ssl_version',
'split_first',
'ssl_wrap_socket',
'wait_for_read',
'wait_for_write',
"HAS_SNI",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"Retry",
"Timeout",
"Url",
"assert_fingerprint",
"current_time",
"is_connection_dropped",
"is_fp_closed",
"get_host",
"parse_url",
"make_headers",
"resolve_cert_reqs",
"resolve_ssl_version",
"split_first",
"ssl_wrap_socket",
"wait_for_read",
"wait_for_write",
)
+3 -5
View File
@@ -10,8 +10,6 @@ def is_connection_dropped(conn): # Platform-specific
return conn.is_dropped()
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
@@ -34,8 +32,8 @@ def create_connection(
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith('['):
host = host.strip('[]')
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
@@ -105,4 +103,4 @@ def _has_ipv6(host):
return has_ipv6
HAS_IPV6 = _has_ipv6('::1')
HAS_IPV6 = _has_ipv6("::1")
+13 -19
View File
@@ -4,7 +4,7 @@ from base64 import b64encode
from ..packages.six import b, integer_types
from ..exceptions import UnrewindableBodyError
ACCEPT_ENCODING = 'gzip,deflate'
ACCEPT_ENCODING = "gzip,deflate"
_FAILEDTELL = object()
@@ -55,26 +55,22 @@ def make_headers(
if isinstance(accept_encoding, str):
pass
elif isinstance(accept_encoding, list):
accept_encoding = ','.join(accept_encoding)
accept_encoding = ",".join(accept_encoding)
else:
accept_encoding = ACCEPT_ENCODING
headers['accept-encoding'] = accept_encoding
headers["accept-encoding"] = accept_encoding
if user_agent:
headers['user-agent'] = user_agent
headers["user-agent"] = user_agent
if keep_alive:
headers['connection'] = 'keep-alive'
headers["connection"] = "keep-alive"
if basic_auth:
headers['authorization'] = 'Basic ' + b64encode(b(basic_auth)).decode(
'utf-8'
)
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
if proxy_basic_auth:
headers['proxy-authorization'] = 'Basic ' + b64encode(
headers["proxy-authorization"] = "Basic " + b64encode(
b(proxy_basic_auth)
).decode(
'utf-8'
)
).decode("utf-8")
if disable_cache:
headers['cache-control'] = 'no-cache'
headers["cache-control"] = "no-cache"
return headers
@@ -85,7 +81,7 @@ def set_file_position(body, pos):
"""
if pos is not None:
rewind_body(body, pos)
elif getattr(body, 'tell', None) is not None:
elif getattr(body, "tell", None) is not None:
try:
pos = body.tell()
except (IOError, OSError):
@@ -106,14 +102,13 @@ def rewind_body(body, body_pos):
:param int pos:
Position to seek to in file.
"""
body_seek = getattr(body, 'seek', None)
body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types):
try:
body_seek(body_pos)
except (IOError, OSError):
raise UnrewindableBodyError(
"An error occurred when rewinding request "
"body for redirect/retry."
"An error occurred when rewinding request " "body for redirect/retry."
)
elif body_pos is _FAILEDTELL:
@@ -124,6 +119,5 @@ def rewind_body(body, body_pos):
else:
raise ValueError(
"body_pos must be of type integer, "
"instead it was %s." % type(body_pos)
"body_pos must be of type integer, " "instead it was %s." % type(body_pos)
)
+18 -31
View File
@@ -19,7 +19,7 @@ from ..packages import six
log = logging.getLogger(__name__)
# Data structure for representing the metadata of requests that result in a retry.
RequestHistory = namedtuple(
'RequestHistory', ["method", "url", "error", "status", "redirect_location"]
"RequestHistory", ["method", "url", "error", "status", "redirect_location"]
)
@@ -139,8 +139,9 @@ class Retry(object):
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
"""
DEFAULT_METHOD_WHITELIST = frozenset(
['HEAD', 'GET', 'PUT', 'DELETE', 'OPTIONS', 'TRACE']
["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]
)
RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503])
# : Maximum backoff time.
@@ -215,18 +216,13 @@ class Retry(object):
# We want to consider only the last consecutive errors sequence (Ignore redirects).
consecutive_errors_len = len(
list(
takewhile(
lambda x: x.redirect_location is None,
reversed(self.history),
)
takewhile(lambda x: x.redirect_location is None, reversed(self.history))
)
)
if consecutive_errors_len <= 1:
return 0
backoff_value = self.backoff_factor * (
2 ** (consecutive_errors_len - 1)
)
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
return min(self.BACKOFF_MAX, backoff_value)
def parse_retry_after(self, retry_after):
@@ -236,9 +232,7 @@ class Retry(object):
else:
retry_date_tuple = email.utils.parsedate(retry_after)
if retry_date_tuple is None:
raise InvalidHeader(
"Invalid Retry-After header: %s" % retry_after
)
raise InvalidHeader("Invalid Retry-After header: %s" % retry_after)
retry_date = time.mktime(retry_date_tuple)
seconds = retry_date - time.time()
@@ -300,8 +294,7 @@ class Retry(object):
""" Checks if a given HTTP method should be retried upon, depending if
it is included on the method whitelist.
"""
if self.method_whitelist and method.upper(
) not in self.method_whitelist:
if self.method_whitelist and method.upper() not in self.method_whitelist:
return False
return True
@@ -320,17 +313,15 @@ class Retry(object):
return True
return (
self.total and
self.respect_retry_after_header and
has_retry_after and
(status_code in self.RETRY_AFTER_STATUS_CODES)
self.total
and self.respect_retry_after_header
and has_retry_after
and (status_code in self.RETRY_AFTER_STATUS_CODES)
)
def is_exhausted(self):
""" Are we out of retries? """
retry_counts = (
self.total, self.connect, self.read, self.redirect, self.status
)
retry_counts = (self.total, self.connect, self.read, self.redirect, self.status)
retry_counts = list(filter(None, retry_counts))
if not retry_counts:
return False
@@ -367,7 +358,7 @@ class Retry(object):
read = self.read
redirect = self.redirect
status_count = self.status
cause = 'unknown'
cause = "unknown"
status = None
redirect_location = None
if error and self._is_connection_error(error):
@@ -388,7 +379,7 @@ class Retry(object):
# Redirect retry?
if redirect is not None:
redirect -= 1
cause = 'too many redirects'
cause = "too many redirects"
redirect_location = response.get_redirect_location()
status = response.status
else:
@@ -398,9 +389,7 @@ class Retry(object):
if response and response.status:
if status_count is not None:
status_count -= 1
cause = ResponseError.SPECIFIC_ERROR.format(
status_code=response.status
)
cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status)
status = response.status
history = self.history + (
RequestHistory(method, url, error, status, redirect_location),
@@ -421,11 +410,9 @@ class Retry(object):
def __repr__(self):
return (
'{cls.__name__}(total={self.total}, connect={self.connect}, '
'read={self.read}, redirect={self.redirect}, status={self.status})'
).format(
cls=type(self), self=self
)
"{cls.__name__}(total={self.total}, connect={self.connect}, "
"read={self.read}, redirect={self.redirect}, status={self.status})"
).format(cls=type(self), self=self)
# For backwards compatibility (equivalent to pre-v1.9):
+19 -37
View File
@@ -21,15 +21,14 @@ try:
monotonic = time.monotonic
except (AttributeError, ImportError): # Python 3.3<
monotonic = time.time
EVENT_READ = (1 << 0)
EVENT_WRITE = (1 << 1)
EVENT_READ = 1 << 0
EVENT_WRITE = 1 << 1
HAS_SELECT = True # Variable that shows whether the platform has a selector.
_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None.
_DEFAULT_SELECTOR = None
class SelectorError(Exception):
def __init__(self, errcode):
super(SelectorError, self).__init__()
self.errno = errcode
@@ -94,9 +93,7 @@ else:
expires = monotonic() + timeout
args = list(args)
if recalc_timeout and "timeout" not in kwargs:
raise ValueError(
"Timeout must be in args or kwargs to be recalculated"
)
raise ValueError("Timeout must be in args or kwargs to be recalculated")
result = _SYSCALL_SENTINEL
while result is _SYSCALL_SENTINEL:
@@ -114,9 +111,8 @@ else:
elif hasattr(e, "args"):
errcode = e.args[0]
# Also test for the Windows equivalent of EINTR.
is_interrupt = (
errcode == errno.EINTR or
(hasattr(errno, "WSAEINTR") and errcode == errno.WSAEINTR)
is_interrupt = errcode == errno.EINTR or (
hasattr(errno, "WSAEINTR") and errcode == errno.WSAEINTR
)
if is_interrupt:
if expires is not None:
@@ -138,7 +134,7 @@ else:
return result
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
SelectorKey = namedtuple("SelectorKey", ["fileobj", "fd", "events", "data"])
class _SelectorMapping(Mapping):
@@ -393,9 +389,7 @@ if hasattr(select, "poll"):
def select(self, timeout=None):
ready = []
fd_events = _syscall_wrapper(
self._wrap_poll, True, timeout=timeout
)
fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.POLLIN:
@@ -489,14 +483,10 @@ if hasattr(select, "kqueue"):
def register(self, fileobj, events, data=None):
key = super(KqueueSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
kevent = select.kevent(
key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD
)
kevent = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
if events & EVENT_WRITE:
kevent = select.kevent(
key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD
)
kevent = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
return key
@@ -507,9 +497,7 @@ if hasattr(select, "kqueue"):
key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE
)
try:
_syscall_wrapper(
self._kqueue.control, False, [kevent], 0, 0
)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
if key.events & EVENT_WRITE:
@@ -517,9 +505,7 @@ if hasattr(select, "kqueue"):
key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE
)
try:
_syscall_wrapper(
self._kqueue.control, False, [kevent], 0, 0
)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
return key
@@ -546,9 +532,7 @@ if hasattr(select, "kqueue"):
ready_fds[key.fd] = (key, events & key.events)
else:
old_events = ready_fds[key.fd][1]
ready_fds[key.fd] = (
key, (events | old_events) & key.events
)
ready_fds[key.fd] = (key, (events | old_events) & key.events)
return list(ready_fds.values())
def close(self):
@@ -556,7 +540,7 @@ if hasattr(select, "kqueue"):
super(KqueueSelector, self).close()
if not hasattr(select, 'select'): # Platform-specific: AppEngine
if not hasattr(select, "select"): # Platform-specific: AppEngine
HAS_SELECT = False
@@ -567,7 +551,7 @@ def _can_allocate(struct):
don't have it available will not advertise it. (ie: GAE) """
try:
# select.poll() objects won't fail until used.
if struct == 'poll':
if struct == "poll":
p = select.poll()
p.poll(0)
# All others will fail on allocation.
@@ -579,8 +563,6 @@ def _can_allocate(struct):
return False
# Choose the best implementation, roughly:
# kqueue == epoll > poll > select. Devpoll not supported. (See above)
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
@@ -590,15 +572,15 @@ def DefaultSelector():
by eventlet, greenlet, and preserve proper behavior. """
global _DEFAULT_SELECTOR
if _DEFAULT_SELECTOR is None:
if _can_allocate('kqueue'):
if _can_allocate("kqueue"):
_DEFAULT_SELECTOR = KqueueSelector
elif _can_allocate('epoll'):
elif _can_allocate("epoll"):
_DEFAULT_SELECTOR = EpollSelector
elif _can_allocate('poll'):
elif _can_allocate("poll"):
_DEFAULT_SELECTOR = PollSelector
elif hasattr(select, 'select'):
elif hasattr(select, "select"):
_DEFAULT_SELECTOR = SelectSelector
else: # Platform-specific: AppEngine
raise ValueError('Platform does not have a selector')
raise ValueError("Platform does not have a selector")
return _DEFAULT_SELECTOR()
+50 -57
View File
@@ -9,7 +9,8 @@ from hashlib import md5, sha1, sha256
from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning
from ..packages.ssl_match_hostname import (
match_hostname as _match_hostname, CertificateError
match_hostname as _match_hostname,
CertificateError,
)
SSLContext = None
@@ -34,9 +35,7 @@ def _const_compare_digest_backport(a, b):
return result == 0
_const_compare_digest = getattr(
hmac, 'compare_digest', _const_compare_digest_backport
)
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
try: # Test for SSL features
import ssl
from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23
@@ -68,24 +67,24 @@ except ImportError:
# security,
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
DEFAULT_CIPHERS = ':'.join(
DEFAULT_CIPHERS = ":".join(
[
'TLS13-AES-256-GCM-SHA384',
'TLS13-CHACHA20-POLY1305-SHA256',
'TLS13-AES-128-GCM-SHA256',
'ECDH+AESGCM',
'ECDH+CHACHA20',
'DH+AESGCM',
'DH+CHACHA20',
'ECDH+AES256',
'DH+AES256',
'ECDH+AES128',
'DH+AES',
'RSA+AESGCM',
'RSA+AES',
'!aNULL',
'!eNULL',
'!MD5',
"TLS13-AES-256-GCM-SHA384",
"TLS13-CHACHA20-POLY1305-SHA256",
"TLS13-AES-128-GCM-SHA256",
"ECDH+AESGCM",
"ECDH+CHACHA20",
"DH+AESGCM",
"DH+CHACHA20",
"ECDH+AES256",
"DH+AES256",
"ECDH+AES128",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
]
)
try:
@@ -95,7 +94,6 @@ except ImportError:
# TODO: Can we remove this by choosing to support only platforms with
# actual SSLContext objects?
class SSLContext(object): # Platform-specific: Python 2 & 3.1
def __init__(self, protocol_version):
self.protocol = protocol_version
# Use default values from a real SSLContext
@@ -121,21 +119,21 @@ except ImportError:
def wrap_socket(self, socket, server_hostname=None, server_side=False):
warnings.warn(
'A true SSLContext object is not available. This prevents '
'urllib3 from configuring SSL appropriately and may cause '
'certain SSL connections to fail. You can upgrade to a newer '
'version of Python to solve this. For more information, see '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings',
"A true SSLContext object is not available. This prevents "
"urllib3 from configuring SSL appropriately and may cause "
"certain SSL connections to fail. You can upgrade to a newer "
"version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings",
InsecurePlatformWarning,
)
kwargs = {
'keyfile': self.keyfile,
'certfile': self.certfile,
'ca_certs': self.ca_certs,
'cert_reqs': self.verify_mode,
'ssl_version': self.protocol,
'server_side': server_side,
"keyfile": self.keyfile,
"certfile": self.certfile,
"ca_certs": self.ca_certs,
"cert_reqs": self.verify_mode,
"ssl_version": self.protocol,
"server_side": server_side,
}
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
@@ -149,13 +147,11 @@ def assert_fingerprint(cert, fingerprint):
:param fingerprint:
Fingerprint as string of hexdigits, can be interspersed by colons.
"""
fingerprint = fingerprint.replace(':', '').lower()
fingerprint = fingerprint.replace(":", "").lower()
digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc:
raise SSLError(
'Fingerprint of invalid length: {0}'.format(fingerprint)
)
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
# We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode())
@@ -185,7 +181,7 @@ def resolve_cert_reqs(candidate):
if isinstance(candidate, str):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, 'CERT_' + candidate)
res = getattr(ssl, "CERT_" + candidate)
return res
return candidate
@@ -201,7 +197,7 @@ def resolve_ssl_version(candidate):
if isinstance(candidate, str):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, 'PROTOCOL_' + candidate)
res = getattr(ssl, "PROTOCOL_" + candidate)
return res
return candidate
@@ -258,9 +254,9 @@ def create_urllib3_context(
context.options |= options
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
context.verify_mode = cert_reqs
if getattr(
context, 'check_hostname', None
) is not None: # Platform-specific: Python 3.2
if (
getattr(context, "check_hostname", None) is not None
): # Platform-specific: Python 3.2
# We do our own verification, including fingerprints and alternative
# hostnames. So disable it here
context.check_hostname = False
@@ -294,7 +290,7 @@ def merge_context_settings(
raise
elif getattr(context, 'load_default_certs', None) is not None:
elif getattr(context, "load_default_certs", None) is not None:
# try to load OS default certs; works well on Windows (require Python3.4+)
context.load_default_certs()
if certfile:
@@ -335,9 +331,7 @@ def ssl_wrap_socket(
# Note: This branch of code and all the variables in it are no longer
# used by urllib3 itself. We should consider deprecating and removing
# this code.
context = create_urllib3_context(
ssl_version, cert_reqs, ciphers=ciphers
)
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
if ca_certs or ca_cert_dir:
try:
context.load_verify_locations(ca_certs, ca_cert_dir)
@@ -352,7 +346,7 @@ def ssl_wrap_socket(
raise
elif getattr(context, 'load_default_certs', None) is not None:
elif getattr(context, "load_default_certs", None) is not None:
# try to load OS default certs; works well on Windows (require Python3.4+)
context.load_default_certs()
if certfile:
@@ -361,13 +355,13 @@ def ssl_wrap_socket(
return context.wrap_socket(sock, server_hostname=server_hostname)
warnings.warn(
'An HTTPS request has been made, but the SNI (Server Name '
'Indication) extension to TLS is not available on this platform. '
'This may cause the server to present an incorrect TLS '
'certificate, which can cause validation failures. You can upgrade to '
'a newer version of Python to solve this. For more information, see '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings',
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
"This may cause the server to present an incorrect TLS "
"certificate, which can cause validation failures. You can upgrade to "
"a newer version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings",
SNIMissingWarning,
)
return context.wrap_socket(sock)
@@ -378,8 +372,7 @@ def match_hostname(cert, asserted_hostname):
_match_hostname(cert, asserted_hostname)
except CertificateError as e:
log.error(
'Certificate did not match expected hostname: %s. '
'Certificate: %s',
"Certificate did not match expected hostname: %s. " "Certificate: %s",
asserted_hostname,
cert,
)
+15 -15
View File
@@ -85,18 +85,22 @@ class Timeout(object):
time, consider having a second "watcher" thread to cut off a slow
request.
"""
# : A sentinel object representing the default timeout value
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
def __init__(self, total=None, connect=_Default, read=_Default):
self._connect = self._validate_timeout(connect, 'connect')
self._read = self._validate_timeout(read, 'read')
self.total = self._validate_timeout(total, 'total')
self._connect = self._validate_timeout(connect, "connect")
self._read = self._validate_timeout(read, "read")
self.total = self._validate_timeout(total, "total")
self._start_connect = None
def __str__(self):
return '%s(connect=%r, read=%r, total=%r)' % (
type(self).__name__, self._connect, self._read, self.total
return "%s(connect=%r, read=%r, total=%r)" % (
type(self).__name__,
self._connect,
self._read,
self.total,
)
@classmethod
@@ -174,9 +178,7 @@ class Timeout(object):
# We can't use copy.deepcopy because that will also create a new object
# for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
# detect the user default.
return Timeout(
connect=self._connect, read=self._read, total=self.total
)
return Timeout(connect=self._connect, read=self._read, total=self.total)
def start_connect(self):
""" Start the timeout clock, used during a connect() attempt
@@ -241,18 +243,16 @@ class Timeout(object):
has not yet been called on this object.
"""
if (
self.total is not None and
self.total is not self.DEFAULT_TIMEOUT and
self._read is not None and
self._read is not self.DEFAULT_TIMEOUT
self.total is not None
and self.total is not self.DEFAULT_TIMEOUT
and self._read is not None
and self._read is not self.DEFAULT_TIMEOUT
):
# In case the connect timeout has not yet been established.
if self._start_connect is None:
return self._read
return max(
0, min(self.total - self.get_connect_duration(), self._read)
)
return max(0, min(self.total - self.get_connect_duration(), self._read))
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
return max(0, self.total - self.get_connect_duration())
+32 -31
View File
@@ -3,18 +3,19 @@ from collections import namedtuple
from ..exceptions import LocationParseError
url_attrs = ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment']
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
# We only want to normalize urls with an HTTP(S) scheme.
# urllib3 infers URLs without a scheme (None) to be http.
NORMALIZABLE_SCHEMES = ('http', 'https', None)
NORMALIZABLE_SCHEMES = ("http", "https", None)
class Url(namedtuple('Url', url_attrs)):
class Url(namedtuple("Url", url_attrs)):
"""
Datastructure for representing an HTTP URL. Used as a return value for
:func:`parse_url`. Both the scheme and host are normalized as they are
both case-insensitive according to RFC 3986.
"""
__slots__ = ()
def __new__(
@@ -27,8 +28,8 @@ class Url(namedtuple('Url', url_attrs)):
query=None,
fragment=None,
):
if path and not path.startswith('/'):
path = '/' + path
if path and not path.startswith("/"):
path = "/" + path
if scheme:
scheme = scheme.lower()
if host and scheme in NORMALIZABLE_SCHEMES:
@@ -45,16 +46,16 @@ class Url(namedtuple('Url', url_attrs)):
@property
def request_uri(self):
"""Absolute path including the query string."""
uri = self.path or '/'
uri = self.path or "/"
if self.query is not None:
uri += '?' + self.query
uri += "?" + self.query
return uri
@property
def netloc(self):
"""Network location including host and port"""
if self.port:
return '%s:%d' % (self.host, self.port)
return "%s:%d" % (self.host, self.port)
return self.host
@@ -78,22 +79,22 @@ class Url(namedtuple('Url', url_attrs)):
'http://username:password@host.com:80/path?query#fragment'
"""
scheme, auth, host, port, path, query, fragment = self
url = ''
url = ""
# We use "is not None" we want things to happen with empty strings (or 0 port)
if scheme is not None:
url += scheme + '://'
url += scheme + "://"
if auth is not None:
url += auth + '@'
url += auth + "@"
if host is not None:
url += host
if port is not None:
url += ':' + str(port)
url += ":" + str(port)
if path is not None:
url += path
if query is not None:
url += '?' + query
url += "?" + query
if fragment is not None:
url += '#' + fragment
url += "#" + fragment
return url
def __str__(self):
@@ -127,9 +128,9 @@ def split_first(s, delims):
min_idx = idx
min_delim = d
if min_idx is None or min_idx < 0:
return s, '', None
return s, "", None
return s[:min_idx], s[min_idx + 1:], min_delim
return s[:min_idx], s[min_idx + 1 :], min_delim
def parse_url(url):
@@ -164,25 +165,25 @@ def parse_url(url):
fragment = None
query = None
# Scheme
if '://' in url:
scheme, url = url.split('://', 1)
if "://" in url:
scheme, url = url.split("://", 1)
# Find the earliest Authority Terminator
# (http://tools.ietf.org/html/rfc3986#section-3.2)
url, path_, delim = split_first(url, ['/', '?', '#'])
url, path_, delim = split_first(url, ["/", "?", "#"])
if delim:
# Reassemble the path
path = delim + path_
# Auth
if '@' in url:
if "@" in url:
# Last '@' denotes end of auth part
auth, url = url.rsplit('@', 1)
auth, url = url.rsplit("@", 1)
# IPv6
if url and url[0] == '[':
host, url = url.split(']', 1)
host += ']'
if url and url[0] == "[":
host, url = url.split("]", 1)
host += "]"
# Port
if ':' in url:
_host, port = url.split(':', 1)
if ":" in url:
_host, port = url.split(":", 1)
if not host:
host = _host
if port:
@@ -205,11 +206,11 @@ def parse_url(url):
return Url(scheme, auth, host, port, path, query, fragment)
# Fragment
if '#' in path:
path, fragment = path.split('#', 1)
if "#" in path:
path, fragment = path.split("#", 1)
# Query
if '?' in path:
path, query = path.split('?', 1)
if "?" in path:
path, query = path.split("?", 1)
return Url(scheme, auth, host, port, path, query, fragment)
@@ -218,4 +219,4 @@ def get_host(url):
Deprecated. Use :func:`parse_url` instead.
"""
p = parse_url(url)
return p.scheme or 'http', p.hostname, p.port
return p.scheme or "http", p.hostname, p.port
+3 -7
View File
@@ -1,4 +1,4 @@
from .selectors import (HAS_SELECT, DefaultSelector, EVENT_READ, EVENT_WRITE)
from .selectors import HAS_SELECT, DefaultSelector, EVENT_READ, EVENT_WRITE
def _wait_for_io_events(socks, events, timeout=None):
@@ -6,7 +6,7 @@ def _wait_for_io_events(socks, events, timeout=None):
or optionally a single socket if passed in. Returns a list of
sockets that can be interacted with immediately. """
if not HAS_SELECT:
raise ValueError('Platform does not have a selector')
raise ValueError("Platform does not have a selector")
if not isinstance(socks, list):
# Probably just a single socket.
@@ -18,11 +18,7 @@ def _wait_for_io_events(socks, events, timeout=None):
with DefaultSelector() as selector:
for sock in socks:
selector.register(sock, events)
return [
key[0].fileobj
for key in selector.select(timeout)
if key[1] & events
]
return [key[0].fileobj for key in selector.select(timeout) if key[1] & events]
def wait_for_read(socks, timeout=None):
+6 -9
View File
@@ -15,14 +15,10 @@ class RequestException(IOError):
def __init__(self, *args, **kwargs):
"""Initialize RequestException with `request` and `response` objects."""
response = kwargs.pop('response', None)
response = kwargs.pop("response", None)
self.response = response
self.request = kwargs.pop('request', None)
if (
response is not None and
not self.request and
hasattr(response, 'request')
):
self.request = kwargs.pop("request", None)
if response is not None and not self.request and hasattr(response, "request"):
self.request = self.response.request
super(RequestException, self).__init__(*args, **kwargs)
@@ -115,19 +111,20 @@ class InvalidBodyError(RequestException, ValueError):
"""An invalid request body was specified"""
# Warnings
class RequestsWarning(Warning):
"""Base warning for Requests."""
pass
class FileModeWarning(RequestsWarning, DeprecationWarning):
"""A file was opened in text mode, but Requests determined its binary length."""
pass
class RequestsDependencyWarning(RequestsWarning):
"""An imported dependency doesn't match the expected version range."""
pass
+57 -92
View File
@@ -92,26 +92,26 @@ def _pool_kwargs(verify, cert):
"invalid path: {0}".format(cert_loc)
)
pool_kwargs['cert_reqs'] = 'CERT_REQUIRED'
pool_kwargs["cert_reqs"] = "CERT_REQUIRED"
if not os.path.isdir(cert_loc):
pool_kwargs['ca_certs'] = cert_loc
pool_kwargs['ca_cert_dir'] = None
pool_kwargs["ca_certs"] = cert_loc
pool_kwargs["ca_cert_dir"] = None
else:
pool_kwargs['ca_cert_dir'] = cert_loc
pool_kwargs['ca_certs'] = None
pool_kwargs["ca_cert_dir"] = cert_loc
pool_kwargs["ca_certs"] = None
else:
pool_kwargs['cert_reqs'] = 'CERT_NONE'
pool_kwargs['ca_certs'] = None
pool_kwargs['ca_cert_dir'] = None
pool_kwargs["cert_reqs"] = "CERT_NONE"
pool_kwargs["ca_certs"] = None
pool_kwargs["ca_cert_dir"] = None
if cert:
if not isinstance(cert, basestring):
pool_kwargs['cert_file'] = cert[0]
pool_kwargs['key_file'] = cert[1]
pool_kwargs["cert_file"] = cert[0]
pool_kwargs["key_file"] = cert[1]
else:
pool_kwargs['cert_file'] = cert
pool_kwargs['key_file'] = None
cert_file = pool_kwargs['cert_file']
key_file = pool_kwargs['key_file']
pool_kwargs["cert_file"] = cert
pool_kwargs["key_file"] = None
cert_file = pool_kwargs["cert_file"]
key_file = pool_kwargs["key_file"]
if cert_file and not os.path.exists(cert_file):
raise IOError(
"Could not find the TLS certificate file, "
@@ -120,8 +120,7 @@ def _pool_kwargs(verify, cert):
if key_file and not os.path.exists(key_file):
raise IOError(
"Could not find the TLS key file, "
"invalid path: {0}".format(key_file)
"Could not find the TLS key file, " "invalid path: {0}".format(key_file)
)
return pool_kwargs
@@ -134,13 +133,7 @@ class BaseAdapter(object):
super(BaseAdapter, self).__init__()
def send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
):
"""Sends PreparedRequest object. Returns Response object.
@@ -189,12 +182,13 @@ class HTTPAdapter(BaseAdapter):
>>> a = requests.adapters.HTTPAdapter(max_retries=3)
>>> s.mount('http://', a)
"""
__attrs__ = [
'max_retries',
'config',
'_pool_connections',
'_pool_maxsize',
'_pool_block',
"max_retries",
"config",
"_pool_connections",
"_pool_maxsize",
"_pool_block",
]
def __init__(
@@ -270,7 +264,7 @@ class HTTPAdapter(BaseAdapter):
"""
if proxy in self.proxy_manager:
manager = self.proxy_manager[proxy]
elif proxy.lower().startswith('socks'):
elif proxy.lower().startswith("socks"):
username, password = get_auth_from_url(proxy)
manager = self.proxy_manager[proxy] = SOCKSProxyManager(
proxy,
@@ -305,15 +299,15 @@ class HTTPAdapter(BaseAdapter):
"""
response = Response()
# Fallback to None if there's no status_code, for whatever reason.
response.status_code = getattr(resp, 'status', None)
response.status_code = getattr(resp, "status", None)
# Make headers case-insensitive.
response.headers = HTTPHeaderDict(getattr(resp, 'headers', {}))
response.headers = HTTPHeaderDict(getattr(resp, "headers", {}))
# Set encoding.
response.encoding = get_encoding_from_headers(response.headers)
response.raw = resp
response.reason = response.raw.reason
if isinstance(req.url, bytes):
response.url = req.url.decode('utf-8')
response.url = req.url.decode("utf-8")
else:
response.url = req.url
# Add new cookies from the server.
@@ -335,18 +329,14 @@ class HTTPAdapter(BaseAdapter):
pool_kwargs = _pool_kwargs(verify, cert)
proxy = select_proxy(url, proxies)
if proxy:
proxy = prepend_scheme_if_needed(proxy, 'http')
proxy = prepend_scheme_if_needed(proxy, "http")
proxy_manager = self.proxy_manager_for(proxy)
conn = proxy_manager.connection_from_url(
url, pool_kwargs=pool_kwargs
)
conn = proxy_manager.connection_from_url(url, pool_kwargs=pool_kwargs)
else:
# Only scheme should be lower case
parsed = urlparse(url)
url = parsed.geturl()
conn = self.poolmanager.connection_from_url(
url, pool_kwargs=pool_kwargs
)
conn = self.poolmanager.connection_from_url(url, pool_kwargs=pool_kwargs)
return conn
def close(self):
@@ -375,11 +365,11 @@ class HTTPAdapter(BaseAdapter):
"""
proxy = select_proxy(request.url, proxies)
scheme = urlparse(request.url).scheme
is_proxied_http_request = (proxy and scheme != 'https')
is_proxied_http_request = proxy and scheme != "https"
using_socks_proxy = False
if proxy:
proxy_scheme = urlparse(proxy).scheme.lower()
using_socks_proxy = proxy_scheme.startswith('socks')
using_socks_proxy = proxy_scheme.startswith("socks")
url = request.path_url
if is_proxied_http_request and not using_socks_proxy:
url = urldefragauth(request.url)
@@ -415,19 +405,11 @@ class HTTPAdapter(BaseAdapter):
headers = {}
username, password = get_auth_from_url(proxy)
if username:
headers['Proxy-Authorization'] = _basic_auth_str(
username, password
)
headers["Proxy-Authorization"] = _basic_auth_str(username, password)
return headers
def send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
):
"""Sends PreparedRequest object. Returns Response object.
@@ -447,9 +429,7 @@ class HTTPAdapter(BaseAdapter):
conn = self.get_connection(request.url, proxies, verify, cert)
url = self.request_url(request, proxies)
self.add_headers(request)
chunked = not (
request.body is None or 'Content-Length' in request.headers
)
chunked = not (request.body is None or "Content-Length" in request.headers)
if isinstance(timeout, tuple):
try:
connect, read = timeout
@@ -481,17 +461,15 @@ class HTTPAdapter(BaseAdapter):
retries=self.max_retries,
timeout=timeout,
enforce_content_length=True,
pool=conn
pool=conn,
)
# Send the request.
else:
if hasattr(conn, 'proxy_pool'):
if hasattr(conn, "proxy_pool"):
conn = conn.proxy_pool
low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT)
try:
low_conn.putrequest(
request.method, url, skip_accept_encoding=True
)
low_conn.putrequest(request.method, url, skip_accept_encoding=True)
for header, value in request.headers.items():
low_conn.putheader(header, value)
low_conn.endheaders()
@@ -500,11 +478,11 @@ class HTTPAdapter(BaseAdapter):
if chunk_size == 0:
continue
low_conn.send(hex(chunk_size)[2:].encode('utf-8'))
low_conn.send(b'\r\n')
low_conn.send(hex(chunk_size)[2:].encode("utf-8"))
low_conn.send(b"\r\n")
low_conn.send(i)
low_conn.send(b'\r\n')
low_conn.send(b'0\r\n\r\n')
low_conn.send(b"\r\n")
low_conn.send(b"0\r\n\r\n")
# Receive the response from the server
try:
# For Python 2.7, use buffering of HTTP responses
@@ -570,6 +548,7 @@ class HTTPAdapter(BaseAdapter):
class AsyncHTTPAdapter(HTTPAdapter):
"""docstring for AsyncHTTPAdapter"""
def __init__(self, backend=None, *args, **kwargs):
self.backend = backend or TrioBackend()
super(AsyncHTTPAdapter, self).__init__(*args, **kwargs)
@@ -586,15 +565,15 @@ class AsyncHTTPAdapter(HTTPAdapter):
"""
response = AsyncResponse()
# Fallback to None if there's no status_code, for whatever reason.
response.status_code = getattr(resp, 'status', None)
response.status_code = getattr(resp, "status", None)
# Make headers case-insensitive.
response.headers = HTTPHeaderDict(getattr(resp, 'headers', {}))
response.headers = HTTPHeaderDict(getattr(resp, "headers", {}))
# Set encoding.
response.encoding = get_encoding_from_headers(response.headers)
response.raw = resp
response.reason = response.raw.reason
if isinstance(req.url, bytes):
response.url = req.url.decode('utf-8')
response.url = req.url.decode("utf-8")
else:
response.url = req.url
# Add new cookies from the server.
@@ -643,18 +622,14 @@ class AsyncHTTPAdapter(HTTPAdapter):
pool_kwargs = _pool_kwargs(verify, cert)
proxy = select_proxy(url, proxies)
if proxy:
proxy = prepend_scheme_if_needed(proxy, 'http')
proxy = prepend_scheme_if_needed(proxy, "http")
proxy_manager = self.proxy_manager_for(proxy)
conn = proxy_manager.connection_from_url(
url, pool_kwargs=pool_kwargs
)
conn = proxy_manager.connection_from_url(url, pool_kwargs=pool_kwargs)
else:
# Only scheme should be lower case
parsed = urlparse(url)
url = parsed.geturl()
conn = self.poolmanager.connection_from_url(
url, pool_kwargs=pool_kwargs
)
conn = self.poolmanager.connection_from_url(url, pool_kwargs=pool_kwargs)
return conn
def close(self):
@@ -669,13 +644,7 @@ class AsyncHTTPAdapter(HTTPAdapter):
pass
async def send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
):
"""Sends PreparedRequest object. Returns Response object.
@@ -696,9 +665,7 @@ class AsyncHTTPAdapter(HTTPAdapter):
url = self.request_url(request, proxies)
self.add_headers(request)
chunked = not (
request.body is None or 'Content-Length' in request.headers
)
chunked = not (request.body is None or "Content-Length" in request.headers)
if isinstance(timeout, tuple):
try:
connect, read = timeout
@@ -730,18 +697,16 @@ class AsyncHTTPAdapter(HTTPAdapter):
retries=self.max_retries,
timeout=timeout,
enforce_content_length=True,
pool=conn
pool=conn,
)
# Send the request.
else:
if hasattr(conn, 'proxy_pool'):
if hasattr(conn, "proxy_pool"):
conn = conn.proxy_pool
low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT)
try:
low_conn.putrequest(
request.method, url, skip_accept_encoding=True
)
low_conn.putrequest(request.method, url, skip_accept_encoding=True)
for header, value in request.headers.items():
low_conn.putheader(header, value)
low_conn.endheaders()
@@ -750,11 +715,11 @@ class AsyncHTTPAdapter(HTTPAdapter):
if chunk_size == 0:
continue
low_conn.send(hex(chunk_size)[2:].encode('utf-8'))
low_conn.send(b'\r\n')
low_conn.send(hex(chunk_size)[2:].encode("utf-8"))
low_conn.send(b"\r\n")
low_conn.send(i)
low_conn.send(b'\r\n')
low_conn.send(b'0\r\n\r\n')
low_conn.send(b"\r\n")
low_conn.send(b"0\r\n\r\n")
# Receive the response from the server
try:
# For Python 2.7, use buffering of HTTP responses
+58 -67
View File
@@ -51,11 +51,11 @@ class MockRequest(object):
def get_full_url(self):
# Only return the response's URL if the user hadn't set the Host
# header
if not self._r.headers.get('Host'):
if not self._r.headers.get("Host"):
return self._r.url
# If they did set it, retrieve it and reconstruct the expected domain
host = to_native_string(self._r.headers['Host'], encoding='utf-8')
host = to_native_string(self._r.headers["Host"], encoding="utf-8")
parsed = urlparse(self._r.url)
# Reconstruct the URL as we expect it
return urlunparse(
@@ -131,9 +131,7 @@ def extract_cookies_to_jar(jar, request, response):
:param request: our own requests.Request object
:param response: urllib3.HTTPResponse object
"""
if not (
hasattr(response, '_original_response') and response._original_response
):
if not (hasattr(response, "_original_response") and response._original_response):
return
# the _original_response field is the wrapped httplib.HTTPResponse object,
@@ -151,7 +149,7 @@ def get_cookie_header(jar, request):
"""
r = MockRequest(request)
jar.add_cookie_header(r)
return r.get_new_headers().get('Cookie')
return r.get_new_headers().get("Cookie")
def remove_cookie_by_name(cookiejar, name, domain=None, path=None):
@@ -220,10 +218,7 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
# support client code that unsets cookies by assignment of a None value:
if value is None:
remove_cookie_by_name(
self,
name,
domain=kwargs.get('domain'),
path=kwargs.get('path'),
self, name, domain=kwargs.get("domain"), path=kwargs.get("path")
)
return
@@ -325,9 +320,8 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
"""
dictionary = {}
for cookie in iter(self):
if (
(domain is None or cookie.domain == domain) and
(path is None or cookie.path == path)
if (domain is None or cookie.domain == domain) and (
path is None or cookie.path == path
):
dictionary[cookie.name] = cookie.value
return dictionary
@@ -362,15 +356,13 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
remove_cookie_by_name(self, name)
def set_cookie(self, cookie, *args, **kwargs):
if hasattr(cookie.value, 'startswith') and cookie.value.startswith(
'"'
) and cookie.value.endswith(
'"'
if (
hasattr(cookie.value, "startswith")
and cookie.value.startswith('"')
and cookie.value.endswith('"')
):
cookie.value = cookie.value.replace('\\"', '')
return super(RequestsCookieJar, self).set_cookie(
cookie, *args, **kwargs
)
cookie.value = cookie.value.replace('\\"', "")
return super(RequestsCookieJar, self).set_cookie(cookie, *args, **kwargs)
def update(self, other):
"""Updates this jar with cookies from another CookieJar or dict-like"""
@@ -398,7 +390,7 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
if path is None or cookie.path == path:
return cookie.value
raise KeyError('name=%r, domain=%r, path=%r' % (name, domain, path))
raise KeyError("name=%r, domain=%r, path=%r" % (name, domain, path))
def _find_no_duplicates(self, name, domain=None, path=None):
"""Both ``__get_item__`` and ``get`` call this function: it's never
@@ -417,29 +409,32 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
if cookie.name == name:
if domain is None or cookie.domain == domain:
if path is None or cookie.path == path:
if toReturn is not None: # if there are multiple cookies that meet passed in criteria
if (
toReturn is not None
): # if there are multiple cookies that meet passed in criteria
raise CookieConflictError(
'There are multiple cookies with name, %r' %
(name)
"There are multiple cookies with name, %r" % (name)
)
toReturn = cookie.value # we will eventually return this as long as no cookie conflict
toReturn = (
cookie.value
) # we will eventually return this as long as no cookie conflict
if toReturn:
return toReturn
raise KeyError('name=%r, domain=%r, path=%r' % (name, domain, path))
raise KeyError("name=%r, domain=%r, path=%r" % (name, domain, path))
def __getstate__(self):
"""Unlike a normal CookieJar, this class is pickleable."""
state = self.__dict__.copy()
# remove the unpickleable RLock object
state.pop('_cookies_lock')
state.pop("_cookies_lock")
return state
def __setstate__(self, state):
"""Unlike a normal CookieJar, this class is pickleable."""
self.__dict__.update(state)
if '_cookies_lock' not in self.__dict__:
if "_cookies_lock" not in self.__dict__:
self._cookies_lock = threading.RLock()
def copy(self):
@@ -457,7 +452,7 @@ def _copy_cookie_jar(jar):
if jar is None:
return None
if hasattr(jar, 'copy'):
if hasattr(jar, "copy"):
# We're dealing with an instance of RequestsCookieJar
return jar.copy()
@@ -476,61 +471,59 @@ def create_cookie(name, value, **kwargs):
and sent on every request (this is sometimes called a "supercookie").
"""
result = {
'version': 0,
'name': name,
'value': value,
'port': None,
'domain': '',
'path': '/',
'secure': False,
'expires': None,
'discard': True,
'comment': None,
'comment_url': None,
'rest': {'HttpOnly': None},
'rfc2109': False,
"version": 0,
"name": name,
"value": value,
"port": None,
"domain": "",
"path": "/",
"secure": False,
"expires": None,
"discard": True,
"comment": None,
"comment_url": None,
"rest": {"HttpOnly": None},
"rfc2109": False,
}
badargs = set(kwargs) - set(result)
if badargs:
err = 'create_cookie() got unexpected keyword arguments: %s'
err = "create_cookie() got unexpected keyword arguments: %s"
raise TypeError(err % list(badargs))
result.update(kwargs)
result['port_specified'] = bool(result['port'])
result['domain_specified'] = bool(result['domain'])
result['domain_initial_dot'] = result['domain'].startswith('.')
result['path_specified'] = bool(result['path'])
result["port_specified"] = bool(result["port"])
result["domain_specified"] = bool(result["domain"])
result["domain_initial_dot"] = result["domain"].startswith(".")
result["path_specified"] = bool(result["path"])
return cookielib.Cookie(**result)
def morsel_to_cookie(morsel):
"""Convert a Morsel object into a Cookie containing the one k/v pair."""
expires = None
if morsel['max-age']:
if morsel["max-age"]:
try:
expires = int(time.time() + int(morsel['max-age']))
expires = int(time.time() + int(morsel["max-age"]))
except ValueError:
raise TypeError('max-age: %s must be integer' % morsel['max-age'])
raise TypeError("max-age: %s must be integer" % morsel["max-age"])
elif morsel['expires']:
time_template = '%a, %d-%b-%Y %H:%M:%S GMT'
expires = calendar.timegm(
time.strptime(morsel['expires'], time_template)
)
elif morsel["expires"]:
time_template = "%a, %d-%b-%Y %H:%M:%S GMT"
expires = calendar.timegm(time.strptime(morsel["expires"], time_template))
return create_cookie(
comment=morsel['comment'],
comment_url=bool(morsel['comment']),
comment=morsel["comment"],
comment_url=bool(morsel["comment"]),
discard=False,
domain=morsel['domain'],
domain=morsel["domain"],
expires=expires,
name=morsel.key,
path=morsel['path'],
path=morsel["path"],
port=None,
rest={'HttpOnly': morsel['httponly']},
rest={"HttpOnly": morsel["httponly"]},
rfc2109=False,
secure=bool(morsel['secure']),
secure=bool(morsel["secure"]),
value=morsel.value,
version=morsel['version'] or 0,
version=morsel["version"] or 0,
)
@@ -561,12 +554,10 @@ def merge_cookies(cookiejar, cookies):
:rtype: CookieJar
"""
if not isinstance(cookiejar, cookielib.CookieJar):
raise ValueError('You can only merge into CookieJar')
raise ValueError("You can only merge into CookieJar")
if isinstance(cookies, dict):
cookiejar = cookiejar_from_dict(
cookies, cookiejar=cookiejar, overwrite=False
)
cookiejar = cookiejar_from_dict(cookies, cookiejar=cookiejar, overwrite=False)
elif isinstance(cookies, cookielib.CookieJar):
try:
cookiejar.update(cookies)
+141 -149
View File
@@ -20,7 +20,10 @@ import encodings.idna
from .core._http.fields import RequestField
from .core._http.filepost import encode_multipart_formdata
from .core._http.exceptions import (
DecodeError, ReadTimeoutError, ProtocolError, LocationParseError
DecodeError,
ReadTimeoutError,
ProtocolError,
LocationParseError,
)
from io import UnsupportedOperation
@@ -73,11 +76,11 @@ from .http_stati import codes
# : The set of HTTP status codes that indicate an automatically
#: processable redirect.
REDIRECT_STATI = (
codes['moved'], # 301
codes['found'], # 302
codes['other'], # 303
codes['temporary_redirect'], # 307
codes['permanent_redirect'], # 308
codes["moved"], # 301
codes["found"], # 302
codes["other"], # 303
codes["temporary_redirect"], # 307
codes["permanent_redirect"], # 308
)
DEFAULT_REDIRECT_LIMIT = 30
CONTENT_CHUNK_SIZE = 10 * 1024
@@ -85,7 +88,6 @@ ITER_CHUNK_SIZE = 512
class RequestEncodingMixin(object):
@property
def path_url(self):
"""Build the path URL to use."""
@@ -93,13 +95,13 @@ class RequestEncodingMixin(object):
p = urlsplit(self.url)
path = p.path
if not path:
path = '/'
path = "/"
url.append(path)
query = p.query
if query:
url.append('?')
url.append("?")
url.append(query)
return ''.join(url)
return "".join(url)
@staticmethod
def _encode_params(data):
@@ -112,20 +114,20 @@ class RequestEncodingMixin(object):
if isinstance(data, (str, bytes)):
return data
elif hasattr(data, 'read'):
elif hasattr(data, "read"):
return data
elif hasattr(data, '__iter__'):
elif hasattr(data, "__iter__"):
result = []
for k, vs in to_key_val_list(data):
if isinstance(vs, basestring) or not hasattr(vs, '__iter__'):
if isinstance(vs, basestring) or not hasattr(vs, "__iter__"):
vs = [vs]
for v in vs:
if v is not None:
result.append(
(
k.encode('utf-8') if isinstance(k, str) else k,
v.encode('utf-8') if isinstance(v, str) else v,
k.encode("utf-8") if isinstance(k, str) else k,
v.encode("utf-8") if isinstance(v, str) else v,
)
)
return urlencode(result, doseq=True)
@@ -143,7 +145,7 @@ class RequestEncodingMixin(object):
The tuples may be 2-tuples (filename, fileobj), 3-tuples (filename, fileobj, contentype)
or 4-tuples (filename, fileobj, contentype, custom_headers).
"""
if (not files):
if not files:
raise ValueError("Files must be provided.")
elif isinstance(data, basestring):
@@ -153,7 +155,7 @@ class RequestEncodingMixin(object):
fields = to_key_val_list(data or {})
files = to_key_val_list(files or {})
for field, val in fields:
if isinstance(val, basestring) or not hasattr(val, '__iter__'):
if isinstance(val, basestring) or not hasattr(val, "__iter__"):
val = [val]
for v in val:
if v is not None:
@@ -162,10 +164,10 @@ class RequestEncodingMixin(object):
v = str(v)
new_fields.append(
(
field.decode('utf-8') if isinstance(
field, bytes
) else field,
v.encode('utf-8') if isinstance(v, str) else v,
field.decode("utf-8")
if isinstance(field, bytes)
else field,
v.encode("utf-8") if isinstance(v, str) else v,
)
)
for (k, v) in files:
@@ -184,7 +186,7 @@ class RequestEncodingMixin(object):
fp = v
if isinstance(fp, (str, bytes, bytearray)):
fdata = fp
elif hasattr(fp, 'read'):
elif hasattr(fp, "read"):
fdata = fp.read()
elif fp is None:
continue
@@ -199,7 +201,6 @@ class RequestEncodingMixin(object):
class RequestHooksMixin(object):
def register_hook(self, event, hook):
"""Properly register a hook."""
if event not in self.hooks:
@@ -209,7 +210,7 @@ class RequestHooksMixin(object):
if isinstance(hook, Callable):
self.hooks[event].append(hook)
elif hasattr(hook, '__iter__'):
elif hasattr(hook, "__iter__"):
self.hooks[event].extend(h for h in hook if isinstance(h, Callable))
def deregister_hook(self, event, hook):
@@ -251,17 +252,18 @@ class Request(RequestHooksMixin):
>>> req.prepare()
<PreparedRequest [GET]>
"""
__slots__ = (
'method',
'url',
'headers',
'files',
'data',
'params',
'auth',
'cookies',
'hooks',
'json',
"method",
"url",
"headers",
"files",
"data",
"params",
"auth",
"cookies",
"hooks",
"json",
)
def __init__(
@@ -297,7 +299,7 @@ class Request(RequestHooksMixin):
self.cookies = cookies
def __repr__(self):
return '<Request [%s]>' % (self.method)
return "<Request [%s]>" % (self.method)
def prepare(self):
"""Constructs a :class:`PreparedRequest <PreparedRequest>` for transmission and returns it."""
@@ -334,14 +336,15 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
>>> s.send(r)
<Response [200]>
"""
__slots__ = (
'method',
'url',
'headers',
'_cookies',
'body',
'hooks',
'_body_position',
"method",
"url",
"headers",
"_cookies",
"body",
"hooks",
"_body_position",
)
def __init__(self):
@@ -387,7 +390,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.prepare_hooks(hooks)
def __repr__(self):
return f'<PreparedRequest [{self.method}]>'
return f"<PreparedRequest [{self.method}]>"
def copy(self):
p = PreparedRequest()
@@ -413,7 +416,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
import idna
try:
host = idna.encode(host, uts46=True).decode('utf-8')
host = idna.encode(host, uts46=True).decode("utf-8")
except idna.IDNAError:
raise UnicodeError
@@ -427,7 +430,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
#: on python 3.x.
#: https://github.com/requests/requests/pull/2238
if isinstance(url, bytes):
url = url.decode('utf8')
url = url.decode("utf8")
else:
url = str(url)
# Ignore any leading and trailing whitespace characters.
@@ -435,7 +438,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
# Don't do any URL preparation for non-HTTP schemes like `mailto`,
# `data` etc to work around exceptions from `url_parse`, which
# handles RFC 3986 only.
if ':' in url and not url.lower().startswith('http'):
if ":" in url and not url.lower().startswith("http"):
self.url = url
return
@@ -451,7 +454,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
error = (
"Invalid URL {0!r}: No scheme supplied. Perhaps you meant http://{0}?"
)
error = error.format(to_native_string(url, 'utf8'))
error = error.format(to_native_string(url, "utf8"))
raise MissingScheme(error)
if not uri.host:
@@ -465,20 +468,20 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
try:
uri = uri.copy_with(host=self._get_idna_encoded_host(uri.host))
except UnicodeError:
raise InvalidURL('URL has an invalid label.')
raise InvalidURL("URL has an invalid label.")
elif uri.host.startswith(u'*'):
raise InvalidURL('URL has an invalid label.')
elif uri.host.startswith("*"):
raise InvalidURL("URL has an invalid label.")
# Bare domains aren't valid URLs.
if not uri.path:
uri = uri.copy_with(path='/')
uri = uri.copy_with(path="/")
if isinstance(params, (str, bytes)):
params = to_native_string(params)
enc_params = self._encode_params(params)
if enc_params:
if uri.query:
uri = uri.copy_with(query=f'{uri.query}&{enc_params}')
uri = uri.copy_with(query=f"{uri.query}&{enc_params}")
else:
uri = uri.copy_with(query=enc_params)
# url = requote_uri(
@@ -507,15 +510,17 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if not data and json is not None:
# urllib3 requires a bytes-like body. Python 2's json.dumps
# provides this natively, but Python 3 gives a Unicode string.
content_type = 'application/json'
content_type = "application/json"
body = complexjson.dumps(json)
if not isinstance(body, bytes):
body = body.encode('utf-8')
body = body.encode("utf-8")
is_stream = all([
hasattr(data, '__iter__'),
not isinstance(data, (basestring, list, tuple, Mapping))
])
is_stream = all(
[
hasattr(data, "__iter__"),
not isinstance(data, (basestring, list, tuple, Mapping)),
]
)
try:
length = super_len(data)
@@ -524,7 +529,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if is_stream:
body = data
if getattr(body, 'tell', None) is not None:
if getattr(body, "tell", None) is not None:
# Record the current file position before reading.
# This will allow us to rewind a file in the event
# of a redirect.
@@ -536,7 +541,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self._body_position = object()
if files:
raise NotImplementedError(
'Streamed bodies and files are mutually exclusive.'
"Streamed bodies and files are mutually exclusive."
)
else:
@@ -546,13 +551,13 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
else:
if data:
body = self._encode_params(data)
if isinstance(data, basestring) or hasattr(data, 'read'):
if isinstance(data, basestring) or hasattr(data, "read"):
content_type = None
else:
content_type = 'application/x-www-form-urlencoded'
content_type = "application/x-www-form-urlencoded"
# Add content-type if it wasn't explicitly provided.
if content_type and ('content-type' not in self.headers):
self.headers['Content-Type'] = content_type
if content_type and ("content-type" not in self.headers):
self.headers["Content-Type"] = content_type
self.prepare_content_length(body)
self.body = body
@@ -568,27 +573,28 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if body is not None:
length = super_len(body)
if length:
self.headers['Content-Length'] = builtin_str(length)
self.headers["Content-Length"] = builtin_str(length)
elif is_stream(body):
self.headers['Transfer-Encoding'] = 'chunked'
self.headers["Transfer-Encoding"] = "chunked"
else:
raise InvalidBodyError(
'Non-null body must have length or be streamable.'
"Non-null body must have length or be streamable."
)
elif self.method not in ('GET', 'HEAD') and self.headers.get(
'Content-Length'
) is None:
elif (
self.method not in ("GET", "HEAD")
and self.headers.get("Content-Length") is None
):
# Set Content-Length to 0 for methods that can have a body
# but don't provide one. (i.e. not GET or HEAD)
self.headers['Content-Length'] = '0'
if 'Transfer-Encoding' in self.headers and 'Content-Length' in self.headers:
self.headers["Content-Length"] = "0"
if "Transfer-Encoding" in self.headers and "Content-Length" in self.headers:
raise InvalidHeader(
'Conflicting Headers: Both Transfer-Encoding and '
'Content-Length are set.'
"Conflicting Headers: Both Transfer-Encoding and "
"Content-Length are set."
)
def prepare_auth(self, auth, url=''):
def prepare_auth(self, auth, url=""):
"""Prepares the given HTTP auth data."""
# If no Auth is explicitly provided, extract it from the URL first.
if auth is None:
@@ -622,7 +628,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self._cookies = cookiejar_from_dict(cookies)
cookie_header = get_cookie_header(self._cookies, self)
if cookie_header is not None:
self.headers['Cookie'] = cookie_header
self.headers["Cookie"] = cookie_header
def prepare_hooks(self, hooks):
"""Prepares the given hooks."""
@@ -645,19 +651,20 @@ class Response(object):
"""The :class:`Response <Response>` object, which contains a
server's response to an HTTP request.
"""
__attrs__ = [
'_content',
'status_code',
'headers',
'url',
'history',
'encoding',
'reason',
'cookies',
'elapsed',
'request',
"_content",
"status_code",
"headers",
"url",
"history",
"encoding",
"reason",
"cookies",
"elapsed",
"request",
]
__slots__ = __attrs__ + ['_content_consumed', 'raw', '_next', 'connection']
__slots__ = __attrs__ + ["_content_consumed", "raw", "_next", "connection"]
def __init__(self):
self._content = False
@@ -714,11 +721,11 @@ class Response(object):
for name, value in state.items():
setattr(self, name, value)
# pickled objects do not have .raw
setattr(self, '_content_consumed', True)
setattr(self, 'raw', None)
setattr(self, "_content_consumed", True)
setattr(self, "raw", None)
def __repr__(self):
return '<Response [%s]>' % (self.status_code)
return "<Response [%s]>" % (self.status_code)
def __iter__(self):
"""Allows you to use a response as an iterator."""
@@ -745,18 +752,14 @@ class Response(object):
"""True if this Response is a well-formed HTTP redirect that could have
been processed automatically (by :meth:`Session.resolve_redirects`).
"""
return (
'location' in self.headers and self.status_code in REDIRECT_STATI
)
return "location" in self.headers and self.status_code in REDIRECT_STATI
@property
def is_permanent_redirect(self):
"""True if this Response one of the permanent versions of redirect."""
return (
'location' in self.headers and
self.status_code in (
codes.moved_permanently, codes.permanent_redirect
)
return "location" in self.headers and self.status_code in (
codes.moved_permanently,
codes.permanent_redirect,
)
@property
@@ -767,7 +770,7 @@ class Response(object):
@property
def apparent_encoding(self):
"""The apparent encoding, provided by the chardet library."""
return chardet.detect(self.content)['encoding']
return chardet.detect(self.content)["encoding"]
def iter_content(self, decode_unicode=False):
"""Iterates over the response data. When stream=True is set on the
@@ -790,7 +793,7 @@ class Response(object):
def generate():
# Special case for urllib3.
if hasattr(self.raw, 'stream'):
if hasattr(self.raw, "stream"):
try:
for chunk in self.raw.stream(
# chunk_size, decode_content=True
@@ -799,7 +802,7 @@ class Response(object):
yield chunk
except ProtocolError as e:
if self.headers.get('Transfer-Encoding') == 'chunked':
if self.headers.get("Transfer-Encoding") == "chunked":
raise ChunkedEncodingError(e)
else:
@@ -838,8 +841,7 @@ class Response(object):
if decode_unicode:
if self.encoding is None:
raise TypeError(
'encoding must be set before consuming streaming '
'responses'
"encoding must be set before consuming streaming " "responses"
)
# check encoding value here, don't wait for the generator to be
@@ -848,15 +850,17 @@ class Response(object):
chunks = stream_decode_response_unicode(chunks, self)
return chunks
def iter_lines(self, chunk_size=ITER_CHUNK_SIZE, decode_unicode=False, delimiter=None):
def iter_lines(
self, chunk_size=ITER_CHUNK_SIZE, decode_unicode=False, delimiter=None
):
"""Iterates over the response data, one line at a time. When
stream=True is set on the request, this avoids reading the
content at once into memory for large responses.
.. note:: This method is not reentrant safe.
"""
carriage_return = u'\r' if decode_unicode else b'\r'
line_feed = u'\n' if decode_unicode else b'\n'
carriage_return = "\r" if decode_unicode else b"\r"
line_feed = "\n" if decode_unicode else b"\n"
pending = None
last_chunk_ends_with_cr = False
for chunk in self.iter_content(
@@ -927,9 +931,7 @@ class Response(object):
if self._content is False:
# Read the contents.
if self._content_consumed:
raise RuntimeError(
'The content for this response was already consumed'
)
raise RuntimeError("The content for this response was already consumed")
if self.status_code == 0 or self.raw is None:
self._content = None
@@ -938,9 +940,7 @@ class Response(object):
# print(bytes().join(
# [await self.iter_content(CONTENT_CHUNK_SIZE)]
# ))
self._content = bytes().join(
self.iter_content()
) or bytes()
self._content = bytes().join(self.iter_content()) or bytes()
self._content_consumed = True
# don't need to release the connection; that's been handled by urllib3
# since we exhausted the data.
@@ -962,14 +962,14 @@ class Response(object):
content = None
encoding = self.encoding
if not self.content:
return str('')
return str("")
# Fallback to auto-detected encoding.
if self.encoding is None:
encoding = self.apparent_encoding
# Decode unicode from given encoding.
try:
content = str(self.content, encoding, errors='replace')
content = str(self.content, encoding, errors="replace")
except (LookupError, TypeError):
# A LookupError is raised if the encoding was not found which could
# indicate a misspelling or similar mistake.
@@ -977,7 +977,7 @@ class Response(object):
# A TypeError can be raised if encoding is None
#
# So we try blindly encoding.
content = str(self.content, errors='replace')
content = str(self.content, errors="replace")
return content
def json(self, **kwargs):
@@ -995,9 +995,7 @@ class Response(object):
if encoding is not None:
try:
content = self.content
return complexjson.loads(
content.decode(encoding), **kwargs
)
return complexjson.loads(content.decode(encoding), **kwargs)
except UnicodeDecodeError:
# Wrong UTF codec detected; usually because it's not UTF-8
@@ -1010,38 +1008,42 @@ class Response(object):
@property
def links(self):
"""Returns the parsed header links of the response, if any."""
header = self.headers.get('link')
header = self.headers.get("link")
# l = MultiDict()
l = {}
if header:
links = parse_header_links(header)
for link in links:
key = link.get('rel') or link.get('url')
key = link.get("rel") or link.get("url")
l[key] = link
return l
def raise_for_status(self):
"""Raises stored :class:`HTTPError`, if one occurred.
Otherwise, returns the response object (self)."""
http_error_msg = ''
http_error_msg = ""
if isinstance(self.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
reason = self.reason.decode('utf-8')
reason = self.reason.decode("utf-8")
except UnicodeDecodeError:
reason = self.reason.decode('iso-8859-1')
reason = self.reason.decode("iso-8859-1")
else:
reason = self.reason
if 400 <= self.status_code < 500:
http_error_msg = u'%s Client Error: %s for url: %s' % (
self.status_code, reason, self.url
http_error_msg = "%s Client Error: %s for url: %s" % (
self.status_code,
reason,
self.url,
)
elif 500 <= self.status_code < 600:
http_error_msg = u'%s Server Error: %s for url: %s' % (
self.status_code, reason, self.url
http_error_msg = "%s Server Error: %s for url: %s" % (
self.status_code,
reason,
self.url,
)
if http_error_msg:
raise HTTPError(http_error_msg, response=self)
@@ -1056,7 +1058,7 @@ class Response(object):
"""
if not self._content_consumed:
self.raw.close()
release_conn = getattr(self.raw, 'release_conn', None)
release_conn = getattr(self.raw, "release_conn", None)
if release_conn is not None:
release_conn()
@@ -1080,9 +1082,7 @@ class AsyncResponse(Response):
if encoding is not None:
try:
content = await self.content
return complexjson.loads(
content.decode(encoding), **kwargs
)
return complexjson.loads(content.decode(encoding), **kwargs)
except UnicodeDecodeError:
# Wrong UTF codec detected; usually because it's not UTF-8
@@ -1108,14 +1108,14 @@ class AsyncResponse(Response):
content = None
encoding = self.encoding
if not await self.content:
return str('')
return str("")
# Fallback to auto-detected encoding.
if self.encoding is None:
encoding = self.apparent_encoding
# Decode unicode from given encoding.
try:
content = str(self.content, encoding, errors='replace')
content = str(self.content, encoding, errors="replace")
except (LookupError, TypeError):
# A LookupError is raised if the encoding was not found which could
# indicate a misspelling or similar mistake.
@@ -1123,7 +1123,7 @@ class AsyncResponse(Response):
# A TypeError can be raised if encoding is None
#
# So we try blindly encoding.
content = str(await self.content, errors='replace')
content = str(await self.content, errors="replace")
return content
@property
@@ -1132,9 +1132,7 @@ class AsyncResponse(Response):
if self._content is False:
# Read the contents.
if self._content_consumed:
raise RuntimeError(
'The content for this response was already consumed'
)
raise RuntimeError("The content for this response was already consumed")
if self.status_code == 0 or self.raw is None:
self._content = None
@@ -1143,19 +1141,16 @@ class AsyncResponse(Response):
# print(bytes().join(
# [await self.iter_content(CONTENT_CHUNK_SIZE)]
# ))
self._content = bytes().join(
[await self.iter_content()]
) or bytes()
self._content = bytes().join([await self.iter_content()]) or bytes()
self._content_consumed = True
# don't need to release the connection; that's been handled by urllib3
# since we exhausted the data.
return self._content
@property
async def apparent_encoding(self):
"""The apparent encoding, provided by the chardet library."""
return chardet.detect(await self.content)['encoding']
return chardet.detect(await self.content)["encoding"]
async def iter_content(self, decode_unicode=False):
"""Iterates over the response data. When stream=True is set on the
@@ -1178,15 +1173,13 @@ class AsyncResponse(Response):
async def generate():
# Special case for requests.core.
if hasattr(self.raw, 'stream'):
if hasattr(self.raw, "stream"):
try:
async for chunk in self.raw.stream(
decode_content=True
):
async for chunk in self.raw.stream(decode_content=True):
yield chunk
except ProtocolError as e:
if self.headers.get('Transfer-Encoding') == 'chunked':
if self.headers.get("Transfer-Encoding") == "chunked":
raise ChunkedEncodingError(e)
else:
@@ -1222,8 +1215,7 @@ class AsyncResponse(Response):
if decode_unicode:
if self.encoding is None:
raise TypeError(
'encoding must be set before consuming streaming '
'responses'
"encoding must be set before consuming streaming " "responses"
)
# check encoding value here, don't wait for the generator to be
+28 -14
View File
@@ -12,23 +12,37 @@ See https://toolbelt.readthedocs.io/ for documentation
from .adapters import SSLAdapter, SourceAddressAdapter
from .auth.guess import GuessAuth
from .multipart import (
MultipartEncoder, MultipartEncoderMonitor, MultipartDecoder,
ImproperBodyPartContentException, NonMultipartContentTypeException
)
MultipartEncoder,
MultipartEncoderMonitor,
MultipartDecoder,
ImproperBodyPartContentException,
NonMultipartContentTypeException,
)
from .streaming_iterator import StreamingIterator
from .utils.user_agent import user_agent
__title__ = 'requests-toolbelt'
__authors__ = 'Ian Cordasco, Cory Benfield'
__license__ = 'Apache v2.0'
__copyright__ = 'Copyright 2014 Ian Cordasco, Cory Benfield'
__version__ = '0.9.1'
__version_info__ = tuple(int(i) for i in __version__.split('.'))
__title__ = "requests-toolbelt"
__authors__ = "Ian Cordasco, Cory Benfield"
__license__ = "Apache v2.0"
__copyright__ = "Copyright 2014 Ian Cordasco, Cory Benfield"
__version__ = "0.9.1"
__version_info__ = tuple(int(i) for i in __version__.split("."))
__all__ = [
'GuessAuth', 'MultipartEncoder', 'MultipartEncoderMonitor',
'MultipartDecoder', 'SSLAdapter', 'SourceAddressAdapter',
'StreamingIterator', 'user_agent', 'ImproperBodyPartContentException',
'NonMultipartContentTypeException', '__title__', '__authors__',
'__license__', '__copyright__', '__version__', '__version_info__',
"GuessAuth",
"MultipartEncoder",
"MultipartEncoderMonitor",
"MultipartDecoder",
"SSLAdapter",
"SourceAddressAdapter",
"StreamingIterator",
"user_agent",
"ImproperBodyPartContentException",
"NonMultipartContentTypeException",
"__title__",
"__authors__",
"__license__",
"__copyright__",
"__version__",
"__version_info__",
]
+26 -24
View File
@@ -53,8 +53,7 @@ if requests.__build__ < 0x021200:
PyOpenSSLContext = None
else:
try:
from requests.packages.urllib3.contrib.pyopenssl \
import PyOpenSSLContext
from requests.packages.urllib3.contrib.pyopenssl import PyOpenSSLContext
except ImportError:
try:
from urllib3.contrib.pyopenssl import PyOpenSSLContext
@@ -130,7 +129,7 @@ class HTTPHeaderDict(MutableMapping):
def __getitem__(self, key):
val = self._container[key.lower()]
return ', '.join(val[1:])
return ", ".join(val[1:])
def __delitem__(self, key):
del self._container[key.lower()]
@@ -139,12 +138,13 @@ class HTTPHeaderDict(MutableMapping):
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return ({k.lower(): v for k, v in self.itermerged()} ==
{k.lower(): v for k, v in other.itermerged()})
return {k.lower(): v for k, v in self.itermerged()} == {
k.lower(): v for k, v in other.itermerged()
}
def __ne__(self, other):
return not self.__eq__(other)
@@ -218,8 +218,10 @@ class HTTPHeaderDict(MutableMapping):
with self.add instead of self.__setitem__
"""
if len(args) > 1:
raise TypeError("extend() takes at most 1 positional "
"arguments ({} given)".format(len(args)))
raise TypeError(
"extend() takes at most 1 positional "
"arguments ({} given)".format(len(args))
)
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
@@ -283,7 +285,7 @@ class HTTPHeaderDict(MutableMapping):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])
yield val[0], ", ".join(val[1:])
def items(self):
return list(self.iteritems())
@@ -297,28 +299,28 @@ class HTTPHeaderDict(MutableMapping):
headers = []
for line in message.headers:
if line.startswith((' ', '\t')):
if line.startswith((" ", "\t")):
key, value = headers[-1]
headers[-1] = (key, value + '\r\n' + line.rstrip())
headers[-1] = (key, value + "\r\n" + line.rstrip())
continue
key, value = line.split(':', 1)
key, value = line.split(":", 1)
headers.append((key, value.strip()))
return cls(headers)
__all__ = (
'basestring',
'connection',
'fields',
'filepost',
'poolmanager',
'timeout',
'HTTPHeaderDict',
'queue',
'urlencode',
'gaecontrib',
'urljoin',
'PyOpenSSLContext',
"basestring",
"connection",
"fields",
"filepost",
"poolmanager",
"timeout",
"HTTPHeaderDict",
"queue",
"urlencode",
"gaecontrib",
"urljoin",
"PyOpenSSLContext",
)
+1 -1
View File
@@ -12,4 +12,4 @@ See https://toolbelt.readthedocs.io/ for documentation
from .ssl import SSLAdapter
from .source import SourceAddressAdapter
__all__ = ['SSLAdapter', 'SourceAddressAdapter']
__all__ = ["SSLAdapter", "SourceAddressAdapter"]
+30 -15
View File
@@ -52,6 +52,7 @@ class AppEngineMROHack(adapters.HTTPAdapter):
monkeypatch, at which point this class becomes HTTPAdapter's base class.
In addition, we use an instantiation flag to avoid infinite recursion.
"""
_initialized = False
def __init__(self, *args, **kwargs):
@@ -71,7 +72,7 @@ class AppEngineAdapter(AppEngineMROHack, adapters.HTTPAdapter):
for Requests to be able to use it.
"""
__attrs__ = adapters.HTTPAdapter.__attrs__ + ['_validate_certificate']
__attrs__ = adapters.HTTPAdapter.__attrs__ + ["_validate_certificate"]
def __init__(self, validate_certificate=True, *args, **kwargs):
_check_version()
@@ -99,13 +100,17 @@ class InsecureAppEngineAdapter(AppEngineAdapter):
def __init__(self, *args, **kwargs):
if kwargs.pop("validate_certificate", False):
warnings.warn("Certificate validation cannot be specified on the "
"InsecureAppEngineAdapter, but was present. This "
"will be ignored and certificate validation will "
"remain off.", exc.IgnoringGAECertificateValidation)
warnings.warn(
"Certificate validation cannot be specified on the "
"InsecureAppEngineAdapter, but was present. This "
"will be ignored and certificate validation will "
"remain off.",
exc.IgnoringGAECertificateValidation,
)
super(InsecureAppEngineAdapter, self).__init__(
validate_certificate=False, *args, **kwargs)
validate_certificate=False, *args, **kwargs
)
class _AppEnginePoolManager(object):
@@ -119,7 +124,8 @@ class _AppEnginePoolManager(object):
def __init__(self, validate_certificate=True):
self.appengine_manager = gaecontrib.AppEngineManager(
validate_certificate=validate_certificate)
validate_certificate=validate_certificate
)
def connection_from_url(self, url):
return _AppEngineConnection(self.appengine_manager, url)
@@ -143,10 +149,20 @@ class _AppEngineConnection(object):
self.appengine_manager = appengine_manager
self.url = url
def urlopen(self, method, url, body=None, headers=None, retries=None,
redirect=True, assert_same_host=True,
timeout=timeout.Timeout.DEFAULT_TIMEOUT,
pool_timeout=None, release_conn=None, **response_kw):
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
assert_same_host=True,
timeout=timeout.Timeout.DEFAULT_TIMEOUT,
pool_timeout=None,
release_conn=None,
**response_kw
):
# This function's url argument is a host-relative URL,
# but the AppEngineManager expects an absolute URL.
# So we saved out the self.url when the AppEngineConnection
@@ -169,7 +185,8 @@ class _AppEngineConnection(object):
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw)
**response_kw
)
def monkeypatch(validate_certificate=True):
@@ -200,7 +217,5 @@ def _check_version():
if gaecontrib is None:
raise exc.VersionMismatchError(
"The toolbelt requires at least Requests 2.10.0 to be "
"installed. Version {} was found instead.".format(
requests.__version__
)
"installed. Version {} was found instead.".format(requests.__version__)
)
+3 -2
View File
@@ -33,7 +33,7 @@ class FingerprintAdapter(HTTPAdapter):
containing colons.
"""
__attrs__ = HTTPAdapter.__attrs__ + ['fingerprint']
__attrs__ = HTTPAdapter.__attrs__ + ["fingerprint"]
def __init__(self, fingerprint, **kwargs):
self.fingerprint = fingerprint
@@ -45,4 +45,5 @@ class FingerprintAdapter(HTTPAdapter):
num_pools=connections,
maxsize=maxsize,
block=block,
assert_fingerprint=self.fingerprint)
assert_fingerprint=self.fingerprint,
)
+24 -25
View File
@@ -39,20 +39,21 @@ class SocketOptionsAdapter(adapters.HTTPAdapter):
if connection is not None:
default_options = getattr(
connection.HTTPConnection,
'default_socket_options',
[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
"default_socket_options",
[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)],
)
else:
default_options = []
warnings.warn(exc.RequestsVersionTooOld,
"This version of Requests is only compatible with a "
"version of urllib3 which is too old to support "
"setting options on a socket. This adapter is "
"functionally useless.")
warnings.warn(
exc.RequestsVersionTooOld,
"This version of Requests is only compatible with a "
"version of urllib3 which is too old to support "
"setting options on a socket. This adapter is "
"functionally useless.",
)
def __init__(self, **kwargs):
self.socket_options = kwargs.pop('socket_options',
self.default_options)
self.socket_options = kwargs.pop("socket_options", self.default_options)
super(SocketOptionsAdapter, self).__init__(**kwargs)
@@ -63,7 +64,7 @@ class SocketOptionsAdapter(adapters.HTTPAdapter):
num_pools=connections,
maxsize=maxsize,
block=block,
socket_options=self.socket_options
socket_options=self.socket_options,
)
else:
super(SocketOptionsAdapter, self).init_poolmanager(
@@ -98,30 +99,28 @@ class TCPKeepAliveAdapter(SocketOptionsAdapter):
"""
def __init__(self, **kwargs):
socket_options = kwargs.pop('socket_options',
SocketOptionsAdapter.default_options)
idle = kwargs.pop('idle', 60)
interval = kwargs.pop('interval', 20)
count = kwargs.pop('count', 5)
socket_options = socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
]
socket_options = kwargs.pop(
"socket_options", SocketOptionsAdapter.default_options
)
idle = kwargs.pop("idle", 60)
interval = kwargs.pop("interval", 20)
count = kwargs.pop("count", 5)
socket_options = socket_options + [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)]
# NOTE(Ian): OSX does not have these constants defined, so we
# set them conditionally.
if getattr(socket, 'TCP_KEEPINTVL', None) is not None:
socket_options += [(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL,
interval)]
elif sys.platform == 'darwin':
if getattr(socket, "TCP_KEEPINTVL", None) is not None:
socket_options += [(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)]
elif sys.platform == "darwin":
# On OSX, TCP_KEEPALIVE from netinet/tcp.h is not exported
# by python's socket module
TCP_KEEPALIVE = getattr(socket, 'TCP_KEEPALIVE', 0x10)
TCP_KEEPALIVE = getattr(socket, "TCP_KEEPALIVE", 0x10)
socket_options += [(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval)]
if getattr(socket, 'TCP_KEEPCNT', None) is not None:
if getattr(socket, "TCP_KEEPCNT", None) is not None:
socket_options += [(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, count)]
if getattr(socket, 'TCP_KEEPIDLE', None) is not None:
if getattr(socket, "TCP_KEEPIDLE", None) is not None:
socket_options += [(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, idle)]
super(TCPKeepAliveAdapter, self).__init__(
+5 -4
View File
@@ -42,6 +42,7 @@ class SourceAddressAdapter(HTTPAdapter):
s.mount('http://', SourceAddressAdapter('10.10.10.10'))
s.mount('https://', SourceAddressAdapter(('10.10.10.10', 8999)))
"""
def __init__(self, source_address, **kwargs):
if isinstance(source_address, basestring):
self.source_address = (source_address, 0)
@@ -59,9 +60,9 @@ class SourceAddressAdapter(HTTPAdapter):
num_pools=connections,
maxsize=maxsize,
block=block,
source_address=self.source_address)
source_address=self.source_address,
)
def proxy_manager_for(self, *args, **kwargs):
kwargs['source_address'] = self.source_address
return super(SourceAddressAdapter, self).proxy_manager_for(
*args, **kwargs)
kwargs["source_address"] = self.source_address
return super(SourceAddressAdapter, self).proxy_manager_for(*args, **kwargs)
+4 -3
View File
@@ -43,7 +43,7 @@ class SSLAdapter(HTTPAdapter):
properly when used with proxies.
"""
__attrs__ = HTTPAdapter.__attrs__ + ['ssl_version']
__attrs__ = HTTPAdapter.__attrs__ + ["ssl_version"]
def __init__(self, ssl_version=None, **kwargs):
self.ssl_version = ssl_version
@@ -55,12 +55,13 @@ class SSLAdapter(HTTPAdapter):
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_version=self.ssl_version)
ssl_version=self.ssl_version,
)
if requests.__build__ >= 0x020400:
# Earlier versions of requests either don't have this method or, worse,
# don't allow passing arbitrary keyword arguments. As a result, only
# conditionally define this method.
def proxy_manager_for(self, *args, **kwargs):
kwargs['ssl_version'] = self.ssl_version
kwargs["ssl_version"] = self.ssl_version
return super(SSLAdapter, self).proxy_manager_for(*args, **kwargs)
+35 -29
View File
@@ -9,8 +9,10 @@ X.509 certificate without needing to convert it to a .pem file
from OpenSSL.crypto import PKey, X509
from cryptography import x509
from cryptography.hazmat.primitives.serialization import (load_pem_private_key,
load_der_private_key)
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key,
load_der_private_key,
)
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.backends import default_backend
@@ -82,49 +84,52 @@ class X509Adapter(HTTPAdapter):
def __init__(self, *args, **kwargs):
self._check_version()
cert_bytes = kwargs.pop('cert_bytes', None)
pk_bytes = kwargs.pop('pk_bytes', None)
password = kwargs.pop('password', None)
encoding = kwargs.pop('encoding', Encoding.PEM)
cert_bytes = kwargs.pop("cert_bytes", None)
pk_bytes = kwargs.pop("pk_bytes", None)
password = kwargs.pop("password", None)
encoding = kwargs.pop("encoding", Encoding.PEM)
password_bytes = None
if cert_bytes is None or not isinstance(cert_bytes, bytes):
raise ValueError('Invalid cert content provided. '
'You must provide an X.509 cert '
'formatted as a byte array.')
raise ValueError(
"Invalid cert content provided. "
"You must provide an X.509 cert "
"formatted as a byte array."
)
if pk_bytes is None or not isinstance(pk_bytes, bytes):
raise ValueError('Invalid private key content provided. '
'You must provide a private key '
'formatted as a byte array.')
raise ValueError(
"Invalid private key content provided. "
"You must provide a private key "
"formatted as a byte array."
)
if isinstance(password, bytes):
password_bytes = password
elif password:
password_bytes = password.encode('utf8')
password_bytes = password.encode("utf8")
self.ssl_context = create_ssl_context(cert_bytes, pk_bytes,
password_bytes, encoding)
self.ssl_context = create_ssl_context(
cert_bytes, pk_bytes, password_bytes, encoding
)
super(X509Adapter, self).__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
if self.ssl_context:
kwargs['ssl_context'] = self.ssl_context
kwargs["ssl_context"] = self.ssl_context
return super(X509Adapter, self).init_poolmanager(*args, **kwargs)
def proxy_manager_for(self, *args, **kwargs):
if self.ssl_context:
kwargs['ssl_context'] = self.ssl_context
kwargs["ssl_context"] = self.ssl_context
return super(X509Adapter, self).proxy_manager_for(*args, **kwargs)
def _check_version(self):
if PyOpenSSLContext is None:
raise exc.VersionMismatchError(
"The X509Adapter requires at least Requests 2.12.0 to be "
"installed. Version {} was found instead.".format(
requests.__version__
)
"installed. Version {} was found instead.".format(requests.__version__)
)
@@ -133,14 +138,16 @@ def check_cert_dates(cert):
now = datetime.utcnow()
if cert.not_valid_after < now or cert.not_valid_before > now:
raise ValueError('Client certificate expired: Not After: '
'{:%Y-%m-%d %H:%M:%SZ} '
'Not Before: {:%Y-%m-%d %H:%M:%SZ}'
.format(cert.not_valid_after, cert.not_valid_before))
raise ValueError(
"Client certificate expired: Not After: "
"{:%Y-%m-%d %H:%M:%SZ} "
"Not Before: {:%Y-%m-%d %H:%M:%SZ}".format(
cert.not_valid_after, cert.not_valid_before
)
)
def create_ssl_context(cert_byes, pk_bytes, password=None,
encoding=Encoding.PEM):
def create_ssl_context(cert_byes, pk_bytes, password=None, encoding=Encoding.PEM):
"""Create an SSL Context with the supplied cert/password.
:param cert_bytes array of bytes containing the cert encoded
@@ -166,11 +173,10 @@ def create_ssl_context(cert_byes, pk_bytes, password=None,
cert = x509.load_der_x509_certificate(cert_byes, backend)
key = load_der_private_key(pk_bytes, password, backend)
else:
raise ValueError('Invalid encoding provided: Must be PEM or DER')
raise ValueError("Invalid encoding provided: Must be PEM or DER")
if not (cert and key):
raise ValueError('Cert and key could not be parsed from '
'provided data')
raise ValueError("Cert and key could not be parsed from " "provided data")
check_cert_dates(cert)
ssl_context = PyOpenSSLContext(PROTOCOL)
ssl_context._ctx.use_certificate(X509.from_cryptography(cert))
@@ -15,12 +15,12 @@ class _ThreadingDescriptor(object):
class _HTTPDigestAuth(requests.auth.HTTPDigestAuth):
init = _ThreadingDescriptor('init', True)
last_nonce = _ThreadingDescriptor('last_nonce', '')
nonce_count = _ThreadingDescriptor('nonce_count', 0)
chal = _ThreadingDescriptor('chal', {})
pos = _ThreadingDescriptor('pos', None)
num_401_calls = _ThreadingDescriptor('num_401_calls', 1)
init = _ThreadingDescriptor("init", True)
last_nonce = _ThreadingDescriptor("last_nonce", "")
nonce_count = _ThreadingDescriptor("nonce_count", 0)
chal = _ThreadingDescriptor("chal", {})
pos = _ThreadingDescriptor("pos", None)
num_401_calls = _ThreadingDescriptor("num_401_calls", 1)
if requests.__build__ < 0x020800:
+19 -18
View File
@@ -8,6 +8,7 @@ from . import _digest_auth_compat as auth_compat, http_proxy_digest
class GuessAuth(auth.AuthBase):
"""Guesses the auth type by the WWW-Authentication header."""
def __init__(self, username, password):
self.username = username
self.password = password
@@ -23,7 +24,7 @@ class GuessAuth(auth.AuthBase):
r.content
r.raw.release_conn()
prep = r.request.copy()
if not hasattr(prep, '_cookies'):
if not hasattr(prep, "_cookies"):
prep._cookies = cookies.RequestsCookieJar()
cookies.extract_cookies_to_jar(prep._cookies, r.request, r.raw)
prep.prepare_cookies(prep._cookies)
@@ -48,8 +49,7 @@ class GuessAuth(auth.AuthBase):
# Check that the attr exists because much older versions of requests
# set this attribute lazily. For example:
# https://github.com/kennethreitz/requests/blob/33735480f77891754304e7f13e3cdf83aaaa76aa/requests/auth.py#L59
if (hasattr(self.auth, 'num_401_calls') and
self.auth.num_401_calls is None):
if hasattr(self.auth, "num_401_calls") and self.auth.num_401_calls is None:
self.auth.num_401_calls = 1
# Digest auth would resend the request by itself. We can take a
# shortcut here.
@@ -58,12 +58,12 @@ class GuessAuth(auth.AuthBase):
def handle_401(self, r, **kwargs):
"""Resends a request with auth headers, if needed."""
www_authenticate = r.headers.get('www-authenticate', '').lower()
www_authenticate = r.headers.get("www-authenticate", "").lower()
if 'basic' in www_authenticate:
if "basic" in www_authenticate:
return self._handle_basic_auth_401(r, kwargs)
if 'digest' in www_authenticate:
if "digest" in www_authenticate:
return self._handle_digest_auth_401(r, kwargs)
def __call__(self, request):
@@ -75,7 +75,7 @@ class GuessAuth(auth.AuthBase):
except AttributeError:
pass
request.register_hook('response', self.handle_401)
request.register_hook("response", self.handle_401)
return request
@@ -84,8 +84,10 @@ class GuessProxyAuth(GuessAuth):
Guesses the auth type by WWW-Authentication and Proxy-Authentication
headers
"""
def __init__(self, username=None, password=None,
proxy_username=None, proxy_password=None):
def __init__(
self, username=None, password=None, proxy_username=None, proxy_password=None
):
super(GuessProxyAuth, self).__init__(username, password)
self.proxy_username = proxy_username
self.proxy_password = proxy_password
@@ -98,13 +100,12 @@ class GuessProxyAuth(GuessAuth):
r.content
r.raw.release_conn()
prep = r.request.copy()
if not hasattr(prep, '_cookies'):
if not hasattr(prep, "_cookies"):
prep._cookies = cookies.RequestsCookieJar()
cookies.extract_cookies_to_jar(prep._cookies, r.request, r.raw)
prep.prepare_cookies(prep._cookies)
self.proxy_auth = auth.HTTPProxyAuth(self.proxy_username,
self.proxy_password)
self.proxy_auth = auth.HTTPProxyAuth(self.proxy_username, self.proxy_password)
prep = self.proxy_auth(prep)
_r = r.connection.send(prep, **kwargs)
_r.history.append(r)
@@ -114,8 +115,8 @@ class GuessProxyAuth(GuessAuth):
def _handle_digest_auth_407(self, r, kwargs):
self.proxy_auth = http_proxy_digest.HTTPProxyDigestAuth(
username=self.proxy_username,
password=self.proxy_password)
username=self.proxy_username, password=self.proxy_password
)
try:
self.auth.init_per_thread_state()
@@ -125,12 +126,12 @@ class GuessProxyAuth(GuessAuth):
return self.proxy_auth.handle_407(r, **kwargs)
def handle_407(self, r, **kwargs):
proxy_authenticate = r.headers.get('Proxy-Authenticate', '').lower()
proxy_authenticate = r.headers.get("Proxy-Authenticate", "").lower()
if 'basic' in proxy_authenticate:
if "basic" in proxy_authenticate:
return self._handle_basic_auth_407(r, kwargs)
if 'digest' in proxy_authenticate:
if "digest" in proxy_authenticate:
return self._handle_digest_auth_407(r, kwargs)
def __call__(self, request):
@@ -142,5 +143,5 @@ class GuessProxyAuth(GuessAuth):
except AttributeError:
pass
request.register_hook('response', self.handle_407)
request.register_hook("response", self.handle_407)
return super(GuessProxyAuth, self).__call__(request)
+5 -5
View File
@@ -60,7 +60,7 @@ class AuthHandler(AuthBase):
return auth(request)
def __repr__(self):
return '<AuthHandler({!r})>'.format(self.strategies)
return "<AuthHandler({!r})>".format(self.strategies)
def _make_uniform(self):
existing_strategies = list(self.strategies.items())
@@ -72,9 +72,9 @@ class AuthHandler(AuthBase):
@staticmethod
def _key_from_url(url):
parsed = urlparse(url)
return urlunparse((parsed.scheme.lower(),
parsed.netloc.lower(),
'', '', '', ''))
return urlunparse(
(parsed.scheme.lower(), parsed.netloc.lower(), "", "", "", "")
)
def add_strategy(self, domain, strategy):
"""Add a new domain and authentication strategy.
@@ -136,7 +136,7 @@ class AuthHandler(AuthBase):
class NullAuthStrategy(AuthBase):
def __repr__(self):
return '<NullAuthStrategy>'
return "<NullAuthStrategy>"
def __call__(self, r):
return r
+15 -16
View File
@@ -16,7 +16,8 @@ class HTTPProxyDigestAuth(auth.HTTPDigestAuth):
new username and password. i.e., retry build_digest_header
:type stale_rejects: int
"""
_pat = re.compile(r'digest ', flags=re.IGNORECASE)
_pat = re.compile(r"digest ", flags=re.IGNORECASE)
def __init__(self, *args, **kwargs):
super(HTTPProxyDigestAuth, self).__init__(*args, **kwargs)
@@ -26,14 +27,14 @@ class HTTPProxyDigestAuth(auth.HTTPDigestAuth):
@property
def stale_rejects(self):
thread_local = getattr(self, '_thread_local', None)
thread_local = getattr(self, "_thread_local", None)
if thread_local is None:
return self._stale_rejects
return thread_local.stale_rejects
@stale_rejects.setter
def stale_rejects(self, value):
thread_local = getattr(self, '_thread_local', None)
thread_local = getattr(self, "_thread_local", None)
if thread_local is None:
self._stale_rejects = value
else:
@@ -57,21 +58,20 @@ class HTTPProxyDigestAuth(auth.HTTPDigestAuth):
if s_auth is None:
raise IOError(
"proxy server violated RFC 7235:"
"407 response MUST contain header proxy-authenticate")
"407 response MUST contain header proxy-authenticate"
)
elif not self._pat.match(s_auth):
return r
self.chal = utils.parse_dict_header(
self._pat.sub('', s_auth, count=1))
self.chal = utils.parse_dict_header(self._pat.sub("", s_auth, count=1))
# if we present the user/passwd and still get rejected
# https://tools.ietf.org/html/rfc2617#section-3.2.1
if ('Proxy-Authorization' in r.request.headers and
'stale' in self.chal):
if self.chal['stale'].lower() == 'true': # try again
if "Proxy-Authorization" in r.request.headers and "stale" in self.chal:
if self.chal["stale"].lower() == "true": # try again
self.stale_rejects += 1
# wrong user/passwd
elif self.chal['stale'].lower() == 'false':
elif self.chal["stale"].lower() == "false":
raise IOError("User or password is invalid")
# Consume content and release the original connection
@@ -82,8 +82,9 @@ class HTTPProxyDigestAuth(auth.HTTPDigestAuth):
cookies.extract_cookies_to_jar(prep._cookies, r.request, r.raw)
prep.prepare_cookies(prep._cookies)
prep.headers['Proxy-Authorization'] = self.build_digest_header(
prep.method, prep.url)
prep.headers["Proxy-Authorization"] = self.build_digest_header(
prep.method, prep.url
)
_r = r.connection.send(prep, **kwargs)
_r.history.append(r)
_r.request = prep
@@ -96,8 +97,6 @@ class HTTPProxyDigestAuth(auth.HTTPDigestAuth):
self.init_per_thread_state()
# if we have nonce, then just use it, otherwise server will tell us
if self.last_nonce:
r.headers['Proxy-Authorization'] = self.build_digest_header(
r.method, r.url
)
r.register_hook('response', self.handle_407)
r.headers["Proxy-Authorization"] = self.build_digest_header(r.method, r.url)
r.register_hook("response", self.handle_407)
return r
+8 -8
View File
@@ -9,8 +9,8 @@ from .. import exceptions as exc
# cd2c97bb0a076da2322f11adce0b2731f9193396 L62-L64
_QUOTED_STRING_RE = r'"[^"\\]*(?:\\.[^"\\]*)*"'
_OPTION_HEADER_PIECE_RE = re.compile(
r';\s*(%s|[^\s;=]+)\s*(?:=\s*(%s|[^;]+))?\s*' % (_QUOTED_STRING_RE,
_QUOTED_STRING_RE)
r";\s*(%s|[^\s;=]+)\s*(?:=\s*(%s|[^;]+))?\s*"
% (_QUOTED_STRING_RE, _QUOTED_STRING_RE)
)
_DEFAULT_CHUNKSIZE = 512
@@ -18,7 +18,7 @@ _DEFAULT_CHUNKSIZE = 512
def _get_filename(content_disposition):
for match in _OPTION_HEADER_PIECE_RE.finditer(content_disposition):
k, v = match.groups()
if k == 'filename':
if k == "filename":
# ignore any directory paths in the filename
return os.path.split(v)[1]
return None
@@ -52,10 +52,10 @@ def get_download_file_path(response, path):
filepath = path
else:
response_filename = _get_filename(
response.headers.get('content-disposition', '')
response.headers.get("content-disposition", "")
)
if not response_filename:
raise exc.StreamingError('No filename given to stream response to')
raise exc.StreamingError("No filename given to stream response to")
if path_is_dir:
# directory to download to
@@ -157,15 +157,15 @@ def stream_response_to_file(response, path=None, chunksize=_DEFAULT_CHUNKSIZE):
pre_opened = False
fd = None
filename = None
if path and callable(getattr(path, 'write', None)):
if path and callable(getattr(path, "write", None)):
pre_opened = True
fd = path
filename = getattr(fd, 'name', None)
filename = getattr(fd, "name", None)
else:
filename = get_download_file_path(response, path)
if os.path.exists(filename):
raise exc.StreamingError("File already exists: %s" % filename)
fd = open(filename, 'wb')
fd = open(filename, "wb")
for chunk in response.iter_content(chunk_size=chunksize):
fd.write(chunk)
+17 -17
View File
@@ -3,18 +3,16 @@ import io
_DEFAULT_CHUNKSIZE = 65536
__all__ = ['tee', 'tee_to_file', 'tee_to_bytearray']
__all__ = ["tee", "tee_to_file", "tee_to_bytearray"]
def _tee(response, callback, chunksize, decode_content):
for chunk in response.raw.stream(amt=chunksize,
decode_content=decode_content):
for chunk in response.raw.stream(amt=chunksize, decode_content=decode_content):
callback(chunk)
yield chunk
def tee(response, fileobject, chunksize=_DEFAULT_CHUNKSIZE,
decode_content=None):
def tee(response, fileobject, chunksize=_DEFAULT_CHUNKSIZE, decode_content=None):
"""Stream the response both to the generator and a file.
This will stream the response body while writing the bytes to
@@ -53,17 +51,19 @@ def tee(response, fileobject, chunksize=_DEFAULT_CHUNKSIZE,
# ensure that writing to the fileobject will preserve those bytes. On
# Python3, if the user passes an io.StringIO, this will fail, so we need
# to check for BytesIO instead.
if not ('b' in getattr(fileobject, 'mode', '') or
isinstance(fileobject, io.BytesIO)):
raise TypeError('tee() will write bytes directly to this fileobject'
', it must be opened with the "b" flag if it is a file'
' or inherit from io.BytesIO.')
if not (
"b" in getattr(fileobject, "mode", "") or isinstance(fileobject, io.BytesIO)
):
raise TypeError(
"tee() will write bytes directly to this fileobject"
', it must be opened with the "b" flag if it is a file'
" or inherit from io.BytesIO."
)
return _tee(response, fileobject.write, chunksize, decode_content)
def tee_to_file(response, filename, chunksize=_DEFAULT_CHUNKSIZE,
decode_content=None):
def tee_to_file(response, filename, chunksize=_DEFAULT_CHUNKSIZE, decode_content=None):
"""Stream the response both to the generator and a file.
This will open a file named ``filename`` and stream the response body
@@ -84,13 +84,14 @@ def tee_to_file(response, filename, chunksize=_DEFAULT_CHUNKSIZE,
:param bool decode_content: (optional), If True, this will decode the
compressed content of the response.
"""
with open(filename, 'wb') as fd:
with open(filename, "wb") as fd:
for chunk in tee(response, fd, chunksize, decode_content):
yield chunk
def tee_to_bytearray(response, bytearr, chunksize=_DEFAULT_CHUNKSIZE,
decode_content=None):
def tee_to_bytearray(
response, bytearr, chunksize=_DEFAULT_CHUNKSIZE, decode_content=None
):
"""Stream the response both to the generator and a bytearray.
This will stream the response provided to the function, add them to the
@@ -118,6 +119,5 @@ def tee_to_bytearray(response, bytearr, chunksize=_DEFAULT_CHUNKSIZE,
compressed content of the response.
"""
if not isinstance(bytearr, bytearray):
raise TypeError('tee_to_bytearray() expects bytearr to be a '
'bytearray')
raise TypeError("tee_to_bytearray() expects bytearr to be a " "bytearray")
return _tee(response, bytearr.extend, chunksize, decode_content)
+4
View File
@@ -4,6 +4,7 @@
class StreamingError(Exception):
"""Used in :mod:`requests_toolbelt.downloadutils.stream`."""
pass
@@ -13,6 +14,7 @@ class VersionMismatchError(Exception):
The feature in use requires a newer version of Requests to function
appropriately but the version installed is not sufficient.
"""
pass
@@ -22,6 +24,7 @@ class RequestsVersionTooOld(Warning):
If the version of Requests is too old to support a feature, we will issue
this warning to the user.
"""
pass
@@ -34,4 +37,5 @@ class IgnoringGAECertificateValidation(Warning):
In :class:`requests_toolbelt.adapters.appengine.InsecureAppEngineAdapter`.
"""
pass
+13 -13
View File
@@ -13,19 +13,19 @@ from .decoder import MultipartDecoder
from .decoder import ImproperBodyPartContentException
from .decoder import NonMultipartContentTypeException
__title__ = 'requests-toolbelt'
__authors__ = 'Ian Cordasco, Cory Benfield'
__license__ = 'Apache v2.0'
__copyright__ = 'Copyright 2014 Ian Cordasco, Cory Benfield'
__title__ = "requests-toolbelt"
__authors__ = "Ian Cordasco, Cory Benfield"
__license__ = "Apache v2.0"
__copyright__ = "Copyright 2014 Ian Cordasco, Cory Benfield"
__all__ = [
'MultipartEncoder',
'MultipartEncoderMonitor',
'MultipartDecoder',
'ImproperBodyPartContentException',
'NonMultipartContentTypeException',
'__title__',
'__authors__',
'__license__',
'__copyright__',
"MultipartEncoder",
"MultipartEncoderMonitor",
"MultipartDecoder",
"ImproperBodyPartContentException",
"NonMultipartContentTypeException",
"__title__",
"__authors__",
"__license__",
"__copyright__",
]
+22 -25
View File
@@ -16,7 +16,7 @@ from requests.structures import CaseInsensitiveDict
def _split_on_find(content, bound):
point = content.find(bound)
return content[:point], content[point + len(bound):]
return content[:point], content[point + len(bound) :]
class ImproperBodyPartContentException(Exception):
@@ -32,10 +32,7 @@ def _header_parser(string, encoding):
if major == 3:
string = string.decode(encoding)
headers = email.parser.HeaderParser().parsestr(string).items()
return (
(encode_with(k, encoding), encode_with(v, encoding))
for k, v in headers
)
return ((encode_with(k, encoding), encode_with(v, encoding)) for k, v in headers)
class BodyPart(object):
@@ -55,13 +52,13 @@ class BodyPart(object):
self.encoding = encoding
headers = {}
# Split into header section (if any) and the content
if b'\r\n\r\n' in content:
first, self.content = _split_on_find(content, b'\r\n\r\n')
if first != b'':
if b"\r\n\r\n" in content:
first, self.content = _split_on_find(content, b"\r\n\r\n")
if first != b"":
headers = _header_parser(first.lstrip(), encoding)
else:
raise ImproperBodyPartContentException(
'content does not contain CR-LF-CR-LF'
"content does not contain CR-LF-CR-LF"
)
self.headers = CaseInsensitiveDict(headers)
@@ -100,7 +97,8 @@ class MultipartDecoder(object):
``'utf-8'``).
"""
def __init__(self, content, content_type, encoding='utf-8'):
def __init__(self, content, content_type, encoding="utf-8"):
#: Original Content-Type header
self.content_type = content_type
#: Response body encoding
@@ -111,18 +109,15 @@ class MultipartDecoder(object):
self._parse_body(content)
def _find_boundary(self):
ct_info = tuple(x.strip() for x in self.content_type.split(';'))
ct_info = tuple(x.strip() for x in self.content_type.split(";"))
mimetype = ct_info[0]
if mimetype.split('/')[0].lower() != 'multipart':
if mimetype.split("/")[0].lower() != "multipart":
raise NonMultipartContentTypeException(
"Unexpected mimetype in content-type: '{}'".format(mimetype)
)
for item in ct_info[1:]:
attr, value = _split_on_find(
item,
'='
)
if attr.lower() == 'boundary':
attr, value = _split_on_find(item, "=")
if attr.lower() == "boundary":
self.boundary = encode_with(value.strip('"'), self.encoding)
@staticmethod
@@ -134,23 +129,25 @@ class MultipartDecoder(object):
return part
def _parse_body(self, content):
boundary = b''.join((b'--', self.boundary))
boundary = b"".join((b"--", self.boundary))
def body_part(part):
fixed = MultipartDecoder._fix_first_part(part, boundary)
return BodyPart(fixed, self.encoding)
def test_part(part):
return (part != b'' and
part != b'\r\n' and
part[:4] != b'--\r\n' and
part != b'--')
return (
part != b""
and part != b"\r\n"
and part[:4] != b"--\r\n"
and part != b"--"
)
parts = content.split(b''.join((b'\r\n', boundary)))
parts = content.split(b"".join((b"\r\n", boundary)))
self.parts = tuple(body_part(x) for x in parts if test_part(x))
@classmethod
def from_response(cls, response, encoding='utf-8'):
def from_response(cls, response, encoding="utf-8"):
content = response.content
content_type = response.headers.get('content-type', None)
content_type = response.headers.get("content-type", None)
return cls(content, content_type, encoding)
+36 -36
View File
@@ -84,21 +84,23 @@ class MultipartEncoder(object):
"""
def __init__(self, fields, boundary=None, encoding='utf-8'):
def __init__(self, fields, boundary=None, encoding="utf-8"):
#: Boundary value either passed in by the user or created
self.boundary_value = boundary or uuid4().hex
# Computed boundary
self.boundary = '--{}'.format(self.boundary_value)
self.boundary = "--{}".format(self.boundary_value)
#: Encoding of the data being passed in
self.encoding = encoding
# Pre-encoded boundary
self._encoded_boundary = b''.join([
encode_with(self.boundary, self.encoding),
encode_with('\r\n', self.encoding)
])
self._encoded_boundary = b"".join(
[
encode_with(self.boundary, self.encoding),
encode_with("\r\n", self.encoding),
]
)
#: Fields provided by the user
self.fields = fields
@@ -148,7 +150,7 @@ class MultipartEncoder(object):
return self._len or self._calculate_length()
def __repr__(self):
return '<MultipartEncoder: {!r}>'.format(self.fields)
return "<MultipartEncoder: {!r}>".format(self.fields)
def _calculate_length(self):
"""
@@ -158,9 +160,11 @@ class MultipartEncoder(object):
"""
boundary_len = len(self.boundary) # Length of --{boundary}
# boundary length + header length + body length + len('\r\n') * 2
self._len = sum(
(boundary_len + total_len(p) + 4) for p in self.parts
) + boundary_len + 4
self._len = (
sum((boundary_len + total_len(p) + 4) for p in self.parts)
+ boundary_len
+ 4
)
return self._len
def _calculate_load_amount(self, read_size):
@@ -191,7 +195,7 @@ class MultipartEncoder(object):
while amount == -1 or amount > 0:
written = 0
if part and not part.bytes_left_to_write():
written += self._write(b'\r\n')
written += self._write(b"\r\n")
written += self._write_boundary()
part = self._next_part()
@@ -214,7 +218,7 @@ class MultipartEncoder(object):
def _iter_fields(self):
_fields = self.fields
if hasattr(self.fields, 'items'):
if hasattr(self.fields, "items"):
_fields = list(self.fields.items())
for k, v in _fields:
file_name = None
@@ -230,9 +234,9 @@ class MultipartEncoder(object):
else:
file_pointer = v
field = fields.RequestField(name=k, data=file_pointer,
filename=file_name,
headers=file_headers)
field = fields.RequestField(
name=k, data=file_pointer, filename=file_name, headers=file_headers
)
field.make_multipart(content_type=file_type)
yield field
@@ -263,7 +267,7 @@ class MultipartEncoder(object):
"""Write the bytes necessary to finish a multipart/form-data body."""
with reset(self._buffer):
self._buffer.seek(-2, 2)
self._buffer.write(b'--\r\n')
self._buffer.write(b"--\r\n")
return 2
def _write_headers(self, headers):
@@ -272,9 +276,7 @@ class MultipartEncoder(object):
@property
def content_type(self):
return str(
'multipart/form-data; boundary={}'.format(self.boundary_value)
)
return str("multipart/form-data; boundary={}".format(self.boundary_value))
def to_string(self):
"""Return the entirety of the data in the encoder.
@@ -385,8 +387,7 @@ class MultipartEncoderMonitor(object):
self.len = self.encoder.len
@classmethod
def from_fields(cls, fields, boundary=None, encoding='utf-8',
callback=None):
def from_fields(cls, fields, boundary=None, encoding="utf-8", callback=None):
encoder = MultipartEncoder(fields, boundary, encoding)
return cls(encoder, callback)
@@ -419,20 +420,20 @@ def encode_with(string, encoding):
def readable_data(data, encoding):
"""Coerce the data to an object with a ``read`` method."""
if hasattr(data, 'read'):
if hasattr(data, "read"):
return data
return CustomBytesIO(data, encoding)
def total_len(o):
if hasattr(o, '__len__'):
if hasattr(o, "__len__"):
return len(o)
if hasattr(o, 'len'):
if hasattr(o, "len"):
return o.len
if hasattr(o, 'fileno'):
if hasattr(o, "fileno"):
try:
fileno = o.fileno()
except io.UnsupportedOperation:
@@ -440,7 +441,7 @@ def total_len(o):
else:
return os.fstat(fileno).st_size
if hasattr(o, 'getvalue'):
if hasattr(o, "getvalue"):
# e.g. BytesIO, cStringIO.StringIO
return len(o.getvalue())
@@ -462,20 +463,20 @@ def reset(buffer):
def coerce_data(data, encoding):
"""Ensure that every object's __len__ behaves uniformly."""
if not isinstance(data, CustomBytesIO):
if hasattr(data, 'getvalue'):
if hasattr(data, "getvalue"):
return CustomBytesIO(data.getvalue(), encoding)
if hasattr(data, 'fileno'):
if hasattr(data, "fileno"):
return FileWrapper(data)
if not hasattr(data, 'read'):
if not hasattr(data, "read"):
return CustomBytesIO(data, encoding)
return data
def to_list(fields):
if hasattr(fields, 'items'):
if hasattr(fields, "items"):
return list(fields.items())
return list(fields)
@@ -531,7 +532,7 @@ class Part(object):
class CustomBytesIO(io.BytesIO):
def __init__(self, buffer=None, encoding='utf-8'):
def __init__(self, buffer=None, encoding="utf-8"):
buffer = encode_with(buffer, encoding)
super(CustomBytesIO, self).__init__(buffer)
@@ -625,18 +626,17 @@ class FileFromURLWrapper(object):
def __init__(self, file_url, session=None):
self.session = session or requests.Session()
requested_file = self._request_for_file(file_url)
self.len = int(requested_file.headers['content-length'])
self.len = int(requested_file.headers["content-length"])
self.raw_data = requested_file.raw
def _request_for_file(self, file_url):
"""Make call for file under provided URL."""
response = self.session.get(file_url, stream=True)
content_length = response.headers.get('content-length', None)
content_length = response.headers.get("content-length", None)
if content_length is None:
error_msg = (
"Data from provided URL {url} is not supported. Lack of "
"content-length Header in requested file response.".format(
url=file_url)
"content-length Header in requested file response.".format(url=file_url)
)
raise FileNotSupportedError(error_msg)
elif not content_length.isdigit():
@@ -650,6 +650,6 @@ class FileFromURLWrapper(object):
def read(self, chunk_size):
"""Read file in chunks."""
chunk_size = chunk_size if chunk_size >= 0 else self.len
chunk = self.raw_data.read(chunk_size) or b''
chunk = self.raw_data.read(chunk_size) or b""
self.len -= len(chunk) if chunk else 0 # left to read
return chunk
+1 -3
View File
@@ -61,9 +61,7 @@ class BaseUrlSession(requests.Session):
def request(self, method, url, *args, **kwargs):
"""Send the request after generating the complete URL."""
url = self.create_url(url)
return super(BaseUrlSession, self).request(
method, url, *args, **kwargs
)
return super(BaseUrlSession, self).request(method, url, *args, **kwargs)
def create_url(self, url):
"""Create the URL based off this partial path."""
+6 -8
View File
@@ -52,14 +52,12 @@ class StreamingIterator(object):
appropriately because the toolbelt will not attempt to guess that for you.
"""
def __init__(self, size, iterator, encoding='utf-8'):
def __init__(self, size, iterator, encoding="utf-8"):
#: The expected size of the upload
self.size = int(size)
if self.size < 0:
raise ValueError(
'The size of the upload must be a positive integer'
)
raise ValueError("The size of the upload must be a positive integer")
#: Attribute that requests will check to determine the length of the
#: body. See bug #80 for more details
@@ -71,7 +69,7 @@ class StreamingIterator(object):
#: The iterator used to generate the upload data
self.iterator = iterator
if hasattr(iterator, 'read'):
if hasattr(iterator, "read"):
self._file = iterator
else:
self._file = _IteratorAsBinaryFile(iterator, encoding)
@@ -81,7 +79,7 @@ class StreamingIterator(object):
class _IteratorAsBinaryFile(object):
def __init__(self, iterator, encoding='utf-8'):
def __init__(self, iterator, encoding="utf-8"):
#: The iterator used to generate the upload data
self.iterator = iterator
@@ -96,7 +94,7 @@ class _IteratorAsBinaryFile(object):
try:
return encode_with(next(self.iterator), self.encoding)
except StopIteration:
return b''
return b""
def _load_bytes(self, size):
self._buffer.smart_truncate()
@@ -110,7 +108,7 @@ class _IteratorAsBinaryFile(object):
def read(self, size=-1):
size = int(size)
if size == -1:
return b''.join(self.iterator)
return b"".join(self.iterator)
self._load_bytes(size)
return self._buffer.read(size)
+2 -2
View File
@@ -82,7 +82,7 @@ def map(requests, **kwargs):
:class:`~requests_toolbelt.threaded.pool.ThreadException`)
"""
if not (requests and all(isinstance(r, dict) for r in requests)):
raise ValueError('map expects a list of dictionaries.')
raise ValueError("map expects a list of dictionaries.")
# Build our queue of requests
job_queue = queue.Queue()
@@ -90,7 +90,7 @@ def map(requests, **kwargs):
job_queue.put(request)
# Ensure the user doesn't try to pass their own job_queue
kwargs['job_queue'] = job_queue
kwargs["job_queue"] = job_queue
threadpool = pool.Pool(**kwargs)
threadpool.join_all()
+23 -11
View File
@@ -24,8 +24,14 @@ class Pool(object):
:type session: requests.Session
"""
def __init__(self, job_queue, initializer=None, auth_generator=None,
num_processes=None, session=requests.Session):
def __init__(
self,
job_queue,
initializer=None,
auth_generator=None,
num_processes=None,
session=requests.Session,
):
if num_processes is None:
num_processes = multiprocessing.cpu_count() or 1
@@ -40,8 +46,12 @@ class Pool(object):
self._auth = auth_generator or _identity
self._session = session
self._pool = [
thread.SessionThread(self._new_session(), self._job_queue,
self._response_queue, self._exc_queue)
thread.SessionThread(
self._new_session(),
self._job_queue,
self._response_queue,
self._exc_queue,
)
for _ in range(self._processes)
]
@@ -85,12 +95,12 @@ class Pool(object):
:returns: An initialized :class:`~Pool` object.
:rtype: :class:`~Pool`
"""
request_dict = {'method': 'GET'}
request_dict = {"method": "GET"}
request_dict.update(request_kwargs or {})
job_queue = queue.Queue()
for url in urls:
job = request_dict.copy()
job.update({'url': url})
job.update({"url": url})
job_queue.put(job)
return cls(job_queue=job_queue, **kwargs)
@@ -172,8 +182,9 @@ class ThreadResponse(ThreadProxy):
json = thread_response.json()
"""
proxied_attr = 'response'
attrs = frozenset(['request_kwargs', 'response'])
proxied_attr = "response"
attrs = frozenset(["request_kwargs", "response"])
def __init__(self, request_kwargs, response):
#: The original keyword arguments provided to the queue
@@ -194,8 +205,9 @@ class ThreadException(ThreadProxy):
msg = thread_exc.message
"""
proxied_attr = 'exception'
attrs = frozenset(['request_kwargs', 'exception'])
proxied_attr = "exception"
attrs = frozenset(["request_kwargs", "exception"])
def __init__(self, request_kwargs, exception):
#: The original keyword arguments provided to the queue
@@ -208,4 +220,4 @@ def _identity(session_obj):
return session_obj
__all__ = ['ThreadException', 'ThreadResponse', 'Pool']
__all__ = ["ThreadException", "ThreadResponse", "Pool"]
+2 -6
View File
@@ -8,8 +8,7 @@ from .._compat import queue
class SessionThread(object):
def __init__(self, initialized_session, job_queue, response_queue,
exception_queue):
def __init__(self, initialized_session, job_queue, response_queue, exception_queue):
self._session = initialized_session
self._jobs = job_queue
self._create_worker()
@@ -17,10 +16,7 @@ class SessionThread(object):
self._exceptions = exception_queue
def _create_worker(self):
self._worker = threading.Thread(
target=self._make_request,
name=uuid.uuid4(),
)
self._worker = threading.Thread(target=self._make_request, name=uuid.uuid4())
self._worker.daemon = True
self._worker._state = 0
self._worker.start()
+5 -10
View File
@@ -5,17 +5,13 @@ import sys
from requests import utils
find_charset = re.compile(
br'<meta.*?charset=["\']*(.+?)["\'>]', flags=re.I
).findall
find_charset = re.compile(br'<meta.*?charset=["\']*(.+?)["\'>]', flags=re.I).findall
find_pragma = re.compile(
br'<meta.*?content=["\']*;?charset=(.+?)["\'>]', flags=re.I
).findall
find_xml = re.compile(
br'^<\?xml.*?encoding=["\']*(.+?)["\'>]'
).findall
find_xml = re.compile(br'^<\?xml.*?encoding=["\']*(.+?)["\'>]').findall
def get_encodings_from_content(content):
@@ -34,10 +30,9 @@ def get_encodings_from_content(content):
:return: encodings detected in the provided content
:rtype: list(str)
"""
encodings = (find_charset(content) + find_pragma(content)
+ find_xml(content))
encodings = find_charset(content) + find_pragma(content) + find_xml(content)
if (3, 0) <= sys.version_info < (4, 0):
encodings = [encoding.decode('utf8') for encoding in encodings]
encodings = [encoding.decode("utf8") for encoding in encodings]
return encodings
@@ -85,7 +80,7 @@ def get_unicode_from_response(response):
# Fall back:
if encoding:
try:
return str(response.content, encoding, errors='replace')
return str(response.content, encoding, errors="replace")
except TypeError:
pass
return response.text
+39 -38
View File
@@ -4,16 +4,11 @@ import collections
from requests import compat
__all__ = ('dump_response', 'dump_all')
__all__ = ("dump_response", "dump_all")
HTTP_VERSIONS = {
9: b'0.9',
10: b'1.0',
11: b'1.1',
}
HTTP_VERSIONS = {9: b"0.9", 10: b"1.0", 11: b"1.1"}
_PrefixSettings = collections.namedtuple('PrefixSettings',
['request', 'response'])
_PrefixSettings = collections.namedtuple("PrefixSettings", ["request", "response"])
class PrefixSettings(_PrefixSettings):
@@ -24,32 +19,31 @@ class PrefixSettings(_PrefixSettings):
def _get_proxy_information(response):
if getattr(response.connection, 'proxy_manager', False):
if getattr(response.connection, "proxy_manager", False):
proxy_info = {}
request_url = response.request.url
if request_url.startswith('https://'):
proxy_info['method'] = 'CONNECT'
if request_url.startswith("https://"):
proxy_info["method"] = "CONNECT"
proxy_info['request_path'] = request_url
proxy_info["request_path"] = request_url
return proxy_info
return None
def _format_header(name, value):
return (_coerce_to_bytes(name) + b': ' + _coerce_to_bytes(value) +
b'\r\n')
return _coerce_to_bytes(name) + b": " + _coerce_to_bytes(value) + b"\r\n"
def _build_request_path(url, proxy_info):
uri = compat.urlparse(url)
proxy_url = proxy_info.get('request_path')
proxy_url = proxy_info.get("request_path")
if proxy_url is not None:
request_path = _coerce_to_bytes(proxy_url)
return request_path, uri
request_path = _coerce_to_bytes(uri.path)
if uri.query:
request_path += b'?' + _coerce_to_bytes(uri.query)
request_path += b"?" + _coerce_to_bytes(uri.query)
return request_path, uri
@@ -59,29 +53,29 @@ def _dump_request_data(request, prefixes, bytearr, proxy_info=None):
proxy_info = {}
prefix = prefixes.request
method = _coerce_to_bytes(proxy_info.pop('method', request.method))
method = _coerce_to_bytes(proxy_info.pop("method", request.method))
request_path, uri = _build_request_path(request.url, proxy_info)
# <prefix><METHOD> <request-path> HTTP/1.1
bytearr.extend(prefix + method + b' ' + request_path + b' HTTP/1.1\r\n')
bytearr.extend(prefix + method + b" " + request_path + b" HTTP/1.1\r\n")
# <prefix>Host: <request-host> OR host header specified by user
headers = request.headers.copy()
host_header = _coerce_to_bytes(headers.pop('Host', uri.netloc))
bytearr.extend(prefix + b'Host: ' + host_header + b'\r\n')
host_header = _coerce_to_bytes(headers.pop("Host", uri.netloc))
bytearr.extend(prefix + b"Host: " + host_header + b"\r\n")
for name, value in headers.items():
bytearr.extend(prefix + _format_header(name, value))
bytearr.extend(prefix + b'\r\n')
bytearr.extend(prefix + b"\r\n")
if request.body:
if isinstance(request.body, compat.basestring):
bytearr.extend(prefix + _coerce_to_bytes(request.body))
else:
# In the event that the body is a file-like object, let's not try
# to read everything into memory.
bytearr.extend(b'<< Request body is not a string-like type >>')
bytearr.extend(b'\r\n')
bytearr.extend(b"<< Request body is not a string-like type >>")
bytearr.extend(b"\r\n")
def _dump_response_data(response, prefixes, bytearr):
@@ -90,32 +84,40 @@ def _dump_response_data(response, prefixes, bytearr):
raw = response.raw
# Let's convert the version int from httplib to bytes
version_str = HTTP_VERSIONS.get(raw.version, b'?')
version_str = HTTP_VERSIONS.get(raw.version, b"?")
# <prefix>HTTP/<version_str> <status_code> <reason>
bytearr.extend(prefix + b'HTTP/' + version_str + b' ' +
str(raw.status).encode('ascii') + b' ' +
_coerce_to_bytes(response.reason) + b'\r\n')
bytearr.extend(
prefix
+ b"HTTP/"
+ version_str
+ b" "
+ str(raw.status).encode("ascii")
+ b" "
+ _coerce_to_bytes(response.reason)
+ b"\r\n"
)
headers = raw.headers
for name in headers.keys():
for value in headers.getlist(name):
bytearr.extend(prefix + _format_header(name, value))
bytearr.extend(prefix + b'\r\n')
bytearr.extend(prefix + b"\r\n")
bytearr.extend(response.content)
def _coerce_to_bytes(data):
if not isinstance(data, bytes) and hasattr(data, 'encode'):
data = data.encode('utf-8')
if not isinstance(data, bytes) and hasattr(data, "encode"):
data = data.encode("utf-8")
# Don't bail out with an exception if data is None
return data if data is not None else b''
return data if data is not None else b""
def dump_response(response, request_prefix=b'< ', response_prefix=b'> ',
data_array=None):
def dump_response(
response, request_prefix=b"< ", response_prefix=b"> ", data_array=None
):
"""Dump a single request-response cycle's information.
This will take a response object and dump only the data that requests can
@@ -148,17 +150,16 @@ def dump_response(response, request_prefix=b'< ', response_prefix=b'> ',
data = data_array if data_array is not None else bytearray()
prefixes = PrefixSettings(request_prefix, response_prefix)
if not hasattr(response, 'request'):
raise ValueError('Response has no associated request')
if not hasattr(response, "request"):
raise ValueError("Response has no associated request")
proxy_info = _get_proxy_information(response)
_dump_request_data(response.request, prefixes, data,
proxy_info=proxy_info)
_dump_request_data(response.request, prefixes, data, proxy_info=proxy_info)
_dump_response_data(response, prefixes, data)
return data
def dump_all(response, request_prefix=b'< ', response_prefix=b'> '):
def dump_all(response, request_prefix=b"< ", response_prefix=b"> "):
"""Dump all requests and responses including redirects.
This takes the response returned by requests and will dump all
+7 -5
View File
@@ -4,7 +4,7 @@ from .._compat import basestring
from .._compat import urlencode as _urlencode
__all__ = ('urlencode',)
__all__ = ("urlencode",)
def urlencode(query, *args, **kwargs):
@@ -76,8 +76,10 @@ def urlencode(query, *args, **kwargs):
original_query_list = _to_kv_list(query)
if not all(_is_two_tuple(i) for i in original_query_list):
raise ValueError("Expected query to be able to be converted to a "
"list comprised of length 2 tuples.")
raise ValueError(
"Expected query to be able to be converted to a "
"list comprised of length 2 tuples."
)
query_list = original_query_list
while any(isinstance(v, expand_classes) for _, v in query_list):
@@ -87,7 +89,7 @@ def urlencode(query, *args, **kwargs):
def _to_kv_list(dict_or_list):
if hasattr(dict_or_list, 'items'):
if hasattr(dict_or_list, "items"):
return list(dict_or_list.items())
return dict_or_list
@@ -102,7 +104,7 @@ def _expand_query_values(original_query_list):
if isinstance(value, basestring):
query_list.append((key, value))
else:
key_fmt = key + '[%s]'
key_fmt = key + "[%s]"
value_list = _to_kv_list(value)
query_list.extend((key_fmt % k, v) for k, v in value_list)
return query_list
+25 -22
View File
@@ -20,12 +20,13 @@ def user_agent(name, version, extras=None):
if extras is None:
extras = []
return UserAgentBuilder(
name, version
).include_extras(
extras
).include_implementation(
).include_system().build()
return (
UserAgentBuilder(name, version)
.include_extras(extras)
.include_implementation()
.include_system()
.build()
)
class UserAgentBuilder(object):
@@ -47,7 +48,7 @@ class UserAgentBuilder(object):
"""
format_string = '%s/%s'
format_string = "%s/%s"
def __init__(self, name, version):
"""Initialize our builder with the name and version of our user agent.
@@ -76,7 +77,7 @@ class UserAgentBuilder(object):
list of tuples of extra-name and extra-version
"""
if any(len(extra) != 2 for extra in extras):
raise ValueError('Extras should be a sequence of two item tuples.')
raise ValueError("Extras should be a sequence of two item tuples.")
self._pieces.extend(extras)
return self
@@ -109,22 +110,24 @@ def _implementation_tuple():
"""
implementation = platform.python_implementation()
if implementation == 'CPython':
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == 'PyPy':
implementation_version = '%s.%s.%s' % (sys.pypy_version_info.major,
sys.pypy_version_info.minor,
sys.pypy_version_info.micro)
if sys.pypy_version_info.releaselevel != 'final':
implementation_version = ''.join([
implementation_version, sys.pypy_version_info.releaselevel
])
elif implementation == 'Jython':
elif implementation == "PyPy":
implementation_version = "%s.%s.%s" % (
sys.pypy_version_info.major,
sys.pypy_version_info.minor,
sys.pypy_version_info.micro,
)
if sys.pypy_version_info.releaselevel != "final":
implementation_version = "".join(
[implementation_version, sys.pypy_version_info.releaselevel]
)
elif implementation == "Jython":
implementation_version = platform.python_version() # Complete Guess
elif implementation == 'IronPython':
elif implementation == "IronPython":
implementation_version = platform.python_version() # Complete Guess
else:
implementation_version = 'Unknown'
implementation_version = "Unknown"
return (implementation, implementation_version)
@@ -138,6 +141,6 @@ def _platform_tuple():
p_system = platform.system()
p_release = platform.release()
except IOError:
p_system = 'Unknown'
p_release = 'Unknown'
p_system = "Unknown"
p_release = "Unknown"
return (p_system, p_release)
+1 -1
View File
@@ -40,7 +40,7 @@ from ._basics import (
basestring,
integer_types,
proxy_bypass_environment,
getproxies_environment
getproxies_environment,
)
from .http_cookies import cookiejar_from_dict, RequestsCookieJar
from ._structures import HTTPHeaderDict
+57 -67
View File
@@ -5,131 +5,121 @@ import sys
from codecs import open
from setuptools import setup
from setuptools import setup, Command
from setuptools.command.test import test as TestCommand
here = os.path.abspath(os.path.dirname(__file__))
class Format(TestCommand):
class Format(Command):
user_options = []
def initialize_options(self):
TestCommand.initialize_options(self)
pass
def finalize_options(self):
pass
def run_tests(self):
os.system('white requests')
def run(self):
os.system("black requests3")
class PyTest(TestCommand):
user_options = [('pytest-args=', 'a', "Arguments to pass into py.test")]
class PyTest(Command):
user_options = [("pytest-args=", "a", "Arguments to pass into py.test")]
def initialize_options(self):
TestCommand.initialize_options(self)
self.pytest_args = ['-n', 'auto']
pass
def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = []
self.test_suite = True
# Command.finalize_options(self)
pass
def run_tests(self):
def run(self):
import pytest
errno = pytest.main(self.pytest_args)
errno = pytest.main(["-n", "auto"])
sys.exit(errno)
class MyPyTest(TestCommand):
user_options = [('pytest-args=', 'a', "Arguments to pass into py.test")]
class MyPyTest(Command):
user_options = [("pytest-args=", "a", "Arguments to pass into py.test")]
def initialize_options(self):
TestCommand.initialize_options(self)
self.pytest_args = ['-n', 'auto', '--mypy', 'tests']
pass
def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = []
self.test_suite = True
pass
def run_tests(self):
import pytest
errno = pytest.main(self.pytest_args)
errno = pytest.main(["-n", "auto", "--mypy", "tests"])
sys.exit(errno)
# 'setup.py publish' shortcut.
if sys.argv[-1] == 'publish':
os.system('python setup.py sdist bdist_wheel')
os.system('twine upload dist/*')
if sys.argv[-1] == "publish":
os.system("python setup.py sdist bdist_wheel")
os.system("twine upload dist/*")
sys.exit()
packages = ['requests3']
packages = ["requests3"]
requires = [
'chardet>=3.0.2,<3.1.0',
'idna>=2.5,<2.9',
'urllib3>=1.21.1,<1.25',
'certifi>=2017.4.17',
'requests-toolbelt'
"chardet>=3.0.2,<3.1.0",
"idna>=2.5,<2.9",
"urllib3>=1.21.1,<1.25",
"certifi>=2017.4.17",
]
test_requirements = [
'pytest-httpbin==0.0.7',
'pytest-cov',
'pytest-mock',
'pytest-xdist',
'PySocks>=1.5.6, !=1.5.7',
'pytest>=2.8.0'
"pytest-httpbin==0.0.7",
"pytest-cov",
"pytest-mock",
"pytest-xdist",
"PySocks>=1.5.6, !=1.5.7",
"pytest>=2.8.0",
]
about = {}
with open(os.path.join(here, 'requests3', '__version__.py'), 'r', 'utf-8') as f:
with open(os.path.join(here, "requests3", "__version__.py"), "r", "utf-8") as f:
exec(f.read(), about)
with open('README.md', 'r', 'utf-8') as f:
with open("README.md", "r", "utf-8") as f:
readme = f.read()
with open('HISTORY.md', 'r', 'utf-8') as f:
with open("HISTORY.md", "r", "utf-8") as f:
history = f.read()
setup(
name=about['__title__'],
version=about['__version__'],
description=about['__description__'],
name=about["__title__"],
version=about["__version__"],
description=about["__description__"],
long_description=readme,
long_description_content_type='text/markdown',
author=about['__author__'],
author_email=about['__author_email__'],
url=about['__url__'],
long_description_content_type="text/markdown",
author=about["__author__"],
author_email=about["__author_email__"],
url=about["__url__"],
packages=packages,
package_data={'': ['LICENSE', 'NOTICE'], 'requests': ['*.pem']},
package_dir={'requests': 'requests'},
package_data={"": ["LICENSE", "NOTICE"], "requests": ["*.pem"]},
package_dir={"requests": "requests"},
include_package_data=True,
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
install_requires=requires,
license=about['__license__'],
license=about["__license__"],
zip_safe=False,
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Natural Language :: English',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Natural Language :: English",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
],
cmdclass={'test': PyTest, 'mypy': MyPyTest, 'format': Format},
cmdclass={"test": PyTest, "mypy": MyPyTest, "format": Format},
tests_require=test_requirements,
extras_require={
'security': ['pyOpenSSL >= 0.14', 'cryptography>=1.3.4', 'idna>=2.0.0'],
'socks': ['PySocks>=1.5.6, !=1.5.7'],
'socks:sys_platform == "win32" and python_version == "2.7"': [
'win_inet_pton'
],
"security": ["pyOpenSSL >= 0.14", "cryptography>=1.3.4", "idna>=2.0.0"],
"socks": ["PySocks>=1.5.6, !=1.5.7"],
'socks:sys_platform == "win32" and python_version == "2.7"': ["win_inet_pton"],
},
)