Support stream response

This commit is contained in:
taoufik07
2019-02-22 03:45:51 +01:00
parent d60b5ee39e
commit 21896aa171
2 changed files with 64 additions and 4 deletions
+26 -4
View File
@@ -1,4 +1,6 @@
import functools
import io
import inspect
import json
import gzip
from base64 import b64decode
@@ -13,7 +15,10 @@ from requests.structures import CaseInsensitiveDict
from requests.cookies import RequestsCookieJar
from starlette.datastructures import MutableHeaders
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from starlette.responses import (
Response as StarletteResponse,
StreamingResponse as StarletteStreamingResponse,
)
from urllib.parse import parse_qs
@@ -244,6 +249,7 @@ class Response:
"formats",
"cookies",
"session",
"_stream",
]
def __init__(self, req, *, formats):
@@ -255,6 +261,7 @@ class Response:
self.media = (
None
) #: A Python object that will be content-negotiated and sent back to the client. Typically, in JSON formatting.
self._stream = None
self.headers = (
{}
) #: A Python dictionary of ``{key: value}``, representing the headers of the response.
@@ -264,8 +271,19 @@ class Response:
req.session.copy()
) #: The cookie-based session data, in dict form, to add to the Response.
# Property or func/dec
def stream(self, func, *args, **kwargs):
assert inspect.isasyncgenfunction(func)
self._stream = functools.partial(func, *args, **kwargs)
return func
@property
async def body(self):
if self._stream is not None:
return (self._stream(), {})
if self.content is not None:
return (self.content, {})
@@ -287,7 +305,11 @@ class Response:
if self.headers:
headers.update(self.headers)
response = StarletteResponse(
body, status_code=self.status_code, headers=headers
)
if self._stream is not None:
response_cls = StarletteStreamingResponse
else:
response_cls = StarletteResponse
response = response_cls(body, status_code=self.status_code, headers=headers)
await response(receive, send)
+38
View File
@@ -680,3 +680,41 @@ def test_staticfiles_custom_route(tmpdir):
# Not found on dir listing
r = session.get(f"{static_route}")
assert r.status_code == api.status_codes.HTTP_404
def test_stream(api, session):
async def shout_stream(who):
for c in who.upper():
yield c
@api.route("/{who}")
async def greeting(req, resp, *, who):
resp.stream(shout_stream, who)
r = session.get("/morocco")
assert r.text == "MOROCCO"
@api.route("/")
async def home(req, resp):
# Raise when it's not an async generator
with pytest.raises(AssertionError):
def foo():
pass
res.stream(foo)
with pytest.raises(AssertionError):
async def foo():
pass
res.stream(foo)
with pytest.raises(AssertionError):
def foo():
yield "oopsie"
res.stream(foo)