mirror of
https://github.com/kennethreitz/requests.git
synced 2026-06-05 22:50:18 +00:00
Rewrite CaseInsensitiveDict to work correctly/sanely
Fixes #649 and #1329 by making Session.headers a CaseInsensitiveDict, and fixing the implementation of CID. Credit for the brilliant idea to map `lowercased_key -> (cased_key, mapped_value)` goes to @gazpachoking, thanks a bunch. Changes from original implementation of CaseInsensitiveDict: 1. CID is rewritten as a subclass of `collections.MutableMapping`. 2. CID remembers the case of the last-set key, but `__setitem__` and `__delitem__` will handle keys without respect to case. 3. CID returns the key case as remembered for the `keys`, `items`, and `__iter__` methods. 4. Query operations (`__getitem__` and `__contains__`) are done in a case-insensitive manner: `cid['foo']` and `cid['FOO']` will return the same value. 5. The constructor as well as `update` and `__eq__` have undefined behavior when given multiple keys that have the same `lower()`. 6. The new method `lower_items` is like `iteritems`, but keys are all lowercased. 7. CID raises `KeyError` for `__getitem__` as normal dicts do. The old implementation returned 6. The `__repr__` now makes it obvious that it's not a normal dict. See PR #1333 for the discussions that lead up to this implementation
This commit is contained in:
@@ -124,3 +124,4 @@ Patches and Suggestions
|
||||
- Wilfred Hughes <me@wilfred.me.uk> @dontYetKnow
|
||||
- Dmitry Medvinsky <me@dmedvinsky.name>
|
||||
- Bryce Boe <bbzbryce@gmail.com> @bboe
|
||||
- Colin Dunklau <colin.dunklau@gmail.com> @cdunklau
|
||||
|
||||
+64
-27
@@ -9,6 +9,7 @@ Data structures that power Requests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import collections
|
||||
from itertools import islice
|
||||
|
||||
|
||||
@@ -33,43 +34,79 @@ class IteratorProxy(object):
|
||||
return "".join(islice(self.i, None, n))
|
||||
|
||||
|
||||
class CaseInsensitiveDict(dict):
|
||||
"""Case-insensitive Dictionary
|
||||
class CaseInsensitiveDict(collections.MutableMapping):
|
||||
"""
|
||||
A case-insensitive ``dict``-like object.
|
||||
|
||||
Implements all methods and operations of
|
||||
``collections.MutableMapping`` as well as dict's ``copy``. Also
|
||||
provides ``lower_items``.
|
||||
|
||||
All keys are expected to be strings. The structure remembers the
|
||||
case of the last key to be set, and ``iter(instance)``,
|
||||
``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
|
||||
will contain case-sensitive keys. However, querying and contains
|
||||
testing is case insensitive:
|
||||
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Accept'] = 'application/json'
|
||||
cid['aCCEPT'] == 'application/json' # True
|
||||
list(cid) == ['Accept'] # True
|
||||
|
||||
For example, ``headers['content-encoding']`` will return the
|
||||
value of a ``'Content-Encoding'`` response header."""
|
||||
value of a ``'Content-Encoding'`` response header, regardless
|
||||
of how the header name was originally stored.
|
||||
|
||||
@property
|
||||
def lower_keys(self):
|
||||
if not hasattr(self, '_lower_keys') or not self._lower_keys:
|
||||
self._lower_keys = dict((k.lower(), k) for k in list(self.keys()))
|
||||
return self._lower_keys
|
||||
If the constructor, ``.update``, or equality comparison
|
||||
operations are given keys that have equal ``.lower()``s, the
|
||||
behavior is undefined.
|
||||
|
||||
def _clear_lower_keys(self):
|
||||
if hasattr(self, '_lower_keys'):
|
||||
self._lower_keys.clear()
|
||||
"""
|
||||
def __init__(self, data=None, **kwargs):
|
||||
self._store = dict()
|
||||
if data is None:
|
||||
data = {}
|
||||
self.update(data, **kwargs)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
dict.__setitem__(self, key, value)
|
||||
self._clear_lower_keys()
|
||||
|
||||
def __delitem__(self, key):
|
||||
dict.__delitem__(self, self.lower_keys.get(key.lower(), key))
|
||||
self._lower_keys.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
return key.lower() in self.lower_keys
|
||||
# Use the lowercased key for lookups, but store the actual
|
||||
# key alongside the value.
|
||||
self._store[key.lower()] = (key, value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# We allow fall-through here, so values default to None
|
||||
if key in self:
|
||||
return dict.__getitem__(self, self.lower_keys[key.lower()])
|
||||
return self._store[key.lower()][1]
|
||||
|
||||
def get(self, key, default=None):
|
||||
if key in self:
|
||||
return self[key]
|
||||
def __delitem__(self, key):
|
||||
del self._store[key.lower()]
|
||||
|
||||
def __iter__(self):
|
||||
return (casedkey for casedkey, mappedvalue in self._store.values())
|
||||
|
||||
def __len__(self):
|
||||
return len(self._store)
|
||||
|
||||
def lower_items(self):
|
||||
"""Like iteritems(), but with all lowercase keys."""
|
||||
return (
|
||||
(lowerkey, keyval[1])
|
||||
for (lowerkey, keyval)
|
||||
in self._store.items()
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, collections.Mapping):
|
||||
other = CaseInsensitiveDict(other)
|
||||
else:
|
||||
return default
|
||||
return NotImplemented
|
||||
# Compare insensitively
|
||||
return dict(self.lower_items()) == dict(other.lower_items())
|
||||
|
||||
# Copy is required
|
||||
def copy(self):
|
||||
return CaseInsensitiveDict(self._store.values())
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r)' % (self.__class__.__name__, dict(self.items()))
|
||||
|
||||
|
||||
class LookupDict(dict):
|
||||
|
||||
+3
-2
@@ -23,6 +23,7 @@ from . import certs
|
||||
from .compat import parse_http_list as _parse_list_header
|
||||
from .compat import quote, urlparse, bytes, str, OrderedDict, urlunparse
|
||||
from .cookies import RequestsCookieJar, cookiejar_from_dict
|
||||
from .structures import CaseInsensitiveDict
|
||||
|
||||
_hush_pyflakes = (RequestsCookieJar,)
|
||||
|
||||
@@ -449,11 +450,11 @@ def default_user_agent():
|
||||
|
||||
|
||||
def default_headers():
|
||||
return {
|
||||
return CaseInsensitiveDict({
|
||||
'User-Agent': default_user_agent(),
|
||||
'Accept-Encoding': ', '.join(('gzip', 'deflate', 'compress')),
|
||||
'Accept': '*/*'
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def parse_header_links(value):
|
||||
|
||||
@@ -13,6 +13,7 @@ import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from requests.compat import str, cookielib
|
||||
from requests.cookies import cookiejar_from_dict
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
|
||||
try:
|
||||
import StringIO
|
||||
@@ -458,6 +459,165 @@ class RequestsTestCase(unittest.TestCase):
|
||||
r = s.send(r.prepare())
|
||||
self.assertEqual(r.status_code, 200)
|
||||
|
||||
def test_fixes_1329(self):
|
||||
s = requests.Session()
|
||||
s.headers.update({'accept': 'application/json'})
|
||||
r = s.get(httpbin('get'))
|
||||
headers = r.request.headers
|
||||
# ASCII encode because of key comparison changes in py3
|
||||
self.assertEqual(
|
||||
headers['accept'.encode('ascii')],
|
||||
'application/json'
|
||||
)
|
||||
self.assertEqual(
|
||||
headers['Accept'.encode('ascii')],
|
||||
'application/json'
|
||||
)
|
||||
|
||||
|
||||
class TestCaseInsensitiveDict(unittest.TestCase):
|
||||
|
||||
def test_mapping_init(self):
|
||||
cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
|
||||
self.assertEqual(len(cid), 2)
|
||||
self.assertTrue('foo' in cid)
|
||||
self.assertTrue('bar' in cid)
|
||||
|
||||
def test_iterable_init(self):
|
||||
cid = CaseInsensitiveDict([('Foo', 'foo'), ('BAr', 'bar')])
|
||||
self.assertEqual(len(cid), 2)
|
||||
self.assertTrue('foo' in cid)
|
||||
self.assertTrue('bar' in cid)
|
||||
|
||||
def test_kwargs_init(self):
|
||||
cid = CaseInsensitiveDict(FOO='foo', BAr='bar')
|
||||
self.assertEqual(len(cid), 2)
|
||||
self.assertTrue('foo' in cid)
|
||||
self.assertTrue('bar' in cid)
|
||||
|
||||
def test_docstring_example(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Accept'] = 'application/json'
|
||||
self.assertEqual(cid['aCCEPT'], 'application/json')
|
||||
self.assertEqual(list(cid), ['Accept'])
|
||||
|
||||
def test_len(self):
|
||||
cid = CaseInsensitiveDict({'a': 'a', 'b': 'b'})
|
||||
cid['A'] = 'a'
|
||||
self.assertEqual(len(cid), 2)
|
||||
|
||||
def test_getitem(self):
|
||||
cid = CaseInsensitiveDict({'Spam': 'blueval'})
|
||||
self.assertEqual(cid['spam'], 'blueval')
|
||||
self.assertEqual(cid['SPAM'], 'blueval')
|
||||
|
||||
def test_fixes_649(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['spam'] = 'oneval'
|
||||
cid['Spam'] = 'twoval'
|
||||
cid['sPAM'] = 'redval'
|
||||
cid['SPAM'] = 'blueval'
|
||||
self.assertEqual(cid['spam'], 'blueval')
|
||||
self.assertEqual(cid['SPAM'], 'blueval')
|
||||
self.assertEqual(list(cid.keys()), ['SPAM'])
|
||||
|
||||
def test_delitem(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Spam'] = 'someval'
|
||||
del cid['sPam']
|
||||
self.assertFalse('spam' in cid)
|
||||
self.assertEqual(len(cid), 0)
|
||||
|
||||
def test_contains(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Spam'] = 'someval'
|
||||
self.assertTrue('Spam' in cid)
|
||||
self.assertTrue('spam' in cid)
|
||||
self.assertTrue('SPAM' in cid)
|
||||
self.assertTrue('sPam' in cid)
|
||||
self.assertFalse('notspam' in cid)
|
||||
|
||||
def test_get(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['spam'] = 'oneval'
|
||||
cid['SPAM'] = 'blueval'
|
||||
self.assertEqual(cid.get('spam'), 'blueval')
|
||||
self.assertEqual(cid.get('SPAM'), 'blueval')
|
||||
self.assertEqual(cid.get('sPam'), 'blueval')
|
||||
self.assertEqual(cid.get('notspam', 'default'), 'default')
|
||||
|
||||
def test_update(self):
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['spam'] = 'blueval'
|
||||
cid.update({'sPam': 'notblueval'})
|
||||
self.assertEqual(cid['spam'], 'notblueval')
|
||||
cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
|
||||
cid.update({'fOO': 'anotherfoo', 'bAR': 'anotherbar'})
|
||||
self.assertEqual(len(cid), 2)
|
||||
self.assertEqual(cid['foo'], 'anotherfoo')
|
||||
self.assertEqual(cid['bar'], 'anotherbar')
|
||||
|
||||
def test_update_retains_unchanged(self):
|
||||
cid = CaseInsensitiveDict({'foo': 'foo', 'bar': 'bar'})
|
||||
cid.update({'foo': 'newfoo'})
|
||||
self.assertEquals(cid['bar'], 'bar')
|
||||
|
||||
def test_iter(self):
|
||||
cid = CaseInsensitiveDict({'Spam': 'spam', 'Eggs': 'eggs'})
|
||||
keys = frozenset(['Spam', 'Eggs'])
|
||||
self.assertEqual(frozenset(iter(cid)), keys)
|
||||
|
||||
def test_equality(self):
|
||||
cid = CaseInsensitiveDict({'SPAM': 'blueval', 'Eggs': 'redval'})
|
||||
othercid = CaseInsensitiveDict({'spam': 'blueval', 'eggs': 'redval'})
|
||||
self.assertEqual(cid, othercid)
|
||||
del othercid['spam']
|
||||
self.assertNotEqual(cid, othercid)
|
||||
self.assertEqual(cid, {'spam': 'blueval', 'eggs': 'redval'})
|
||||
|
||||
def test_setdefault(self):
|
||||
cid = CaseInsensitiveDict({'Spam': 'blueval'})
|
||||
self.assertEqual(
|
||||
cid.setdefault('spam', 'notblueval'),
|
||||
'blueval'
|
||||
)
|
||||
self.assertEqual(
|
||||
cid.setdefault('notspam', 'notblueval'),
|
||||
'notblueval'
|
||||
)
|
||||
|
||||
def test_lower_items(self):
|
||||
cid = CaseInsensitiveDict({
|
||||
'Accept': 'application/json',
|
||||
'user-Agent': 'requests',
|
||||
})
|
||||
keyset = frozenset(lowerkey for lowerkey, v in cid.lower_items())
|
||||
lowerkeyset = frozenset(['accept', 'user-agent'])
|
||||
self.assertEqual(keyset, lowerkeyset)
|
||||
|
||||
def test_preserve_key_case(self):
|
||||
cid = CaseInsensitiveDict({
|
||||
'Accept': 'application/json',
|
||||
'user-Agent': 'requests',
|
||||
})
|
||||
keyset = frozenset(['Accept', 'user-Agent'])
|
||||
self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
|
||||
self.assertEqual(frozenset(cid.keys()), keyset)
|
||||
self.assertEqual(frozenset(cid), keyset)
|
||||
|
||||
def test_preserve_last_key_case(self):
|
||||
cid = CaseInsensitiveDict({
|
||||
'Accept': 'application/json',
|
||||
'user-Agent': 'requests',
|
||||
})
|
||||
cid.update({'ACCEPT': 'application/json'})
|
||||
cid['USER-AGENT'] = 'requests'
|
||||
keyset = frozenset(['ACCEPT', 'USER-AGENT'])
|
||||
self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
|
||||
self.assertEqual(frozenset(cid.keys()), keyset)
|
||||
self.assertEqual(frozenset(cid), keyset)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user