diff --git a/AUTHORS.rst b/AUTHORS.rst index e18d5751..bf79ef48 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -102,7 +102,7 @@ Patches and Suggestions - Roman Haritonov <@reclosedev> - Josh Imhoff - Arup Malakar -- Danilo Bargen (gwrtheyrn) +- Danilo Bargen (dbrgn) - Torsten Landschoff - Michael Holler (apotheos) - Timnit Gebru diff --git a/requests/models.py b/requests/models.py index 00d23700..045b3a58 100644 --- a/requests/models.py +++ b/requests/models.py @@ -557,7 +557,7 @@ class Request(object): proxy = self.proxies.get(_p.scheme) if proxy and not any(map(_p.hostname.endswith, no_proxy)): - conn = poolmanager.proxy_from_url(proxy) + conn = poolmanager.ProxyManager(self.get_connection_for_url(proxy)) _proxy = urlparse(proxy) if '@' in _proxy.netloc: auth, url = _proxy.netloc.split('@', 1) @@ -565,15 +565,10 @@ class Request(object): r = self.proxy_auth(self) self.__dict__.update(r.__dict__) else: - # Check to see if keep_alive is allowed. - try: - if self.config.get('keep_alive'): - conn = self._poolmanager.connection_from_url(url) - else: - conn = connectionpool.connection_from_url(url) - self.headers['Connection'] = 'close' - except LocationParseError as e: - raise InvalidURL(e) + conn = self.get_connection_for_url(url) + + if not self.config.get('keep_alive'): + self.headers['Connection'] = 'close' if url.startswith('https') and self.verify: @@ -672,6 +667,17 @@ class Request(object): return self.sent + def get_connection_for_url(self, url): + # Check to see if keep_alive is allowed. + try: + if self.config.get('keep_alive'): + conn = self._poolmanager.connection_from_url(url) + else: + conn = connectionpool.connection_from_url(url) + return conn + except LocationParseError as e: + raise InvalidURL(e) + class Response(object): """The core :class:`Response ` object. All diff --git a/tests/dummy_server.py b/tests/dummy_server.py new file mode 100644 index 00000000..1096e25b --- /dev/null +++ b/tests/dummy_server.py @@ -0,0 +1,46 @@ +import asyncore +import threading +import socket + +class HttpServer(threading.Thread): + def __init__(self, port): + threading.Thread.__init__(self) + self.dispatcher = HttpServerDispatcher(port) + + def run(self): + asyncore.loop() + + @property + def connection_count(self): + return self.dispatcher.connection_count + + def close(self): + asyncore.close_all() + +class HttpServerDispatcher(asyncore.dispatcher): + def __init__(self, port): + asyncore.dispatcher.__init__(self) + self.connected = False + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.bind(('127.0.0.1', port)) + self.listen(1) + self.connection_count = 0 + + def handle_accept(self): + self.connection_count += 1 + self.handler = RequestHandler(self.accept()[0]) + + def handle_close(self): + self.close() + + +class RequestHandler(asyncore.dispatcher_with_send): + def __init__(self, sock): + asyncore.dispatcher_with_send.__init__(self, sock) + self.response = ("HTTP/1.1 200 OK\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 0\r\n\r\n") + + def handle_read(self): + self.recv(1024) + self.send(self.response) diff --git a/tests/test_keep_alive.py b/tests/test_keep_alive.py new file mode 100644 index 00000000..447fc8e3 --- /dev/null +++ b/tests/test_keep_alive.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import sys +import unittest + +# Path hack. +sys.path.insert(0, os.path.abspath('..')) +import requests +import dummy_server + +class KeepAliveTests(unittest.TestCase): + server_and_proxy_port = 1234 + request_count = 2 + url = 'http://localhost:{0}'.format(server_and_proxy_port) + proxies={'http': url} + + def setUp(self): + self.session = requests.session() + self.proxy_server = dummy_server.HttpServer(self.server_and_proxy_port) + self.proxy_server.start() + + def tearDown(self): + self.proxy_server.close() + + def test_keep_alive_with_direct_connection(self): + self.make_requests() + self.check_each_request_are_in_same_connection() + + def test_no_keep_alive_with_direct_connection(self): + self.disable_keep_alive_in_session() + self.make_requests() + self.check_each_request_are_in_different_connection() + + def test_keep_alive_with_proxy_connection(self): + self.make_proxy_requests() + self.check_each_request_are_in_same_connection() + + def test_no_keep_alive_with_proxy_connection(self): + self.disable_keep_alive_in_session() + self.make_proxy_requests() + self.check_each_request_are_in_different_connection() + + def make_proxy_requests(self): + self.make_requests(self.proxies) + + def make_requests(self, proxies=None): + for _ in xrange(self.request_count): + self.session.get(self.url, proxies=proxies).text + + def check_each_request_are_in_same_connection(self): + """Keep-alive requests open a single connection to the server.""" + self.assertEqual(self.proxy_server.connection_count, 1) + + def check_each_request_are_in_different_connection(self): + """Keep-alive requests open a single connection to the server.""" + self.assertEqual(self.proxy_server.connection_count, self.request_count) + + def disable_keep_alive_in_session(self): + self.session.config['keep_alive'] = False + + +if __name__ == '__main__': + unittest.main()