diff --git a/responder/api.py b/responder/api.py index f9448ff..23f881a 100644 --- a/responder/api.py +++ b/responder/api.py @@ -210,23 +210,35 @@ class API: # Set formats on Request object. req.formats = self.formats + # Get the route. route = self.path_matches_route(req.url.path) + + # Create the response object. resp = models.Response(req=req, formats=self.formats) + # If there's a route... if route: + route = self.routes[route] try: - params = self.routes[route].incoming_matches(req.url.path) - result = self.routes[route].endpoint(req, resp, **params) - if hasattr(result, "cr_running"): - await result + if route.is_function: + params = route.incoming_matches(req.url.path) + # Run the view. + result = route.endpoint(req, resp, **params) + # If it's async, await it. + if hasattr(result, "cr_running"): + await result + # The request is using class-based views. except TypeError: try: - view = self.routes[route].endpoint(**params) + # Run the class-bsed view. + view = route.endpoint(**params) except TypeError: - view = self.routes[route].endpoint + # This is an instance of a class. + view = route.endpoint - if self.routes[route].is_graphql: + # If this is a graphql view: + if route.is_graphql: await self.graphql_response(req, resp, schema=view) else: pass @@ -276,6 +288,12 @@ class API: if default: self.default_endpoint = endpoint + + try: + endpoint.is_routed = True + except AttributeError: + pass + self.routes[route] = Route(route, endpoint) # TODO: A better datastructer or sort it once the app is loaded self.routes = dict( @@ -412,6 +430,8 @@ class API: for (route, route_object) in self.routes.items(): if route_object.endpoint == endpoint: return route_object.url(testing=testing, **params) + elif route_object.endpoint_name == endpoint: + return route_object.url(testing=testing, **params) raise ValueError def static_url(self, asset): diff --git a/responder/routes.py b/responder/routes.py index 53019f4..c89b1c5 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -31,6 +31,11 @@ class Route: # Strings. return self.does_match(other) + @property + def endpoint_name(self): + print(self.endpoint.__name__) + return self.endpoint.__name__ + @property def description(self): return self.endpoint.__doc__ @@ -67,3 +72,16 @@ class Route: @property def is_graphql(self): return hasattr(self.endpoint, "get_graphql_type") + + @property + def is_class_based(self): + return hasattr(self.endpoint, "__call__") and hasattr( + self.endpoint, "__class__" + ) + + @property + def is_class_instance(self): + return hasattr(self.endpoint, "__class__") + + def is_function(self): + return hasattr(self.endpoint, "is_routed") diff --git a/tests/test_responder.py b/tests/test_responder.py index 51659e9..d6967c8 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -395,11 +395,15 @@ def test_sessions(api, session): assert "Responder-Session" in r.cookies r = session.get(api.url_for(view)) - assert r.cookies['Responder-Session'] == '{"hello": "world"}.lJVWJULPqR9kdao_oT4pUglV281bxHfGvcKQ7XF8qNqaiIZlRcMvqKNdA1-d5z7DycAx5eqmzJZoqWPP759-Cw' + assert ( + r.cookies["Responder-Session"] + == '{"hello": "world"}.lJVWJULPqR9kdao_oT4pUglV281bxHfGvcKQ7XF8qNqaiIZlRcMvqKNdA1-d5z7DycAx5eqmzJZoqWPP759-Cw' + ) assert r.json() == {"hello": "world"} + def test_template_rendering(api, session): - @api.route('/') + @api.route("/") def view(req, resp): resp.content = api.template_string("{{ var }}", var="hello") @@ -407,3 +411,10 @@ def test_template_rendering(api, session): assert r.text == "hello" +# def test_file_uploads(api, session): +# @api.route("/") +# async def upload(req, resp): +# resp.media = {"files": await req.media("files")} + +# r = session.get(api.url_for(upload)) +# assert r.ok