new urllib3

This commit is contained in:
Kenneth Reitz
2012-12-17 04:08:36 -05:00
parent e7598e006c
commit 2c5b0207f5
8 changed files with 191 additions and 73 deletions
+38 -17
View File
@@ -6,8 +6,9 @@
import logging
import socket
import errno
from socket import timeout as SocketTimeout
from socket import error as SocketError, timeout as SocketTimeout
try: # Python 3
from http.client import HTTPConnection, HTTPException
@@ -41,7 +42,7 @@ except (ImportError, AttributeError): # Platform-specific: No SSL.
from .request import RequestMethods
from .response import HTTPResponse
from .util import get_host, is_connection_dropped
from .util import get_host, is_connection_dropped, ssl_wrap_socket
from .exceptions import (
ClosedPoolError,
EmptyPoolError,
@@ -76,6 +77,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
"""
cert_reqs = None
ca_certs = None
ssl_version = None
def set_cert(self, key_file=None, cert_file=None,
cert_reqs='CERT_NONE', ca_certs=None):
@@ -96,9 +98,12 @@ class VerifiedHTTPSConnection(HTTPSConnection):
# Wrap socket using verification with the root certs in
# trusted_root_certs
self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
self.sock = ssl_wrap_socket(sock, self.key_file, self.cert_file,
cert_reqs=self.cert_reqs,
ca_certs=self.ca_certs)
ca_certs=self.ca_certs,
server_hostname=self.host,
ssl_version=self.ssl_version)
if self.ca_certs:
match_hostname(self.sock.getpeercert(), self.host)
@@ -166,13 +171,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1,
block=False, headers=None):
super(HTTPConnectionPool, self).__init__(host, port)
ConnectionPool.__init__(self, host, port)
RequestMethods.__init__(self, headers)
self.strict = strict
self.timeout = timeout
self.pool = self.QueueCls(maxsize)
self.block = block
self.headers = headers or {}
# Fill the queue up so that doing get() on it will block properly
for _ in xrange(maxsize):
@@ -189,7 +194,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.num_connections += 1
log.info("Starting new HTTP connection (%d): %s" %
(self.num_connections, self.host))
return HTTPConnection(host=self.host, port=self.port)
return HTTPConnection(host=self.host,
port=self.port,
strict=self.strict)
def _get_conn(self, timeout=None):
"""
@@ -449,12 +456,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Name mismatch
raise SSLError(e)
except HTTPException as e:
except (HTTPException, SocketError) as e:
# Connection broken, discard. It will be replaced next _get_conn().
conn = None
# This is necessary so we can access e below
err = e
if retries == 0:
raise MaxRetryError(self, url, e)
finally:
if release_conn:
# Put the connection back to be reused. If the connection is
@@ -491,11 +501,11 @@ class HTTPSConnectionPool(HTTPConnectionPool):
When Python is compiled with the :mod:`ssl` module, then
:class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates,
instead of :class:httplib.HTTPSConnection`.
instead of :class:`httplib.HTTPSConnection`.
The ``key_file``, ``cert_file``, ``cert_reqs``, and ``ca_certs`` parameters
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, and ``ssl_version``
are only used if :mod:`ssl` is available and are fed into
:meth:`ssl.wrap_socket` to upgrade the connection socket into an SSL socket.
:meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket into an SSL socket.
"""
scheme = 'https'
@@ -504,15 +514,16 @@ class HTTPSConnectionPool(HTTPConnectionPool):
strict=False, timeout=None, maxsize=1,
block=False, headers=None,
key_file=None, cert_file=None,
cert_reqs='CERT_NONE', ca_certs=None):
cert_reqs='CERT_NONE', ca_certs=None, ssl_version=None):
super(HTTPSConnectionPool, self).__init__(host, port,
strict, timeout, maxsize,
block, headers)
HTTPConnectionPool.__init__(self, host, port,
strict, timeout, maxsize,
block, headers)
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
self.ca_certs = ca_certs
self.ssl_version = ssl_version
def _new_conn(self):
"""
@@ -527,11 +538,21 @@ class HTTPSConnectionPool(HTTPConnectionPool):
raise SSLError("Can't connect to HTTPS URL because the SSL "
"module is not available.")
return HTTPSConnection(host=self.host, port=self.port)
return HTTPSConnection(host=self.host,
port=self.port,
strict=self.strict)
connection = VerifiedHTTPSConnection(host=self.host, port=self.port)
connection = VerifiedHTTPSConnection(host=self.host,
port=self.port,
strict=self.strict)
connection.set_cert(key_file=self.key_file, cert_file=self.cert_file,
cert_reqs=self.cert_reqs, ca_certs=self.ca_certs)
if self.ssl_version is None:
connection.ssl_version = ssl.PROTOCOL_SSLv23
else:
connection.ssl_version = self.ssl_version
return connection
+14 -4
View File
@@ -18,6 +18,10 @@ class PoolError(HTTPError):
self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message))
def __reduce__(self):
# For pickling purposes.
return self.__class__, (None, self.url)
class SSLError(HTTPError):
"Raised when SSL certificate fails in an HTTPS connection."
@@ -34,10 +38,16 @@ class DecodeError(HTTPError):
class MaxRetryError(PoolError):
"Raised when the maximum number of retries is exceeded."
def __init__(self, pool, url):
message = "Max retries exceeded with url: %s" % url
PoolError.__init__(self, pool, message)
def __init__(self, pool, url, reason=None):
self.reason = reason
message = "Max retries exceeded with url: %s" % url
if reason:
message += " (Caused by %s: %s)" % (type(reason), reason)
else:
message += " (Caused by redirect)"
PoolError.__init__(self, pool, message)
self.url = url
@@ -72,6 +82,6 @@ class LocationParseError(ValueError, HTTPError):
def __init__(self, location):
message = "Failed to parse: %s" % location
super(LocationParseError, self).__init__(self, message)
HTTPError.__init__(self, message)
self.location = location
+15 -8
View File
@@ -41,13 +41,16 @@ def iter_fields(fields):
def encode_multipart_formdata(fields, boundary=None):
"""
Encode a dictionary of ``fields`` using the multipart/form-data mime format.
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
:param fields:
Dictionary of fields or list of (key, value) field tuples. The key is
treated as the field name, and the value as the body of the form-data
bytes. If the value is a tuple of two elements, then the first element
is treated as the filename of the form-data section.
Dictionary of fields or list of (key, value) or (key, value, MIME type)
field tuples. The key is treated as the field name, and the value as
the body of the form-data bytes. If the value is a tuple of two
elements, then the first element is treated as the filename of the
form-data section and a suitable MIME type is guessed based on the
filename. If the value is a tuple of three elements, then the third
element is treated as an explicit MIME type of the form-data section.
Field names and filenames must be unicode.
@@ -63,16 +66,20 @@ def encode_multipart_formdata(fields, boundary=None):
body.write(b('--%s\r\n' % (boundary)))
if isinstance(value, tuple):
filename, data = value
if len(value) == 3:
filename, data, content_type = value
else:
filename, data = value
content_type = get_content_type(filename)
writer(body).write('Content-Disposition: form-data; name="%s"; '
'filename="%s"\r\n' % (fieldname, filename))
body.write(b('Content-Type: %s\r\n\r\n' %
(get_content_type(filename))))
(content_type,)))
else:
data = value
writer(body).write('Content-Disposition: form-data; name="%s"\r\n'
% (fieldname))
body.write(b'Content-Type: text/plain\r\n\r\n')
body.write(b'\r\n')
if isinstance(data, int):
data = str(data) # Backwards compatibility
+33 -20
View File
@@ -24,7 +24,7 @@ import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.1.0"
__version__ = "1.2.0" # Revision 41c74fef2ded
# True if we are running on Python 3.
@@ -45,19 +45,23 @@ else:
text_type = unicode
binary_type = str
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
if sys.platform.startswith("java"):
# Jython always uses 32 bits.
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
del X
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
del X
def _add_doc(func, doc):
@@ -132,6 +136,7 @@ class _MovedItems(types.ModuleType):
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
MovedAttribute("reduce", "__builtin__", "functools"),
@@ -178,7 +183,7 @@ for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
del attr
moves = sys.modules["six.moves"] = _MovedItems("moves")
moves = sys.modules[__name__ + ".moves"] = _MovedItems("moves")
def add_move(move):
@@ -219,12 +224,19 @@ else:
_iteritems = "iteritems"
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
if PY3:
def get_unbound_function(unbound):
return unbound
advance_iterator = next
Iterator = object
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
@@ -232,9 +244,10 @@ else:
def get_unbound_function(unbound):
return unbound.im_func
class Iterator(object):
def advance_iterator(it):
return it.next()
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
@@ -249,15 +262,15 @@ get_function_defaults = operator.attrgetter(_func_defaults)
def iterkeys(d):
"""Return an iterator over the keys of a dictionary."""
return getattr(d, _iterkeys)()
return iter(getattr(d, _iterkeys)())
def itervalues(d):
"""Return an iterator over the values of a dictionary."""
return getattr(d, _itervalues)()
return iter(getattr(d, _itervalues)())
def iteritems(d):
"""Return an iterator over the (key, value) pairs of a dictionary."""
return getattr(d, _iteritems)()
return iter(getattr(d, _iteritems)())
if PY3:
+18 -13
View File
@@ -30,8 +30,12 @@ class PoolManager(RequestMethods):
necessary connection pools for you.
:param num_pools:
Number of connection pools to cache before discarding the least recently
used pool.
Number of connection pools to cache before discarding the least
recently used pool.
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
:param \**connection_pool_kw:
Additional parameters are used to create fresh
@@ -40,15 +44,16 @@ class PoolManager(RequestMethods):
Example: ::
>>> manager = PoolManager(num_pools=2)
>>> r = manager.urlopen("http://google.com/")
>>> r = manager.urlopen("http://google.com/mail")
>>> r = manager.urlopen("http://yahoo.com/")
>>> r = manager.request('GET', 'http://google.com/')
>>> r = manager.request('GET', 'http://google.com/mail')
>>> r = manager.request('GET', 'http://yahoo.com/')
>>> len(manager.pools)
2
"""
def __init__(self, num_pools=10, **connection_pool_kw):
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools,
dispose_func=lambda p: p.close())
@@ -113,6 +118,8 @@ class PoolManager(RequestMethods):
kw['assert_same_host'] = False
kw['redirect'] = False
if 'headers' not in kw:
kw['headers'] = self.headers
response = conn.urlopen(method, u.request_uri, **kw)
@@ -124,7 +131,7 @@ class PoolManager(RequestMethods):
method = 'GET'
log.info("Redirecting %s -> %s" % (url, redirect_location))
kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown
kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown
return self.urlopen(method, redirect_location, **kw)
@@ -138,13 +145,11 @@ class ProxyManager(RequestMethods):
self.proxy_pool = proxy_pool
def _set_proxy_headers(self, headers=None):
headers = headers or {}
headers_ = {'Accept': '*/*'}
if headers:
headers_.update(headers)
# Same headers are curl passes for --proxy1.0
headers['Accept'] = '*/*'
headers['Proxy-Connection'] = 'Keep-Alive'
return headers
return headers_
def urlopen(self, method, url, **kw):
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
+20 -6
View File
@@ -36,12 +36,20 @@ class RequestMethods(object):
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
Initializer parameters:
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS'])
_encode_body_methods = set(['PATCH', 'POST', 'PUT', 'TRACE'])
def __init__(self, headers=None):
self.headers = headers or {}
def urlopen(self, method, url, body=None, headers=None,
encode_multipart=True, multipart_boundary=None,
**kw): # Abstract
@@ -97,13 +105,16 @@ class RequestMethods(object):
such as with OAuth.
Supports an optional ``fields`` parameter of key/value strings AND
key/filetuple. A filetuple is a (filename, data) tuple. For example: ::
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
the MIME type is optional. For example: ::
fields = {
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'nonamefile': ('contents of nonamefile field'),
'typedfile': ('bazfile.bin', open('bazfile').read(),
'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
}
When uploading a file, providing a filename (the first parameter of the
@@ -121,8 +132,11 @@ class RequestMethods(object):
body, content_type = (urlencode(fields or {}),
'application/x-www-form-urlencoded')
headers = headers or {}
headers.update({'Content-Type': content_type})
if headers is None:
headers = self.headers
return self.urlopen(method, url, body=body, headers=headers,
headers_ = {'Content-Type': content_type}
headers_.update(headers)
return self.urlopen(method, url, body=body, headers=headers_,
**urlopen_kw)
+3 -1
View File
@@ -130,7 +130,9 @@ class HTTPResponse(object):
after having ``.read()`` the file object. (Overridden if ``amt`` is
set.)
"""
content_encoding = self.headers.get('content-encoding')
# Note: content-encoding value should be case-insensitive, per RFC 2616
# Section 3.5
content_encoding = self.headers.get('content-encoding', '').lower()
decoder = self.CONTENT_DECODERS.get(content_encoding)
if decode_content is None:
decode_content = self._decode_content
+50 -4
View File
@@ -11,13 +11,24 @@ from socket import error as SocketError
try:
from select import poll, POLLIN
except ImportError: # `poll` doesn't exist on OSX and other platforms
except ImportError: # `poll` doesn't exist on OSX and other platforms
poll = False
try:
from select import select
except ImportError: # `select` doesn't exist on AppEngine.
except ImportError: # `select` doesn't exist on AppEngine.
select = False
try: # Test for SSL features
SSLContext = None
HAS_SNI = False
from ssl import wrap_socket, CERT_NONE, SSLError, PROTOCOL_SSLv23
from ssl import SSLContext # Modern SSL?
from ssl import HAS_SNI # Has SNI?
except ImportError:
pass
from .packages import six
from .exceptions import LocationParseError
@@ -92,9 +103,9 @@ def parse_url(url):
>>> parse_url('http://google.com/mail/')
Url(scheme='http', host='google.com', port=None, path='/', ...)
>>> prase_url('google.com:80')
>>> parse_url('google.com:80')
Url(scheme=None, host='google.com', port=80, path=None, ...)
>>> prase_url('/foo?bar')
>>> parse_url('/foo?bar')
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
"""
@@ -250,3 +261,38 @@ def is_connection_dropped(conn):
if fno == sock.fileno():
# Either data is buffered (bad), or the connection is dropped.
return True
if SSLContext is not None: # Python 3.2+
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE,
ca_certs=None, server_hostname=None,
ssl_version=PROTOCOL_SSLv23):
"""
All arguments except `server_hostname` have the same meaning as for
:func:`ssl.wrap_socket`
:param server_hostname:
Hostname of the expected certificate
"""
context = SSLContext(ssl_version)
context.verify_mode = cert_reqs
if ca_certs:
try:
context.load_verify_locations(ca_certs)
except TypeError as e: # Reraise as SSLError
# FIXME: This block needs a test.
raise SSLError(e)
if certfile:
# FIXME: This block needs a test.
context.load_cert_chain(certfile, keyfile)
if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI
return context.wrap_socket(sock, server_hostname=server_hostname)
return context.wrap_socket(sock)
else: # Python 3.1 and earlier
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE,
ca_certs=None, server_hostname=None,
ssl_version=PROTOCOL_SSLv23):
return wrap_socket(sock, keyfile=keyfile, certfile=certfile,
ca_certs=ca_certs, cert_reqs=cert_reqs,
ssl_version=ssl_version)