mirror of
https://github.com/kennethreitz/responder.git
synced 2026-06-05 14:50:19 +00:00
Support stream response
This commit is contained in:
+26
-4
@@ -1,4 +1,6 @@
|
||||
import functools
|
||||
import io
|
||||
import inspect
|
||||
import json
|
||||
import gzip
|
||||
from base64 import b64decode
|
||||
@@ -13,7 +15,10 @@ from requests.structures import CaseInsensitiveDict
|
||||
from requests.cookies import RequestsCookieJar
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
from starlette.responses import (
|
||||
Response as StarletteResponse,
|
||||
StreamingResponse as StarletteStreamingResponse,
|
||||
)
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -244,6 +249,7 @@ class Response:
|
||||
"formats",
|
||||
"cookies",
|
||||
"session",
|
||||
"_stream",
|
||||
]
|
||||
|
||||
def __init__(self, req, *, formats):
|
||||
@@ -255,6 +261,7 @@ class Response:
|
||||
self.media = (
|
||||
None
|
||||
) #: A Python object that will be content-negotiated and sent back to the client. Typically, in JSON formatting.
|
||||
self._stream = None
|
||||
self.headers = (
|
||||
{}
|
||||
) #: A Python dictionary of ``{key: value}``, representing the headers of the response.
|
||||
@@ -264,8 +271,19 @@ class Response:
|
||||
req.session.copy()
|
||||
) #: The cookie-based session data, in dict form, to add to the Response.
|
||||
|
||||
# Property or func/dec
|
||||
def stream(self, func, *args, **kwargs):
|
||||
assert inspect.isasyncgenfunction(func)
|
||||
|
||||
self._stream = functools.partial(func, *args, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
@property
|
||||
async def body(self):
|
||||
if self._stream is not None:
|
||||
return (self._stream(), {})
|
||||
|
||||
if self.content is not None:
|
||||
return (self.content, {})
|
||||
|
||||
@@ -287,7 +305,11 @@ class Response:
|
||||
if self.headers:
|
||||
headers.update(self.headers)
|
||||
|
||||
response = StarletteResponse(
|
||||
body, status_code=self.status_code, headers=headers
|
||||
)
|
||||
if self._stream is not None:
|
||||
response_cls = StarletteStreamingResponse
|
||||
else:
|
||||
response_cls = StarletteResponse
|
||||
|
||||
response = response_cls(body, status_code=self.status_code, headers=headers)
|
||||
|
||||
await response(receive, send)
|
||||
|
||||
@@ -680,3 +680,41 @@ def test_staticfiles_custom_route(tmpdir):
|
||||
# Not found on dir listing
|
||||
r = session.get(f"{static_route}")
|
||||
assert r.status_code == api.status_codes.HTTP_404
|
||||
|
||||
|
||||
def test_stream(api, session):
|
||||
async def shout_stream(who):
|
||||
for c in who.upper():
|
||||
yield c
|
||||
|
||||
@api.route("/{who}")
|
||||
async def greeting(req, resp, *, who):
|
||||
|
||||
resp.stream(shout_stream, who)
|
||||
|
||||
r = session.get("/morocco")
|
||||
assert r.text == "MOROCCO"
|
||||
|
||||
@api.route("/")
|
||||
async def home(req, resp):
|
||||
# Raise when it's not an async generator
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
def foo():
|
||||
pass
|
||||
|
||||
res.stream(foo)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
async def foo():
|
||||
pass
|
||||
|
||||
res.stream(foo)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
def foo():
|
||||
yield "oopsie"
|
||||
|
||||
res.stream(foo)
|
||||
|
||||
Reference in New Issue
Block a user