Merge pull request #233 from taoufik07/websocket-x.x

WebSocket returns
This commit is contained in:
2018-12-12 03:59:24 -05:00
committed by GitHub
+40 -14
View File
@@ -221,14 +221,39 @@ class API:
async def asgi(receive, send):
nonlocal scope, self
req = models.Request(scope, receive=receive, api=self)
resp = await self._dispatch_request(
req, scope=scope, send=send, receive=receive
)
await resp(receive, send)
if scope["type"] == "websocket":
ws = WebSocket(scope=scope, receive=receive, send=send)
await self._dispatch_ws(ws)
else:
req = models.Request(scope, receive=receive, api=self)
resp = await self._dispatch_request(
req, scope=scope, send=send, receive=receive
)
await resp(receive, send)
return asgi
async def _dispatch_ws(self, ws):
route = self.path_matches_route(ws.url.path)
route = self.routes.get(route)
# await self._dispatch(route, ws=ws)
try:
try:
# Run the view.
r = self.background(route.endpoint, ws)
# If it's async, await it.
if hasattr(r, "cr_running"):
await r
except TypeError as e:
cont = True
except Exception:
self.background(
self.default_response,
websocket=route.uses_websocket,
error=True
)
raise
def add_schema(self, name, schema, check_existing=True):
"""Adds a mashmallow schema to the API specification."""
if check_existing:
@@ -293,10 +318,7 @@ class API:
route = self.path_matches_route(req.url.path)
route = self.routes.get(route)
if route:
if route.uses_websocket:
resp = WebSocket(**options)
else:
resp = models.Response(req=req, formats=self.formats)
resp = models.Response(req=req, formats=self.formats)
for before_request in self.before_requests:
await self._execute_route(route=before_request, req=req, resp=resp)
@@ -304,8 +326,8 @@ class API:
await self._execute_route(route=route, req=req, resp=resp, **options)
else:
resp = models.Response(req=req, formats=self.formats)
self.default_response(req, resp, notfound=True)
self.default_response(req, resp)
self.default_response(req=req, resp=resp, notfound=True)
self.default_response(req=req, resp=resp)
self._prepare_session(resp)
self._prepare_cookies(resp)
@@ -322,7 +344,6 @@ class API:
try:
try:
# Run the view.
r = self.background(route.endpoint, req, resp, **params)
# If it's async, await it.
if hasattr(r, "cr_running"):
@@ -417,12 +438,17 @@ class API:
sorted(self.routes.items(), key=lambda item: item[1]._weight())
)
def default_response(self, req, resp, notfound=False, error=False):
def default_response(
self, req=None, resp=None, websocket=False, notfound=False, error=False
):
if websocket:
return
if resp.status_code is None:
resp.status_code = 200
if self.default_endpoint and notfound:
self.default_endpoint(req, resp)
self.default_endpoint(req=req, resp=resp)
else:
if notfound:
resp.status_code = status_codes.HTTP_404