diff --git a/requests/api.py b/requests/api.py index 9cea79af..f192b8f5 100644 --- a/requests/api.py +++ b/requests/api.py @@ -38,10 +38,19 @@ def request(method, url, **kwargs): :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair. """ - s = kwargs.pop('session') if 'session' in kwargs else sessions.session() - return s.request(method=method, url=url, **kwargs) - + # if this session was passed in, leave it open (and retain pooled connections); + # if we're making it just for this call, then close it when we're done. + adhoc_session = False + session = kwargs.pop('session', None) + if session is None: + session = sessions.session() + adhoc_session = True + try: + return session.request(method=method, url=url, **kwargs) + finally: + if adhoc_session: + session.close() def get(url, **kwargs): """Sends a GET request. Returns :class:`Response` object. diff --git a/requests/models.py b/requests/models.py index 2c0c7bdd..e42ab366 100644 --- a/requests/models.py +++ b/requests/models.py @@ -57,7 +57,7 @@ class Request(object): proxies=None, hooks=None, config=None, - prefetch=False, + prefetch=True, _poolmanager=None, verify=None, session=None, @@ -458,7 +458,7 @@ class Request(object): except ValueError: return False - def send(self, anyway=False, prefetch=False): + def send(self, anyway=False, prefetch=True): """Sends the request. Returns True if successful, False if not. If there was an HTTPError during transmission, self.response.status_code will contain the HTTPError code. @@ -774,6 +774,8 @@ class Response(object): self._content = None self._content_consumed = True + # don't need to release the connection; that's been handled by urllib3 + # since we exhausted the data. return self._content @property diff --git a/requests/sessions.py b/requests/sessions.py index 3113c787..73c7b17b 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -66,7 +66,7 @@ class Session(object): hooks=None, params=None, config=None, - prefetch=False, + prefetch=True, verify=True, cert=None): @@ -105,7 +105,15 @@ class Session(object): return self def __exit__(self, *args): - pass + self.close() + + def close(self): + """Dispose of any internal state. + + Currently, this just closes the PoolManager, which closes pooled + connections. + """ + self.poolmanager.clear() def request(self, method, url, params=None, @@ -120,7 +128,7 @@ class Session(object): hooks=None, return_response=True, config=None, - prefetch=False, + prefetch=None, verify=None, cert=None): @@ -140,7 +148,7 @@ class Session(object): :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy. :param return_response: (optional) If False, an un-sent Request object will returned. :param config: (optional) A configuration dictionary. See ``request.defaults`` for allowed keys and their default values. - :param prefetch: (optional) if ``True``, the response content will be immediately downloaded. + :param prefetch: (optional) whether to immediately download the response content. Defaults to ``True``. :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided. :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair. """ @@ -153,7 +161,7 @@ class Session(object): headers = {} if headers is None else headers params = {} if params is None else params hooks = {} if hooks is None else hooks - prefetch = self.prefetch or prefetch + prefetch = prefetch if prefetch is not None else self.prefetch # use session's hooks as defaults for key, cb in list(self.hooks.items()): diff --git a/tests/informal/test_leaked_connections.py b/tests/informal/test_leaked_connections.py new file mode 100644 index 00000000..5357bf2f --- /dev/null +++ b/tests/informal/test_leaked_connections.py @@ -0,0 +1,26 @@ +""" +This is an informal test originally written by Bluehorn; +it verifies that Requests does not leak connections when +the body of the request is not read. +""" + +import gc, os, subprocess, requests, sys + +def main(): + gc.disable() + + for x in range(20): + requests.head("http://www.google.com/") + + print("Open sockets after 20 head requests:") + pid = os.getpid() + subprocess.call("lsof -p%d -a -iTCP" % (pid,), shell=True) + + gcresult = gc.collect() + print("Garbage collection result: %s" % (gcresult,)) + + print("Open sockets after garbage collection:") + subprocess.call("lsof -p%d -a -iTCP" % (pid,), shell=True) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/test_requests.py b/tests/test_requests.py index 9ddc58b4..6b7b57a3 100755 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -807,9 +807,9 @@ class RequestsTestSuite(TestSetup, TestBaseMixin, unittest.TestCase): assert 'k' in c ds1 = pickle.loads(pickle.dumps(requests.session())) - ds2 = pickle.loads(pickle.dumps(requests.session(prefetch=True))) - assert not ds1.prefetch - assert ds2.prefetch + ds2 = pickle.loads(pickle.dumps(requests.session(prefetch=False))) + assert ds1.prefetch + assert not ds2.prefetch # def test_invalid_content(self): # # WARNING: if you're using a terrible DNS provider (comcast), @@ -858,7 +858,7 @@ class RequestsTestSuite(TestSetup, TestBaseMixin, unittest.TestCase): ) # Make a request and monkey-patch its contents - r = get(httpbin('get')) + r = get(httpbin('get'), prefetch=False) r.raw = StringIO(quote) lines = list(r.iter_lines())