Rewrite CaseInsensitiveDict to work correctly/sanely

Fixes #649 and #1329 by making Session.headers a CaseInsensitiveDict,
and fixing the implementation of CID. Credit for the brilliant idea
to map `lowercased_key -> (cased_key, mapped_value)` goes to
@gazpachoking, thanks a bunch.

Changes from original implementation of CaseInsensitiveDict:

1.  CID is rewritten as a subclass of `collections.MutableMapping`.
2.  CID remembers the case of the last-set key, but `__setitem__`
    and `__delitem__` will handle keys without respect to case.
3.  CID returns the key case as remembered for the `keys`, `items`,
    and `__iter__` methods.
4.  Query operations (`__getitem__` and `__contains__`) are done in
    a case-insensitive manner: `cid['foo']` and `cid['FOO']` will
    return the same value.
5.  The constructor as well as `update` and `__eq__` have undefined
    behavior when given multiple keys that have the same `lower()`.
6.  The new method `lower_items` is like `iteritems`, but keys are
    all lowercased.
7.  CID raises `KeyError` for `__getitem__` as normal dicts do. The
    old implementation returned
6.  The `__repr__` now makes it obvious that it's not a normal dict.

See PR #1333 for the discussions that lead up to this implementation
This commit is contained in:
Colin Dunklau
2013-04-30 14:52:27 -05:00
parent ab36f3cc6f
commit f7596c75dc
4 changed files with 228 additions and 29 deletions
+1
View File
@@ -124,3 +124,4 @@ Patches and Suggestions
- Wilfred Hughes <me@wilfred.me.uk> @dontYetKnow
- Dmitry Medvinsky <me@dmedvinsky.name>
- Bryce Boe <bbzbryce@gmail.com> @bboe
- Colin Dunklau <colin.dunklau@gmail.com> @cdunklau
+64 -27
View File
@@ -9,6 +9,7 @@ Data structures that power Requests.
"""
import os
import collections
from itertools import islice
@@ -33,43 +34,79 @@ class IteratorProxy(object):
return "".join(islice(self.i, None, n))
class CaseInsensitiveDict(dict):
"""Case-insensitive Dictionary
class CaseInsensitiveDict(collections.MutableMapping):
"""
A case-insensitive ``dict``-like object.
Implements all methods and operations of
``collections.MutableMapping`` as well as dict's ``copy``. Also
provides ``lower_items``.
All keys are expected to be strings. The structure remembers the
case of the last key to be set, and ``iter(instance)``,
``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
will contain case-sensitive keys. However, querying and contains
testing is case insensitive:
cid = CaseInsensitiveDict()
cid['Accept'] = 'application/json'
cid['aCCEPT'] == 'application/json' # True
list(cid) == ['Accept'] # True
For example, ``headers['content-encoding']`` will return the
value of a ``'Content-Encoding'`` response header."""
value of a ``'Content-Encoding'`` response header, regardless
of how the header name was originally stored.
@property
def lower_keys(self):
if not hasattr(self, '_lower_keys') or not self._lower_keys:
self._lower_keys = dict((k.lower(), k) for k in list(self.keys()))
return self._lower_keys
If the constructor, ``.update``, or equality comparison
operations are given keys that have equal ``.lower()``s, the
behavior is undefined.
def _clear_lower_keys(self):
if hasattr(self, '_lower_keys'):
self._lower_keys.clear()
"""
def __init__(self, data=None, **kwargs):
self._store = dict()
if data is None:
data = {}
self.update(data, **kwargs)
def __setitem__(self, key, value):
dict.__setitem__(self, key, value)
self._clear_lower_keys()
def __delitem__(self, key):
dict.__delitem__(self, self.lower_keys.get(key.lower(), key))
self._lower_keys.clear()
def __contains__(self, key):
return key.lower() in self.lower_keys
# Use the lowercased key for lookups, but store the actual
# key alongside the value.
self._store[key.lower()] = (key, value)
def __getitem__(self, key):
# We allow fall-through here, so values default to None
if key in self:
return dict.__getitem__(self, self.lower_keys[key.lower()])
return self._store[key.lower()][1]
def get(self, key, default=None):
if key in self:
return self[key]
def __delitem__(self, key):
del self._store[key.lower()]
def __iter__(self):
return (casedkey for casedkey, mappedvalue in self._store.values())
def __len__(self):
return len(self._store)
def lower_items(self):
"""Like iteritems(), but with all lowercase keys."""
return (
(lowerkey, keyval[1])
for (lowerkey, keyval)
in self._store.items()
)
def __eq__(self, other):
if isinstance(other, collections.Mapping):
other = CaseInsensitiveDict(other)
else:
return default
return NotImplemented
# Compare insensitively
return dict(self.lower_items()) == dict(other.lower_items())
# Copy is required
def copy(self):
return CaseInsensitiveDict(self._store.values())
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, dict(self.items()))
class LookupDict(dict):
+3 -2
View File
@@ -23,6 +23,7 @@ from . import certs
from .compat import parse_http_list as _parse_list_header
from .compat import quote, urlparse, bytes, str, OrderedDict, urlunparse
from .cookies import RequestsCookieJar, cookiejar_from_dict
from .structures import CaseInsensitiveDict
_hush_pyflakes = (RequestsCookieJar,)
@@ -449,11 +450,11 @@ def default_user_agent():
def default_headers():
return {
return CaseInsensitiveDict({
'User-Agent': default_user_agent(),
'Accept-Encoding': ', '.join(('gzip', 'deflate', 'compress')),
'Accept': '*/*'
}
})
def parse_header_links(value):
+160
View File
@@ -13,6 +13,7 @@ import requests
from requests.auth import HTTPDigestAuth
from requests.compat import str, cookielib
from requests.cookies import cookiejar_from_dict
from requests.structures import CaseInsensitiveDict
try:
import StringIO
@@ -458,6 +459,165 @@ class RequestsTestCase(unittest.TestCase):
r = s.send(r.prepare())
self.assertEqual(r.status_code, 200)
def test_fixes_1329(self):
s = requests.Session()
s.headers.update({'accept': 'application/json'})
r = s.get(httpbin('get'))
headers = r.request.headers
# ASCII encode because of key comparison changes in py3
self.assertEqual(
headers['accept'.encode('ascii')],
'application/json'
)
self.assertEqual(
headers['Accept'.encode('ascii')],
'application/json'
)
class TestCaseInsensitiveDict(unittest.TestCase):
def test_mapping_init(self):
cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
self.assertEqual(len(cid), 2)
self.assertTrue('foo' in cid)
self.assertTrue('bar' in cid)
def test_iterable_init(self):
cid = CaseInsensitiveDict([('Foo', 'foo'), ('BAr', 'bar')])
self.assertEqual(len(cid), 2)
self.assertTrue('foo' in cid)
self.assertTrue('bar' in cid)
def test_kwargs_init(self):
cid = CaseInsensitiveDict(FOO='foo', BAr='bar')
self.assertEqual(len(cid), 2)
self.assertTrue('foo' in cid)
self.assertTrue('bar' in cid)
def test_docstring_example(self):
cid = CaseInsensitiveDict()
cid['Accept'] = 'application/json'
self.assertEqual(cid['aCCEPT'], 'application/json')
self.assertEqual(list(cid), ['Accept'])
def test_len(self):
cid = CaseInsensitiveDict({'a': 'a', 'b': 'b'})
cid['A'] = 'a'
self.assertEqual(len(cid), 2)
def test_getitem(self):
cid = CaseInsensitiveDict({'Spam': 'blueval'})
self.assertEqual(cid['spam'], 'blueval')
self.assertEqual(cid['SPAM'], 'blueval')
def test_fixes_649(self):
cid = CaseInsensitiveDict()
cid['spam'] = 'oneval'
cid['Spam'] = 'twoval'
cid['sPAM'] = 'redval'
cid['SPAM'] = 'blueval'
self.assertEqual(cid['spam'], 'blueval')
self.assertEqual(cid['SPAM'], 'blueval')
self.assertEqual(list(cid.keys()), ['SPAM'])
def test_delitem(self):
cid = CaseInsensitiveDict()
cid['Spam'] = 'someval'
del cid['sPam']
self.assertFalse('spam' in cid)
self.assertEqual(len(cid), 0)
def test_contains(self):
cid = CaseInsensitiveDict()
cid['Spam'] = 'someval'
self.assertTrue('Spam' in cid)
self.assertTrue('spam' in cid)
self.assertTrue('SPAM' in cid)
self.assertTrue('sPam' in cid)
self.assertFalse('notspam' in cid)
def test_get(self):
cid = CaseInsensitiveDict()
cid['spam'] = 'oneval'
cid['SPAM'] = 'blueval'
self.assertEqual(cid.get('spam'), 'blueval')
self.assertEqual(cid.get('SPAM'), 'blueval')
self.assertEqual(cid.get('sPam'), 'blueval')
self.assertEqual(cid.get('notspam', 'default'), 'default')
def test_update(self):
cid = CaseInsensitiveDict()
cid['spam'] = 'blueval'
cid.update({'sPam': 'notblueval'})
self.assertEqual(cid['spam'], 'notblueval')
cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
cid.update({'fOO': 'anotherfoo', 'bAR': 'anotherbar'})
self.assertEqual(len(cid), 2)
self.assertEqual(cid['foo'], 'anotherfoo')
self.assertEqual(cid['bar'], 'anotherbar')
def test_update_retains_unchanged(self):
cid = CaseInsensitiveDict({'foo': 'foo', 'bar': 'bar'})
cid.update({'foo': 'newfoo'})
self.assertEquals(cid['bar'], 'bar')
def test_iter(self):
cid = CaseInsensitiveDict({'Spam': 'spam', 'Eggs': 'eggs'})
keys = frozenset(['Spam', 'Eggs'])
self.assertEqual(frozenset(iter(cid)), keys)
def test_equality(self):
cid = CaseInsensitiveDict({'SPAM': 'blueval', 'Eggs': 'redval'})
othercid = CaseInsensitiveDict({'spam': 'blueval', 'eggs': 'redval'})
self.assertEqual(cid, othercid)
del othercid['spam']
self.assertNotEqual(cid, othercid)
self.assertEqual(cid, {'spam': 'blueval', 'eggs': 'redval'})
def test_setdefault(self):
cid = CaseInsensitiveDict({'Spam': 'blueval'})
self.assertEqual(
cid.setdefault('spam', 'notblueval'),
'blueval'
)
self.assertEqual(
cid.setdefault('notspam', 'notblueval'),
'notblueval'
)
def test_lower_items(self):
cid = CaseInsensitiveDict({
'Accept': 'application/json',
'user-Agent': 'requests',
})
keyset = frozenset(lowerkey for lowerkey, v in cid.lower_items())
lowerkeyset = frozenset(['accept', 'user-agent'])
self.assertEqual(keyset, lowerkeyset)
def test_preserve_key_case(self):
cid = CaseInsensitiveDict({
'Accept': 'application/json',
'user-Agent': 'requests',
})
keyset = frozenset(['Accept', 'user-Agent'])
self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
self.assertEqual(frozenset(cid.keys()), keyset)
self.assertEqual(frozenset(cid), keyset)
def test_preserve_last_key_case(self):
cid = CaseInsensitiveDict({
'Accept': 'application/json',
'user-Agent': 'requests',
})
cid.update({'ACCEPT': 'application/json'})
cid['USER-AGENT'] = 'requests'
keyset = frozenset(['ACCEPT', 'USER-AGENT'])
self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
self.assertEqual(frozenset(cid.keys()), keyset)
self.assertEqual(frozenset(cid), keyset)
if __name__ == '__main__':
unittest.main()