typing for utils.py

Signed-off-by: Kenneth Reitz <me@kennethreitz.org>
This commit is contained in:
2018-03-15 15:08:23 -04:00
parent e16fc666af
commit f64b27bfed
2 changed files with 53 additions and 80 deletions
+4
View File
@@ -64,3 +64,7 @@ Verify = Union[None, bool, Text]
Cert = Union[Text, Tuple[Text, Text]]
JSON = Optional[MutableMapping]
Help = Dict
Host = str
Sequence = List
Filename = str
KeyValueList = List[Tuple[Text, Text]]
+49 -80
View File
@@ -18,12 +18,11 @@ import re
import socket
import struct
import warnings
import typing
from .__version__ import __version__
from .import certs
# to_native_string is unused here, but imported here for backwards compatibility
from ._internal_utils import to_native_string
from .basics import parse_http_list as _parse_list_header
from .basics import (
quote,
@@ -40,7 +39,8 @@ from .basics import (
getproxies_environment,
)
from .cookies import cookiejar_from_dict
from .structures import CaseInsensitiveDict
from .structures import HTTPHeaderDict
from .cookies import RequestsCookieJar
from .exceptions import (
InvalidURL, InvalidHeader, FileModeWarning, UnrewindableBodyError
)
@@ -50,7 +50,7 @@ DEFAULT_CA_BUNDLE_PATH = certs.where()
if platform.system() == 'Windows':
# provide a proxy_bypass version on Windows without DNS lookups
def proxy_bypass_registry(host):
def proxy_bypass_registry(host: str) -> bool:
import winreg
try:
@@ -90,7 +90,7 @@ if platform.system() == 'Windows':
return False
def proxy_bypass(host): # noqa
def proxy_bypass(host: str) -> bool: # noqa
"""Return True, if the host should be bypassed.
Checks proxy settings gathered from the environment, if specified,
@@ -103,14 +103,14 @@ if platform.system() == 'Windows':
return proxy_bypass_registry(host)
def dict_to_sequence(d):
def dict_to_sequence(d: dict) -> typing.List:
"""Returns an internal sequence dictionary update."""
if hasattr(d, 'items'):
d = d.items()
return d
def super_len(o):
def super_len(o) -> int:
total_length = None
current_position = 0
if hasattr(o, '__len__'):
@@ -131,10 +131,10 @@ def super_len(o):
(
"Requests has determined the content-length for this "
"request using the binary size of the file: however, the "
"file has been opened in text mode (i.e. without the 'b' "
"file has been opened in typing.Text mode (i.e. without the 'b' "
"flag in the mode). This may lead to an incorrect "
"content-length. In Requests 3.0, support will be removed "
"for files in text mode."
"for files in typing.Text mode."
),
FileModeWarning,
)
@@ -165,7 +165,9 @@ def super_len(o):
return max(0, total_length - current_position)
def get_netrc_auth(url, raise_errors=False):
def get_netrc_auth(
url: str, raise_errors: bool = False
) -> typing.Tuple[typing.Text, typing.Text]:
"""Returns the Requests tuple auth for a given url from netrc."""
try:
from netrc import netrc, NetrcParseError
@@ -213,7 +215,7 @@ def get_netrc_auth(url, raise_errors=False):
pass
def guess_filename(obj):
def guess_filename(obj) -> str:
"""Tries to guess the filename of the given object."""
name = getattr(obj, 'name', None)
if (
@@ -250,7 +252,7 @@ def from_key_val_list(value):
return collections.OrderedDict(value)
def to_key_val_list(value):
def to_key_val_list(value) -> typing.List[typing.Tuple[typing.Text, typing.Text]]:
"""Take an object and test to see if it can be represented as a
dictionary. If it can be, return a list of tuples, e.g.,
@@ -279,7 +281,7 @@ def to_key_val_list(value):
# From mitsuhiko/werkzeug (used with permission).
def parse_list_header(value):
def parse_list_header(value: str) -> typing.List[typing.Text]:
"""Parse lists as described by RFC 2068 Section 2.
In particular, parse comma-separated lists where the elements of
@@ -313,7 +315,7 @@ def parse_list_header(value):
# From mitsuhiko/werkzeug (used with permission).
def parse_dict_header(value):
def parse_dict_header(value) -> dict:
"""Parse lists of key, value pairs as described by RFC 2068 Section 2 and
convert them into a python dict:
@@ -351,7 +353,7 @@ def parse_dict_header(value):
# From mitsuhiko/werkzeug (used with permission).
def unquote_header_value(value, is_filename=False):
def unquote_header_value(value: str, is_filename: bool = False):
r"""Unquotes a header value. (Reversal of :func:`quote_header_value`).
This does not use the real unquoting but what browsers are actually
using for quoting.
@@ -376,7 +378,7 @@ def unquote_header_value(value, is_filename=False):
return value
def dict_from_cookiejar(cj):
def dict_from_cookiejar(cj: RequestsCookieJar) -> dict:
"""Returns a key/value dictionary from a CookieJar.
:param cj: CookieJar object to extract cookies from.
@@ -388,7 +390,9 @@ def dict_from_cookiejar(cj):
return cookie_dict
def add_dict_to_cookiejar(cj, cookie_dict):
def add_dict_to_cookiejar(
cj: RequestsCookieJar, cookie_dict: dict
) -> RequestsCookieJar:
"""Returns a CookieJar from a key/value dictionary.
:param cj: CookieJar to insert cookies into.
@@ -398,7 +402,7 @@ def add_dict_to_cookiejar(cj, cookie_dict):
return cookiejar_from_dict(cookie_dict, cj)
def get_encodings_from_content(content):
def get_encodings_from_content(content: str) -> typing.List[str]:
"""Returns encodings from given content string.
:param content: bytestring to extract encodings from.
@@ -423,7 +427,7 @@ def get_encodings_from_content(content):
)
def get_encoding_from_headers(headers):
def get_encoding_from_headers(headers: typing.MutableMapping) -> str:
"""Returns encodings from given HTTP Header Dict.
:param headers: dictionary to extract encoding from.
@@ -437,7 +441,7 @@ def get_encoding_from_headers(headers):
if 'charset' in params:
return params['charset'].strip("'\"")
if 'text' in content_type:
if 'typing.Text' in content_type:
return 'ISO-8859-1'
@@ -465,50 +469,13 @@ def iter_slices(string, slice_length):
pos += slice_length
def get_unicode_from_response(r):
"""Returns the requested content back in unicode.
:param r: Response object to get unicode content from.
Tried:
1. charset from content-type
2. fall back and replace all unicode characters
:rtype: str
"""
warnings.warn(
(
'In requests 3.0, get_unicode_from_response will be removed. For '
'more information, please see the discussion on issue #2266. (This'
' warning should only appear once.)'
),
DeprecationWarning,
)
tried_encodings = []
# Try charset from content-type
encoding = get_encoding_from_headers(r.headers)
if encoding:
try:
return str(r.content, encoding)
except UnicodeError:
tried_encodings.append(encoding)
# Fall back:
try:
return str(r.content, encoding, errors='replace')
except TypeError:
return r.content
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)
def unquote_unreserved(uri):
def unquote_unreserved(uri: str) -> str:
"""Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
@@ -550,7 +517,7 @@ def unquote_unreserved(uri):
return base.join(parts)
def requote_uri(uri):
def requote_uri(uri: str) -> str:
"""Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
@@ -573,7 +540,7 @@ def requote_uri(uri):
return quote(uri, safe=safe_without_percent)
def address_in_network(ip, net):
def address_in_network(ip: str, net: str) -> bool:
"""This function allows you to check if an IP belongs to a network subnet
Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24
@@ -590,7 +557,7 @@ def address_in_network(ip, net):
return ( ipaddr & netmask) == ( network & netmask)
def dotted_netmask(mask):
def dotted_netmask(mask: str) -> str:
"""Converts mask from /xx format to xxx.xxx.xxx.xxx
Example: if mask is 24 function returns 255.255.255.0
@@ -601,7 +568,7 @@ def dotted_netmask(mask):
return socket.inet_ntoa(struct.pack('>I', bits))
def is_ipv4_address(string_ip):
def is_ipv4_address(string_ip: str) -> bool:
"""
:rtype: bool
"""
@@ -613,7 +580,7 @@ def is_ipv4_address(string_ip):
return True
def is_valid_cidr(string_network):
def is_valid_cidr(string_network: str) -> bool:
"""
Very simple check of the cidr format in no_proxy variable.
@@ -640,7 +607,7 @@ def is_valid_cidr(string_network):
@contextlib.contextmanager
def set_environ(env_name, value):
def set_environ(env_name: str, value: typing.Optional[str]) -> None:
"""Set the environment variable 'env_name' to 'value'
Save previous value, yield, and then restore the previous value stored in
@@ -662,7 +629,7 @@ def set_environ(env_name, value):
os.environ[env_name] = old_value
def should_bypass_proxies(url, no_proxy):
def should_bypass_proxies(url: str, no_proxy: bool) -> bool:
"""
Returns whether we should bypass proxies or not.
@@ -706,7 +673,9 @@ def should_bypass_proxies(url, no_proxy):
return bool(proxy_bypass(netloc))
def get_environ_proxies(url, no_proxy=None):
def get_environ_proxies(
url: str, no_proxy: typing.Optional[bool] = None
) -> dict:
"""
Return a dict of environment proxies.
@@ -719,7 +688,7 @@ def get_environ_proxies(url, no_proxy=None):
return getproxies()
def select_proxy(url, proxies):
def select_proxy(url: str, proxies: typing.Optional[typing.MutableMapping[typing.Text, typing.Text]]):
"""Select a proxy for the url, if applicable.
:param url: The url being for the request
@@ -745,7 +714,7 @@ def select_proxy(url, proxies):
return proxy
def default_user_agent(name="python-requests"):
def default_user_agent(name: str = "python-requests") -> str:
"""
Return a string representing the default user agent.
@@ -754,11 +723,11 @@ def default_user_agent(name="python-requests"):
return '%s/%s' % (name, __version__)
def default_headers():
def default_headers() -> HTTPHeaderDict:
"""
:rtype: requests.structures.CaseInsensitiveDict
:rtype: requests.structures.HTTPHeaderDict
"""
return CaseInsensitiveDict(
return HTTPHeaderDict(
{
'User-Agent': default_user_agent(),
'Accept-Encoding': ', '.join(('gzip', 'deflate')),
@@ -768,7 +737,7 @@ def default_headers():
)
def parse_header_links(value):
def parse_header_links(value: str) -> typing.List[typing.MutableMapping]:
"""Return a list of parsed link headers proxies.
i.e. Link: <http:/.../front.jpeg>; rel=front; type="image/jpeg",<http://.../back.jpeg>; rel=back;type="image/jpeg"
@@ -798,7 +767,7 @@ def parse_header_links(value):
return links
def is_valid_location(response):
def is_valid_location(response: 'Response') -> bool:
"""Verify that multiple Location headers weren't
returned from the last response.
"""
@@ -818,7 +787,7 @@ _null2 = _null * 2
_null3 = _null * 3
def guess_json_utf(data):
def guess_json_utf(data: bytes) -> typing.Optional[str]:
"""
:rtype: str
"""
@@ -858,7 +827,7 @@ def guess_json_utf(data):
return None
def prepend_scheme_if_needed(url, new_scheme):
def prepend_scheme_if_needed(url: str, new_scheme: str) -> str:
"""Given a URL that may or may not have a scheme, prepend the given scheme.
Does not replace a present scheme with the one provided as an argument.
@@ -873,7 +842,7 @@ def prepend_scheme_if_needed(url, new_scheme):
return urlunparse((scheme, netloc, path, params, query, fragment))
def get_auth_from_url(url):
def get_auth_from_url(url: str) -> typing.Tuple[typing.Text, typing.Text]:
"""Given a url with authentication components, extract them into a tuple of
username,password.
@@ -892,7 +861,7 @@ _CLEAN_HEADER_REGEX_BYTE = re.compile(b'^\\S[^\\r\\n]*$|^$')
_CLEAN_HEADER_REGEX_STR = re.compile(r'^\S[^\r\n]*$|^$')
def check_header_validity(header):
def check_header_validity(header: typing.Tuple[typing.Text, typing.Text]) -> None:
"""Verifies that header value is a string which doesn't contain
leading whitespace or return characters. This prevents unintended
header injection.
@@ -918,7 +887,7 @@ def check_header_validity(header):
)
def urldefragauth(url):
def urldefragauth(url: str) -> str:
"""
Given a url remove the fragment and the authentication part.
@@ -932,7 +901,7 @@ def urldefragauth(url):
return urlunparse((scheme, netloc, path, params, query, ''))
def rewind_body(prepared_request):
def rewind_body(prepared_request: 'PreparedRequest') -> None:
"""Move file pointer back to its recorded starting position
so it can be read again on redirect.
"""
@@ -954,7 +923,7 @@ def rewind_body(prepared_request):
)
def is_stream(data):
def is_stream(data: bytes) -> bool:
"""Given data, determines if it should be sent as a stream."""
is_iterable = getattr(data, '__iter__', False)
is_io_type = not isinstance(