From 009b80c95a85b44b5f8b7be7d4ef08e39b18c40e Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Sun, 21 May 2017 16:40:19 -0700 Subject: [PATCH] persist session-level CookiePolicy --- requests/cookies.py | 2 +- requests/sessions.py | 7 ++++--- tests/test_requests.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/requests/cookies.py b/requests/cookies.py index 6484af6b..7268184a 100644 --- a/requests/cookies.py +++ b/requests/cookies.py @@ -412,7 +412,7 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping): def copy(self): """Return a copy of this RequestsCookieJar.""" - new_cj = RequestsCookieJar() + new_cj = RequestsCookieJar(self._policy) new_cj.update(self) return new_cj diff --git a/requests/sessions.py b/requests/sessions.py index 49037c0c..b54fa391 100755 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -16,7 +16,8 @@ from datetime import timedelta from .auth import _basic_auth_str from .compat import cookielib, OrderedDict, urljoin, urlparse, is_py3, str from .cookies import ( - cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar, merge_cookies) + cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar, + merge_cookies, _copy_cookie_jar) from .models import Request, PreparedRequest, DEFAULT_REDIRECT_LIMIT from .hooks import default_hooks, dispatch_hook from ._internal_utils import to_native_string @@ -425,8 +426,8 @@ class Session(SessionRedirectMixin): cookies = cookiejar_from_dict(cookies) # Merge with session cookies - merged_cookies = merge_cookies( - merge_cookies(RequestsCookieJar(), self.cookies), cookies) + session_cookies = _copy_cookie_jar(self.cookies) + merged_cookies = merge_cookies(session_cookies, cookies) # Set environment's basic authentication if not explicitly set. auth = request.auth diff --git a/tests/test_requests.py b/tests/test_requests.py index 89f48e2b..460fe4a7 100755 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -482,6 +482,35 @@ class TestRequests: assert cookies['foo'] == 'bar' assert cookies['cookie'] == 'tasty' + @pytest.mark.parametrize( + 'jar', ( + requests.cookies.RequestsCookieJar(), + cookielib.CookieJar() + )) + def test_custom_cookie_policy_persistence(self, httpbin, jar): + """Verify a custom CookiePolicy is propagated on each session request.""" + + class TestCookiePolicy(cookielib.DefaultCookiePolicy): + """Policy to restrict all cookies from localhost (127.0.0.1).""" + def __init__(self): + cookielib.DefaultCookiePolicy.__init__(self, blocked_domains=['127.0.0.1']) + + # Establish session with jar and set some cookies. + s = requests.Session() + s.cookies = jar + s.get(httpbin('cookies/set?k1=v1&k2=v2')) + assert len(s.cookies) == 2 + + # Set different policy. + s.cookies.set_policy(TestCookiePolicy()) + assert isinstance(s.cookies._policy, TestCookiePolicy) + + # No cookies were sent to our blocked domain and none were set. + resp = s.get(httpbin('cookies/set?k3=v3')) + assert 'Cookie' not in resp.request.headers + assert len(s.cookies) == 2 + assert 'k3' not in s.cookies + def test_requests_in_history_are_not_overridden(self, httpbin): resp = requests.get(httpbin('redirect/3')) urls = [r.url for r in resp.history]