diff --git a/requests_async/__init__.py b/requests_async/__init__.py index 2974361..adaa405 100644 --- a/requests_async/__init__.py +++ b/requests_async/__init__.py @@ -1,6 +1,7 @@ from .adapters import HTTPAdapter from .sessions import Session from .api import request, get, head, post, patch, put, delete, options +from .asgi import ASGISession __version__ = "0.1.0" __all__ = [ @@ -13,4 +14,5 @@ __all__ = [ "delete", "options", "Session", + "ASGISession" ] diff --git a/requests_async/asgi.py b/requests_async/asgi.py new file mode 100644 index 0000000..6bff52d --- /dev/null +++ b/requests_async/asgi.py @@ -0,0 +1,205 @@ +import asyncio +import http +import inspect +import io +import json +import queue +import threading +import types +import typing +from urllib.parse import unquote, urljoin, urlsplit +import requests +from .sessions import Session + + +class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key: str, default: str) -> str: + return self.getheaders(key) + + +class _MockOriginalResponse: + """ + We have to jump through some hoops to present the response as if + it was made using urllib3. + """ + + def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: + self.msg = _HeaderDict(headers) + self.closed = False + + def isclosed(self) -> bool: + return self.closed + + +def _get_reason_phrase(status_code: int) -> str: + try: + return http.HTTPStatus(status_code).phrase + except ValueError: + return "" + + +class ASGIAdapter(requests.adapters.HTTPAdapter): + def __init__(self, app, raise_server_exceptions: bool = True) -> None: + self.app = app + self.raise_server_exceptions = raise_server_exceptions + + async def send( # type: ignore + self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + ) -> requests.Response: + scheme, netloc, path, query, fragment = urlsplit(request.url) # type: ignore + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif port == default_port: + headers = [(b"host", host.encode())] + else: + headers = [(b"host", (f"{host}:{port}").encode())] + + # Include other request headers. + headers += [ + (key.lower().encode(), value.encode()) + for key, value in request.headers.items() + ] + + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } + + async def receive(): + nonlocal request_complete, response_complete + + if request_complete: + while not response_complete: + await asyncio.sleep(0.0001) + return {"type": "http.disconnect"} + + body = request.body + if isinstance(body, str): + body_bytes = body.encode("utf-8") # type: bytes + elif body is None: + body_bytes = b"" + elif isinstance(body, types.GeneratorType): + try: + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message) -> None: + nonlocal raw_kwargs, response_started, response_complete, template, context + + if message["type"] == "http.response.start": + assert ( + not response_started + ), 'Received multiple "http.response.start" messages.' + raw_kwargs["version"] = 11 + raw_kwargs["status"] = message["status"] + raw_kwargs["reason"] = _get_reason_phrase(message["status"]) + raw_kwargs["headers"] = [ + (key.decode(), value.decode()) for key, value in message["headers"] + ] + raw_kwargs["preload_content"] = False + raw_kwargs["original_response"] = _MockOriginalResponse( + raw_kwargs["headers"] + ) + response_started = True + elif message["type"] == "http.response.body": + assert ( + response_started + ), 'Received "http.response.body" without "http.response.start".' + assert ( + not response_complete + ), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["body"].write(body) + if not more_body: + raw_kwargs["body"].seek(0) + response_complete = True + elif message["type"] == "http.response.template": + template = message["template"] + context = message["context"] + + request_complete = False + response_started = False + response_complete = False + raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] + template = None + context = None + + try: + await self.app(scope, receive, send) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc from None + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "version": 11, + "status": 500, + "reason": "Internal Server Error", + "headers": [], + "preload_content": False, + "original_response": _MockOriginalResponse([]), + "body": io.BytesIO(), + } + + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) + response = self.build_response(request, raw) + if template is not None: + response.template = template + response.context = context + return response + + +class ASGISession(Session): + def __init__( + self, + app, + base_url: str = "http://mockserver", + raise_server_exceptions: bool = True, + ) -> None: + super(ASGISession, self).__init__() + adapter = ASGIAdapter( + app, raise_server_exceptions=raise_server_exceptions + ) + self.mount("http://", adapter) + self.mount("https://", adapter) + self.headers.update({"user-agent": "testclient"}) + self.app = app + self.base_url = base_url + + async def request(self, method, url, *args, **kwargs) -> requests.Response: + url = urljoin(self.base_url, url) + return await super().request(method, url, *args, **kwargs) diff --git a/requirements.txt b/requirements.txt index 59110bc..28e1e03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,5 @@ pytest pytest-asyncio pytest-cov python-multipart -starlette +starlette==0.12.0b1 uvicorn diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 0000000..bd26176 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,13 @@ +from starlette.responses import JSONResponse +import requests_async as requests +import pytest + + +app = JSONResponse({"hello": "world"}) + + +@pytest.mark.asyncio +async def test_the_test_client(): + client = requests.ASGISession(app) + response = await client.get('/') + assert response.status_code == 200