From 21896aa171b2d6da4c3f69a802ca43597aad2f80 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Fri, 22 Feb 2019 03:45:51 +0100 Subject: [PATCH] Support stream response --- responder/models.py | 30 ++++++++++++++++++++++++++---- tests/test_responder.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/responder/models.py b/responder/models.py index 5b164df..3e3b2fc 100644 --- a/responder/models.py +++ b/responder/models.py @@ -1,4 +1,6 @@ +import functools import io +import inspect import json import gzip from base64 import b64decode @@ -13,7 +15,10 @@ from requests.structures import CaseInsensitiveDict from requests.cookies import RequestsCookieJar from starlette.datastructures import MutableHeaders from starlette.requests import Request as StarletteRequest -from starlette.responses import Response as StarletteResponse +from starlette.responses import ( + Response as StarletteResponse, + StreamingResponse as StarletteStreamingResponse, +) from urllib.parse import parse_qs @@ -244,6 +249,7 @@ class Response: "formats", "cookies", "session", + "_stream", ] def __init__(self, req, *, formats): @@ -255,6 +261,7 @@ class Response: self.media = ( None ) #: A Python object that will be content-negotiated and sent back to the client. Typically, in JSON formatting. + self._stream = None self.headers = ( {} ) #: A Python dictionary of ``{key: value}``, representing the headers of the response. @@ -264,8 +271,19 @@ class Response: req.session.copy() ) #: The cookie-based session data, in dict form, to add to the Response. + # Property or func/dec + def stream(self, func, *args, **kwargs): + assert inspect.isasyncgenfunction(func) + + self._stream = functools.partial(func, *args, **kwargs) + + return func + @property async def body(self): + if self._stream is not None: + return (self._stream(), {}) + if self.content is not None: return (self.content, {}) @@ -287,7 +305,11 @@ class Response: if self.headers: headers.update(self.headers) - response = StarletteResponse( - body, status_code=self.status_code, headers=headers - ) + if self._stream is not None: + response_cls = StarletteStreamingResponse + else: + response_cls = StarletteResponse + + response = response_cls(body, status_code=self.status_code, headers=headers) + await response(receive, send) diff --git a/tests/test_responder.py b/tests/test_responder.py index 7048d1b..81efc64 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -680,3 +680,41 @@ def test_staticfiles_custom_route(tmpdir): # Not found on dir listing r = session.get(f"{static_route}") assert r.status_code == api.status_codes.HTTP_404 + + +def test_stream(api, session): + async def shout_stream(who): + for c in who.upper(): + yield c + + @api.route("/{who}") + async def greeting(req, resp, *, who): + + resp.stream(shout_stream, who) + + r = session.get("/morocco") + assert r.text == "MOROCCO" + + @api.route("/") + async def home(req, resp): + # Raise when it's not an async generator + with pytest.raises(AssertionError): + + def foo(): + pass + + res.stream(foo) + + with pytest.raises(AssertionError): + + async def foo(): + pass + + res.stream(foo) + + with pytest.raises(AssertionError): + + def foo(): + yield "oopsie" + + res.stream(foo)