From 6d154b0c788e8e4d448ff722419b97a290185b2c Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sat, 17 Nov 2018 20:49:26 +0100 Subject: [PATCH 1/2] WIP websocket --- responder/api.py | 50 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/responder/api.py b/responder/api.py index 005c153..4273c32 100644 --- a/responder/api.py +++ b/responder/api.py @@ -220,14 +220,40 @@ class API: async def asgi(receive, send): nonlocal scope, self - req = models.Request(scope, receive=receive, api=self) - resp = await self._dispatch_request( - req, scope=scope, send=send, receive=receive - ) - await resp(receive, send) + if scope["type"] == "websocket": + ws = WebSocket(scope=scope, receive=receive, send=send) + await self._dispatch_ws(ws) + else: + req = models.Request(scope, receive=receive, api=self) + resp = await self._dispatch_request( + req, scope=scope, send=send, receive=receive + ) + await resp(receive, send) return asgi + async def _dispatch_ws(self, ws): + route = self.path_matches_route(ws.url.path) + route = self.routes.get(route) + # await self._dispatch(route, ws=ws) + try: + try: + # Run the view. + + r = self.background(route.endpoint, ws) + # If it's async, await it. + if hasattr(r, "cr_running"): + await r + except TypeError as e: + cont = True + except Exception: + self.background( + self.default_response, + websocket=route.uses_websocket, + error=True + ) + raise + def add_schema(self, name, schema, check_existing=True): """Adds a mashmallow schema to the API specification.""" if check_existing: @@ -284,7 +310,7 @@ class API: def no_response(req, resp, **params): pass - async def _dispatch_request(self, req, **options): + async def _dispatch_request(self, req=None, **options): # Set formats on Request object. req.formats = self.formats @@ -303,8 +329,8 @@ class API: await self._execute_route(route=route, req=req, resp=resp, **options) else: resp = models.Response(req=req, formats=self.formats) - self.default_response(req, resp, notfound=True) - self.default_response(req, resp) + self.default_response(req=req, resp=resp, notfound=True) + self.default_response(req=req, resp=resp) self._prepare_session(resp) self._prepare_cookies(resp) @@ -319,7 +345,6 @@ class API: try: try: # Run the view. - r = self.background(route.endpoint, req, resp, **params) # If it's async, await it. if hasattr(r, "cr_running"): @@ -414,12 +439,15 @@ class API: sorted(self.routes.items(), key=lambda item: item[1]._weight()) ) - def default_response(self, req, resp, notfound=False, error=False): + def default_response(self, req=None, resp=None, websocket=False, notfound=False, error=False): + if websocket: + return + if resp.status_code is None: resp.status_code = 200 if self.default_endpoint and notfound: - self.default_endpoint(req, resp) + self.default_endpoint(req=req, resp=resp) else: if notfound: resp.status_code = status_codes.HTTP_404 From 983cbcc71166ffda1685b4d41b503cd1c6535683 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sat, 17 Nov 2018 21:31:41 +0100 Subject: [PATCH 2/2] cleanup --- responder/api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/responder/api.py b/responder/api.py index 4273c32..8156b30 100644 --- a/responder/api.py +++ b/responder/api.py @@ -239,7 +239,6 @@ class API: try: try: # Run the view. - r = self.background(route.endpoint, ws) # If it's async, await it. if hasattr(r, "cr_running"): @@ -310,7 +309,7 @@ class API: def no_response(req, resp, **params): pass - async def _dispatch_request(self, req=None, **options): + async def _dispatch_request(self, req, **options): # Set formats on Request object. req.formats = self.formats @@ -318,10 +317,7 @@ class API: route = self.path_matches_route(req.url.path) route = self.routes.get(route) if route: - if route.uses_websocket: - resp = WebSocket(**options) - else: - resp = models.Response(req=req, formats=self.formats) + resp = models.Response(req=req, formats=self.formats) for before_request in self.before_requests: await self._execute_route(route=before_request, req=req, resp=resp) @@ -439,7 +435,9 @@ class API: sorted(self.routes.items(), key=lambda item: item[1]._weight()) ) - def default_response(self, req=None, resp=None, websocket=False, notfound=False, error=False): + def default_response( + self, req=None, resp=None, websocket=False, notfound=False, error=False + ): if websocket: return