diff --git a/requests/utils.py b/requests/utils.py index dea323ef..397a655e 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -739,8 +739,9 @@ _CLEAN_HEADER_REGEX_BYTE = re.compile(b'^\\S[^\\r\\n]*$|^$') _CLEAN_HEADER_REGEX_STR = re.compile(r'^\S[^\r\n]*$|^$') def check_header_validity(header): - """Verifies that header value doesn't contain leading whitespace or - return characters. This prevents unintended header injection. + """Verifies that header value is a string which doesn't contain + leading whitespace or return characters. This prevents unintended + header injection. :param header: tuple, in the format (name, value). """ @@ -750,8 +751,12 @@ def check_header_validity(header): pat = _CLEAN_HEADER_REGEX_BYTE else: pat = _CLEAN_HEADER_REGEX_STR - if not pat.match(value): - raise InvalidHeader("Invalid return character or leading space in header: %s" % name) + try: + if not pat.match(value): + raise InvalidHeader("Invalid return character or leading space in header: %s" % name) + except TypeError: + raise InvalidHeader("Header value %s must be of type str or bytes, " + "not %s" % (value, type(value))) def urldefragauth(url): """ diff --git a/tests/test_requests.py b/tests/test_requests.py index 4250a8f9..a7d3a75b 100755 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1142,15 +1142,33 @@ class TestRequests: assert 'unicode' in p.headers.keys() assert 'byte' in p.headers.keys() - def test_header_validation(self,httpbin): + def test_header_validation(self, httpbin): """Ensure prepare_headers regex isn't flagging valid header contents.""" headers_ok = {'foo': 'bar baz qux', - 'bar': '1', + 'bar': u'fbbq'.encode('utf8'), 'baz': '', - 'qux': str.encode(u'fbbq')} + 'qux': '1'} r = requests.get(httpbin('get'), headers=headers_ok) assert r.request.headers['foo'] == headers_ok['foo'] + def test_header_value_not_str(self, httpbin): + """Ensure the header value is of type string or bytes as + per discussion in GH issue #3386 + """ + headers_int = {'foo': 3} + headers_dict = {'bar': {'foo':'bar'}} + headers_list = {'baz': ['foo', 'bar']} + + # Test for int + with pytest.raises(InvalidHeader): + r = requests.get(httpbin('get'), headers=headers_int) + # Test for dict + with pytest.raises(InvalidHeader): + r = requests.get(httpbin('get'), headers=headers_dict) + # Test for list + with pytest.raises(InvalidHeader): + r = requests.get(httpbin('get'), headers=headers_list) + def test_header_no_return_chars(self, httpbin): """Ensure that a header containing return character sequences raise an exception. Otherwise, multiple headers are created from single string.