diff --git a/tests/test_lowlevel.py b/tests/test_lowlevel.py new file mode 100644 index 00000000..eb6d6273 --- /dev/null +++ b/tests/test_lowlevel.py @@ -0,0 +1,19 @@ +import threading +import requests + +from tests.testserver.server import Server + + +def test_chunked_upload(): + """can safely send generators""" + close_server = threading.Event() + server = Server.basic_response_server(wait_to_close_event=close_server) + data = (i for i in [b'a', b'b', b'c']) + + with server as (host, port): + url = 'http://{0}:{1}/'.format(host, port) + r = requests.post(url, data=data, stream=True) + close_server.set() # release server block + + assert r.status_code == 200 + assert r.request.headers['Transfer-Encoding'] == 'chunked' diff --git a/tests/test_testserver.py b/tests/test_testserver.py new file mode 100644 index 00000000..027f8e50 --- /dev/null +++ b/tests/test_testserver.py @@ -0,0 +1,137 @@ +import threading +import socket +import time + +import pytest +import requests +from tests.testserver.server import Server + +class TestTestServer: + def test_basic(self): + """messages are sent and received properly""" + question = b"sucess?" + answer = b"yeah, success" + def handler(sock): + text = sock.recv(1000) + assert text == question + sock.sendall(answer) + + with Server(handler) as (host, port): + sock = socket.socket() + sock.connect((host, port)) + sock.sendall(question) + text = sock.recv(1000) + assert text == answer + sock.close() + + def test_server_closes(self): + """the server closes when leaving the context manager""" + with Server.basic_response_server() as (host, port): + sock = socket.socket() + sock.connect((host, port)) + + sock.close() + + with pytest.raises(socket.error): + new_sock = socket.socket() + new_sock.connect((host, port)) + + def test_text_response(self): + """the text_response_server sends the given text""" + server = Server.text_response_server( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 6\r\n" + + "\r\nroflol" + ) + + with server as (host, port): + r = requests.get('http://{0}:{1}'.format(host, port)) + + assert r.status_code == 200 + assert r.text == u'roflol' + assert r.headers['Content-Length'] == '6' + + def test_basic_response(self): + """the basic response server returns an empty http response""" + with Server.basic_response_server() as (host, port): + r = requests.get('http://{0}:{1}'.format(host, port)) + assert r.status_code == 200 + assert r.text == u'' + assert r.headers['Content-Length'] == '0' + + def test_basic_waiting_server(self): + """the server waits for the block_server event to be set before closing""" + block_server = threading.Event() + + with Server.basic_response_server(wait_to_close_event=block_server) as (host, port): + sock = socket.socket() + sock.connect((host, port)) + sock.sendall(b'send something') + time.sleep(2.5) + sock.sendall(b'still alive') + block_server.set() # release server block + + def test_multiple_requests(self): + """multiple requests can be served""" + requests_to_handle = 5 + + server = Server.basic_response_server(requests_to_handle=requests_to_handle) + + with server as (host, port): + server_url = 'http://{0}:{1}'.format(host, port) + for _ in range(requests_to_handle): + r = requests.get(server_url) + assert r.status_code == 200 + + # the (n+1)th request fails + with pytest.raises(requests.exceptions.ConnectionError): + r = requests.get(server_url) + + def test_request_recovery(self): + """can check the requests content""" + server = Server.basic_response_server(requests_to_handle=2) + first_request = b'put your hands up in the air' + second_request = b'put your hand down in the floor' + + with server as address: + sock1 = socket.socket() + sock2 = socket.socket() + + sock1.connect(address) + sock1.sendall(first_request) + sock1.close() + + sock2.connect(address) + sock2.sendall(second_request) + sock2.close() + + assert server.handler_results[0] == first_request + assert server.handler_results[1] == second_request + + def test_requests_after_timeout_are_not_received(self): + """the basic response handler times out when receiving requests""" + server = Server.basic_response_server(request_timeout=1) + + with server as address: + sock = socket.socket() + sock.connect(address) + time.sleep(1.5) + sock.sendall(b'hehehe, not received') + sock.close() + + assert server.handler_results[0] == b'' + + + def test_request_recovery_with_bigger_timeout(self): + """a biggest timeout can be specified""" + server = Server.basic_response_server(request_timeout=3) + data = b'bananadine' + + with server as address: + sock = socket.socket() + sock.connect(address) + time.sleep(1.5) + sock.sendall(data) + sock.close() + + assert server.handler_results[0] == data diff --git a/tests/testserver/__init__.py b/tests/testserver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/testserver/server.py b/tests/testserver/server.py new file mode 100644 index 00000000..8b9643c3 --- /dev/null +++ b/tests/testserver/server.py @@ -0,0 +1,102 @@ +import threading +import socket +import select + + +def consume_socket_content(sock, timeout=0.5): + chunks = 65536 + content = b'' + more_to_read = select.select([sock], [], [], timeout)[0] + + while more_to_read: + new_content = sock.recv(chunks) + + if not new_content: + break + + content += new_content + # stop reading if no new data is received for a while + more_to_read = select.select([sock], [], [], timeout)[0] + + return content + + +class Server(threading.Thread): + """Dummy server using for unit testing""" + WAIT_EVENT_TIMEOUT = 5 + + def __init__(self, handler, host='localhost', port=0, requests_to_handle=1, wait_to_close_event=None): + super(Server, self).__init__() + + self.handler = handler + self.handler_results = [] + + self.host = host + self.port = port + self.requests_to_handle = requests_to_handle + + self.wait_to_close_event = wait_to_close_event + self.ready_event = threading.Event() + self.stop_event = threading.Event() + + @classmethod + def text_response_server(cls, text, request_timeout=0.5, **kwargs): + def text_response_handler(sock): + request_content = consume_socket_content(sock, timeout=request_timeout) + sock.send(text.encode('utf-8')) + + return request_content + + + return Server(text_response_handler, **kwargs) + + @classmethod + def basic_response_server(cls, **kwargs): + return cls.text_response_server( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n\r\n", + **kwargs + ) + + def run(self): + try: + sock = self._create_socket_and_bind() + # in case self.port = 0 + self.port = sock.getsockname()[1] + self.ready_event.set() + self._handle_requests(sock) + + if self.wait_to_close_event: + self.wait_to_close_event.wait(self.WAIT_EVENT_TIMEOUT) + finally: + self.ready_event.set() # just in case of exception + sock.close() + self.stop_event.set() + + def _create_socket_and_bind(self): + sock = socket.socket() + sock.bind((self.host, self.port)) + sock.listen(0) + return sock + + def _handle_requests(self, server_sock): + for _ in range(self.requests_to_handle): + sock = server_sock.accept()[0] + handler_result = self.handler(sock) + + self.handler_results.append(handler_result) + + def __enter__(self): + self.start() + self.ready_event.wait(self.WAIT_EVENT_TIMEOUT) + return self.host, self.port + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.stop_event.wait(self.WAIT_EVENT_TIMEOUT) + else: + if self.wait_to_close_event: + # avoid server from waiting for event timeouts + # if an exception is found in the main thread + self.wait_to_close_event.set() + return False # allow exceptions to propagate