Move nested function up to module level and rename. Add more tests for function.

This commit is contained in:
dbairaktaris1
2018-01-01 14:20:55 -06:00
parent 2dc51c887f
commit 1988d9cf72
2 changed files with 67 additions and 16 deletions
+31 -15
View File
@@ -446,33 +446,49 @@ def get_encodings_from_content(content):
xml_re.findall(content))
def _parse_content_type_header(header):
"""Returns content type and parameters from given header
:param header: string
:return: tuple containing content type and dictionary of
parameters
"""
if not header:
return None
# append delimiter on end to ensure at least two elements when split by ';'
header += ';'
# split content type's main value from params
tokens = header.split(';', 1)
content_type_index = 0
params_index = 1
content_type = tokens[content_type_index].strip()
params = tokens[params_index]
params_dict = dict()
for param in params.split(';'):
if param and not param.isspace():
param = param.strip()
key, value = param, True
if '=' in param:
param_tokens = [x.strip('\'" ') for x in param.split('=', 1)]
key, value = param_tokens[0], param_tokens[1]
params_dict[key] = value
return content_type, params_dict
def get_encoding_from_headers(headers):
"""Returns encodings from given HTTP Header Dict.
:param headers: dictionary to extract encoding from.
:rtype: str
"""
def parse_header(content_type):
#Inner function to parse header
#append delimiter on end to ensure atleast two elements when split by ';'
content_type_and_params_delimiter = ';'
content_type += content_type_and_params_delimiter
tokens = content_type.split(content_type_and_params_delimiter)
content_type_index = 0
params_index = 1
content_type = tokens[content_type_index]
params = tokens[params_index]
params_dict = dict(param.split('=') for param in params.split())
return content_type,params_dict
content_type = headers.get('content-type')
if not content_type:
return None
content_type, params = parse_header(content_type)
content_type, params = _parse_content_type_header(content_type)
if 'charset' in params:
return params['charset'].strip("'\"")
+36 -1
View File
@@ -13,7 +13,7 @@ from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
from requests.utils import (
address_in_network, dotted_netmask, extract_zipped_paths,
get_auth_from_url, get_encoding_from_headers,
get_auth_from_url, _parse_content_type_header, get_encoding_from_headers,
get_encodings_from_content, get_environ_proxies,
guess_filename, guess_json_utf, is_ipv4_address,
is_valid_cidr, iter_slices, parse_dict_header,
@@ -470,6 +470,41 @@ def test_parse_dict_header(value, expected):
assert parse_dict_header(value) == expected
@pytest.mark.parametrize(
'value, expected', (
(
None,
None
),
(
'',
None
),
(
'application/xml',
('application/xml', dict())
),
(
'application/json ; charset=utf-8',
('application/json', {'charset': 'utf-8'})
),
(
'text/plain',
('text/plain', dict())
),
(
'multipart/form-data; boundary = something ; \'boundary2=something_else\' ; no_equals ',
('multipart/form-data', {'boundary': 'something', 'boundary2': 'something_else', 'no_equals': True})
),
(
'application/json ;; ; ',
('application/json', dict())
)
))
def test__parse_content_type_header(value, expected):
assert _parse_content_type_header(value) == expected
@pytest.mark.parametrize(
'value, expected', (
(