Refactor merge_kwargs for clarity and to fix a few bugs

This commit is contained in:
Chase Sterling
2013-05-02 12:46:59 -04:00
parent 8eb4243f12
commit 98114245c6
3 changed files with 47 additions and 45 deletions
+31 -43
View File
@@ -9,14 +9,16 @@ requests (cookies, auth, proxies).
"""
import os
from collections import Mapping
from datetime import datetime
from .compat import cookielib, OrderedDict, urljoin, urlparse
from .cookies import cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar
from .models import Request, PreparedRequest
from .hooks import default_hooks, dispatch_hook
from .utils import from_key_val_list, default_headers
from .utils import to_key_val_list, default_headers
from .exceptions import TooManyRedirects, InvalidSchema
from .structures import CaseInsensitiveDict
from .adapters import HTTPAdapter
@@ -32,49 +34,35 @@ REDIRECT_STATI = (
DEFAULT_REDIRECT_LIMIT = 30
def merge_kwargs(local_kwarg, default_kwarg):
"""Merges kwarg dictionaries.
If a local key in the dictionary is set to None, it will be removed.
def merge_setting(request_setting, session_setting, dict_class=OrderedDict):
"""
Determines appropriate setting for a given request, taking into account the
explicit setting on that request, and the setting in the session. If a
setting is a dictionary, they will be merged together using `dict_class`
"""
if default_kwarg is None:
return local_kwarg
if session_setting is None:
return request_setting
if isinstance(local_kwarg, str):
return local_kwarg
if request_setting is None:
return session_setting
if local_kwarg is None:
return default_kwarg
# Bypass if not a dictionary (e.g. verify)
if not (
isinstance(session_setting, Mapping) and
isinstance(request_setting, Mapping)
):
return request_setting
# Bypass if not a dictionary (e.g. timeout)
if not hasattr(default_kwarg, 'items'):
return local_kwarg
default_kwarg = from_key_val_list(default_kwarg)
local_kwarg = from_key_val_list(local_kwarg)
# Update new values in a case-insensitive way
def get_original_key(original_keys, new_key):
"""
Finds the key from original_keys that case-insensitive matches new_key.
"""
for original_key in original_keys:
if key.lower() == original_key.lower():
return original_key
return new_key
kwargs = default_kwarg.copy()
original_keys = kwargs.keys()
for key, value in local_kwarg.items():
kwargs[get_original_key(original_keys, key)] = value
merged_setting = dict_class(to_key_val_list(session_setting))
merged_setting.update(to_key_val_list(request_setting))
# Remove keys that are set to None.
for (k, v) in local_kwarg.items():
for (k, v) in request_setting.items():
if v is None:
del kwargs[k]
del merged_setting[k]
return kwargs
return merged_setting
class SessionRedirectMixin(object):
@@ -311,14 +299,14 @@ class Session(SessionRedirectMixin):
verify = os.environ.get('CURL_CA_BUNDLE')
# Merge all the kwargs.
params = merge_kwargs(params, self.params)
headers = merge_kwargs(headers, self.headers)
auth = merge_kwargs(auth, self.auth)
proxies = merge_kwargs(proxies, self.proxies)
hooks = merge_kwargs(hooks, self.hooks)
stream = merge_kwargs(stream, self.stream)
verify = merge_kwargs(verify, self.verify)
cert = merge_kwargs(cert, self.cert)
params = merge_setting(params, self.params)
headers = merge_setting(headers, self.headers, dict_class=CaseInsensitiveDict)
auth = merge_setting(auth, self.auth)
proxies = merge_setting(proxies, self.proxies)
hooks = merge_setting(hooks, self.hooks)
stream = merge_setting(stream, self.stream)
verify = merge_setting(verify, self.verify)
cert = merge_setting(cert, self.cert)
# Create the Request.
req = Request()
+2 -2
View File
@@ -11,11 +11,11 @@ that are also useful for external consumption.
import cgi
import codecs
import collections
import os
import platform
import re
import sys
import zlib
from netrc import netrc, NetrcParseError
from . import __version__
@@ -135,7 +135,7 @@ def to_key_val_list(value):
if isinstance(value, (str, bytes, bool, int)):
raise ValueError('cannot encode objects that are not 2-tuples')
if isinstance(value, dict):
if isinstance(value, collections.Mapping):
value = value.items()
return list(value)
+14
View File
@@ -521,6 +521,20 @@ class RequestsTestCase(unittest.TestCase):
self.assertTrue('http://' in s2.adapters)
self.assertTrue('https://' in s2.adapters)
def test_header_remove_is_case_insensitive(self):
# From issue #1321
s = requests.Session()
s.headers['foo'] = 'bar'
r = s.get(httpbin('get'), headers={'FOO': None})
assert 'foo' not in r.request.headers
def test_params_are_merged_case_sensitive(self):
s = requests.Session()
s.params['foo'] = 'bar'
r = s.get(httpbin('get'), params={'FOO': 'bar'})
assert r.json()['args'] == {'foo': 'bar', 'FOO': 'bar'}
def test_long_authinfo_in_url(self):
url = 'http://{0}:{1}@{2}:9000/path?query#frag'.format(
'E8A3BE87-9E3F-4620-8858-95478E385B5B',