Add valdation for header name (#6154)

This commit is contained in:
Nate Prewitt
2022-06-08 12:03:56 -06:00
committed by GitHub
parent 60865f21ae
commit e36f34597c
3 changed files with 74 additions and 64 deletions
+11
View File
@@ -5,9 +5,20 @@ requests._internal_utils
Provides utility functions that are consumed internally by Requests
which depend on extremely few external helpers (such as compat)
"""
import re
from .compat import builtin_str
_VALID_HEADER_NAME_RE_BYTE = re.compile(rb"^[^:\s][^:\r\n]*$")
_VALID_HEADER_NAME_RE_STR = re.compile(r"^[^:\s][^:\r\n]*$")
_VALID_HEADER_VALUE_RE_BYTE = re.compile(rb"^\S[^\r\n]*$|^$")
_VALID_HEADER_VALUE_RE_STR = re.compile(r"^\S[^\r\n]*$|^$")
HEADER_VALIDATORS = {
bytes: (_VALID_HEADER_NAME_RE_BYTE, _VALID_HEADER_VALUE_RE_BYTE),
str: (_VALID_HEADER_NAME_RE_STR, _VALID_HEADER_VALUE_RE_STR),
}
def to_native_string(string, encoding="ascii"):
"""Given a string object, regardless of type, returns a representation of
+16 -19
View File
@@ -25,7 +25,7 @@ from . import certs
from .__version__ import __version__
# to_native_string is unused here, but imported here for backwards compatibility
from ._internal_utils import to_native_string # noqa: F401
from ._internal_utils import HEADER_VALIDATORS, to_native_string # noqa: F401
from .compat import (
Mapping,
basestring,
@@ -1024,33 +1024,30 @@ def get_auth_from_url(url):
return auth
# Moved outside of function to avoid recompile every call
_CLEAN_HEADER_REGEX_BYTE = re.compile(b"^\\S[^\\r\\n]*$|^$")
_CLEAN_HEADER_REGEX_STR = re.compile(r"^\S[^\r\n]*$|^$")
def check_header_validity(header):
"""Verifies that header value is a string which doesn't contain
leading whitespace or return characters. This prevents unintended
header injection.
"""Verifies that header parts don't contain leading whitespace
reserved characters, or return characters.
:param header: tuple, in the format (name, value).
"""
name, value = header
if isinstance(value, bytes):
pat = _CLEAN_HEADER_REGEX_BYTE
else:
pat = _CLEAN_HEADER_REGEX_STR
try:
if not pat.match(value):
for part in header:
if type(part) not in HEADER_VALIDATORS:
raise InvalidHeader(
f"Invalid return character or leading space in header: {name}"
f"Header part ({part!r}) from {{{name!r}: {value!r}}} must be "
f"of type str or bytes, not {type(part)}"
)
except TypeError:
_validate_header_part(name, "name", HEADER_VALIDATORS[type(name)][0])
_validate_header_part(value, "value", HEADER_VALIDATORS[type(value)][1])
def _validate_header_part(header_part, header_kind, validator):
if not validator.match(header_part):
raise InvalidHeader(
f"Value for header {{{name}: {value}}} must be of type "
f"str or bytes, not {type(value)}"
f"Invalid leading whitespace, reserved character(s), or return"
f"character(s) in header {header_kind}: {header_part!r}"
)
+47 -45
View File
@@ -1096,7 +1096,7 @@ class TestRequests:
def test_custom_content_type(self, httpbin):
with open(__file__, "rb") as f1:
with open(__file__, "rb") as f2:
data={"stuff": json.dumps({"a": 123})}
data = {"stuff": json.dumps({"a": 123})}
files = {
"file1": ("test_requests.py", f1),
"file2": ("test_requests", f2, "text/py-content-type"),
@@ -1682,68 +1682,70 @@ class TestRequests:
def test_header_validation(self, httpbin):
"""Ensure prepare_headers regex isn't flagging valid header contents."""
headers_ok = {
valid_headers = {
"foo": "bar baz qux",
"bar": b"fbbq",
"baz": "",
"qux": "1",
}
r = requests.get(httpbin("get"), headers=headers_ok)
assert r.request.headers["foo"] == headers_ok["foo"]
r = requests.get(httpbin("get"), headers=valid_headers)
for key in valid_headers.keys():
valid_headers[key] == r.request.headers[key]
def test_header_value_not_str(self, httpbin):
@pytest.mark.parametrize(
"invalid_header, key",
(
({"foo": 3}, "foo"),
({"bar": {"foo": "bar"}}, "bar"),
({"baz": ["foo", "bar"]}, "baz"),
),
)
def test_header_value_not_str(self, httpbin, invalid_header, key):
"""Ensure the header value is of type string or bytes as
per discussion in GH issue #3386
"""
headers_int = {"foo": 3}
headers_dict = {"bar": {"foo": "bar"}}
headers_list = {"baz": ["foo", "bar"]}
with pytest.raises(InvalidHeader) as excinfo:
requests.get(httpbin("get"), headers=invalid_header)
assert key in str(excinfo.value)
# Test for int
with pytest.raises(InvalidHeader) as excinfo:
requests.get(httpbin("get"), headers=headers_int)
assert "foo" in str(excinfo.value)
# Test for dict
with pytest.raises(InvalidHeader) as excinfo:
requests.get(httpbin("get"), headers=headers_dict)
assert "bar" in str(excinfo.value)
# Test for list
with pytest.raises(InvalidHeader) as excinfo:
requests.get(httpbin("get"), headers=headers_list)
assert "baz" in str(excinfo.value)
def test_header_no_return_chars(self, httpbin):
@pytest.mark.parametrize(
"invalid_header",
(
{"foo": "bar\r\nbaz: qux"},
{"foo": "bar\n\rbaz: qux"},
{"foo": "bar\nbaz: qux"},
{"foo": "bar\rbaz: qux"},
{"fo\ro": "bar"},
{"fo\r\no": "bar"},
{"fo\n\ro": "bar"},
{"fo\no": "bar"},
),
)
def test_header_no_return_chars(self, httpbin, invalid_header):
"""Ensure that a header containing return character sequences raise an
exception. Otherwise, multiple headers are created from single string.
"""
headers_ret = {"foo": "bar\r\nbaz: qux"}
headers_lf = {"foo": "bar\nbaz: qux"}
headers_cr = {"foo": "bar\rbaz: qux"}
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=invalid_header)
# Test for newline
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=headers_ret)
# Test for line feed
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=headers_lf)
# Test for carriage return
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=headers_cr)
def test_header_no_leading_space(self, httpbin):
@pytest.mark.parametrize(
"invalid_header",
(
{" foo": "bar"},
{"\tfoo": "bar"},
{" foo": "bar"},
{"foo": " bar"},
{"foo": " bar"},
{"foo": "\tbar"},
{" ": "bar"},
),
)
def test_header_no_leading_space(self, httpbin, invalid_header):
"""Ensure headers containing leading whitespace raise
InvalidHeader Error before sending.
"""
headers_space = {"foo": " bar"}
headers_tab = {"foo": " bar"}
# Test for whitespace
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=headers_space)
# Test for tab
with pytest.raises(InvalidHeader):
requests.get(httpbin("get"), headers=headers_tab)
requests.get(httpbin("get"), headers=invalid_header)
@pytest.mark.parametrize("files", ("foo", b"foo", bytearray(b"foo")))
def test_can_send_objects_with_files(self, httpbin, files):