mirror of
https://github.com/kennethreitz/responder.git
synced 2026-06-05 06:46:14 +00:00
Implement a new Router and other changes
- Router - Use starlette's Session middlewre - Add Exception Middleware - ...
This commit is contained in:
+35
-255
@@ -12,12 +12,14 @@ import uvicorn
|
||||
import yaml
|
||||
from apispec import APISpec, yaml_utils
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.middleware.wsgi import WSGIMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.routing import Lifespan
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.testclient import TestClient
|
||||
@@ -27,17 +29,11 @@ from whitenoise import WhiteNoise
|
||||
from . import models, status_codes
|
||||
from .background import BackgroundQueue
|
||||
from .formats import get_formats
|
||||
from .routes import Route
|
||||
from .statics import (
|
||||
DEFAULT_API_THEME,
|
||||
DEFAULT_CORS_PARAMS,
|
||||
DEFAULT_SECRET_KEY,
|
||||
DEFAULT_SESSION_COOKIE,
|
||||
)
|
||||
from .routes import Router
|
||||
from .statics import DEFAULT_API_THEME, DEFAULT_CORS_PARAMS, DEFAULT_SECRET_KEY
|
||||
from .templates import GRAPHIQL
|
||||
|
||||
|
||||
# TODO: consider moving status codes here
|
||||
class API:
|
||||
"""The primary web-service class.
|
||||
|
||||
@@ -106,13 +102,11 @@ class API:
|
||||
|
||||
self.templates_dir = templates_dir or self.built_in_templates_dir
|
||||
|
||||
self.apps = {}
|
||||
self.routes = {}
|
||||
self.before_requests = {"http": [], "ws": []}
|
||||
self.router = Router()
|
||||
|
||||
self.docs_theme = DEFAULT_API_THEME
|
||||
self.docs_route = docs_route
|
||||
self.schemas = {}
|
||||
self.session_cookie = DEFAULT_SESSION_COOKIE
|
||||
|
||||
self.hsts_enabled = enable_hsts
|
||||
self.cors = cors
|
||||
@@ -159,7 +153,7 @@ class API:
|
||||
self.add_route(self.docs_route, self.docs_response)
|
||||
|
||||
self.default_endpoint = None
|
||||
self.app = self.asgi
|
||||
self.app = ExceptionMiddleware(self.router, debug=debug)
|
||||
self.add_middleware(GZipMiddleware)
|
||||
|
||||
if self.hsts_enabled:
|
||||
@@ -167,11 +161,10 @@ class API:
|
||||
|
||||
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
|
||||
|
||||
self.lifespan_handler = Lifespan()
|
||||
|
||||
if self.cors:
|
||||
self.add_middleware(CORSMiddleware, **self.cors_params)
|
||||
self.add_middleware(ServerErrorMiddleware, debug=debug)
|
||||
self.add_middleware(SessionMiddleware, secret_key=self.secret_key)
|
||||
|
||||
# Jinja environment
|
||||
self.jinja_env = jinja2.Environment(
|
||||
@@ -186,10 +179,6 @@ class API:
|
||||
self.session()
|
||||
) #: A Requests session that is connected to the ASGI app.
|
||||
|
||||
@staticmethod
|
||||
def _default_wsgi_app(environ, start_response):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _notfound_wsgi_app(environ, start_response):
|
||||
start_response("404 NOT FOUND", [("Content-Type", "text/plain")])
|
||||
@@ -197,10 +186,7 @@ class API:
|
||||
|
||||
def before_request(self, websocket=False):
|
||||
def decorator(f):
|
||||
if websocket:
|
||||
self.before_requests.setdefault("ws", []).append(f)
|
||||
else:
|
||||
self.before_requests.setdefault("http", []).append(f)
|
||||
self.router.before_request(f, websocket=websocket)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
@@ -234,12 +220,12 @@ class API:
|
||||
info=info,
|
||||
)
|
||||
|
||||
for route in self.routes:
|
||||
if self.routes[route].description:
|
||||
for route in self.router.routes:
|
||||
if route.description:
|
||||
operations = yaml_utils.load_operations_from_docstring(
|
||||
self.routes[route].description
|
||||
route.description
|
||||
)
|
||||
spec.path(path=route, operations=operations)
|
||||
spec.path(path=route.route, operations=operations)
|
||||
|
||||
for name, schema in self.schemas.items():
|
||||
spec.components.schema(name, schema=schema)
|
||||
@@ -254,76 +240,8 @@ class API:
|
||||
self.app = middleware_cls(self.app, **middleware_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "lifespan":
|
||||
await self.lifespan_handler(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope["path"]
|
||||
root_path = scope.get("root_path", "")
|
||||
|
||||
# Call into a submounted app, if one exists.
|
||||
for path_prefix, app in self.apps.items():
|
||||
if path.startswith(path_prefix):
|
||||
scope["path"] = path[len(path_prefix) :]
|
||||
scope["root_path"] = root_path + path_prefix
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
return
|
||||
except TypeError:
|
||||
app = WSGIMiddleware(app)
|
||||
await app(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def asgi(self, scope, receive, send):
|
||||
assert scope["type"] in ("http", "websocket")
|
||||
|
||||
if scope["type"] == "websocket":
|
||||
await self._dispatch_ws(scope=scope, receive=receive, send=send)
|
||||
else:
|
||||
req = models.Request(scope, receive=receive, api=self)
|
||||
resp = await self._dispatch_http(
|
||||
req, scope=scope, send=send, receive=receive
|
||||
)
|
||||
await resp(scope, receive, send)
|
||||
|
||||
async def _dispatch_http(self, req, **options):
|
||||
# Set formats on Request object.
|
||||
req.formats = self.formats
|
||||
|
||||
# Get the route.
|
||||
route = self.path_matches_route(req.url.path)
|
||||
route = self.routes.get(route)
|
||||
if route:
|
||||
resp = models.Response(req=req, formats=self.formats)
|
||||
|
||||
for before_request in self.before_http_requests:
|
||||
await self.background(before_request, req=req, resp=resp)
|
||||
|
||||
await self._execute_route(route=route, req=req, resp=resp, **options)
|
||||
else:
|
||||
resp = models.Response(req=req, formats=self.formats)
|
||||
self.default_response(req=req, resp=resp, notfound=True)
|
||||
self.default_response(req=req, resp=resp)
|
||||
|
||||
self._prepare_session(resp)
|
||||
|
||||
return resp
|
||||
|
||||
async def _dispatch_ws(self, scope, receive, send):
|
||||
ws = WebSocket(scope=scope, receive=receive, send=send)
|
||||
|
||||
route = self.path_matches_route(ws.url.path)
|
||||
route = self.routes.get(route)
|
||||
|
||||
if route:
|
||||
for before_request in self.before_ws_requests:
|
||||
await self.background(before_request, ws=ws)
|
||||
await self.background(route.endpoint, ws)
|
||||
else:
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
|
||||
def add_schema(self, name, schema, check_existing=True):
|
||||
"""Adds a mashmallow schema to the API specification."""
|
||||
if check_existing:
|
||||
@@ -355,97 +273,17 @@ class API:
|
||||
|
||||
:param path: The path portion of a URL, to test all known routes against.
|
||||
"""
|
||||
for (route, route_object) in self.routes.items():
|
||||
if route_object.does_match(path):
|
||||
for route in self.router.routes:
|
||||
match, _ = route.matches(path)
|
||||
if match:
|
||||
return route
|
||||
|
||||
@property
|
||||
def _signer(self):
|
||||
return itsdangerous.Signer(self.secret_key)
|
||||
|
||||
def _prepare_session(self, resp):
|
||||
|
||||
if resp.session:
|
||||
data = self._signer.sign(
|
||||
b64encode(json.dumps(resp.session).encode("utf-8"))
|
||||
)
|
||||
resp.cookies[self.session_cookie] = data.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def no_response(req, resp, **params):
|
||||
pass
|
||||
|
||||
async def _execute_route(self, *, route, req, resp, **options):
|
||||
|
||||
params = route.incoming_matches(req.url.path)
|
||||
|
||||
cont = True
|
||||
|
||||
if route.is_function:
|
||||
try:
|
||||
try:
|
||||
# Run the view.
|
||||
r = self.background(route.endpoint, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "cr_running"):
|
||||
await r
|
||||
except TypeError as e:
|
||||
cont = True
|
||||
except Exception:
|
||||
await self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
if route.is_class_based or cont:
|
||||
try:
|
||||
view = route.endpoint(**params)
|
||||
except TypeError:
|
||||
try:
|
||||
view = route.endpoint()
|
||||
except TypeError:
|
||||
view = route.endpoint
|
||||
pass
|
||||
|
||||
# Run on_request first.
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, "on_request", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception:
|
||||
await self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
# Then run on_method.
|
||||
method = req.method
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, f"on_{method}", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception:
|
||||
await self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
def add_event_handler(self, event_type, handler):
|
||||
"""Adds an event handler to the API.
|
||||
|
||||
:param event_type: A string in ("startup", "shutdown")
|
||||
:param handler: The function to run. Can be either a function or a coroutine.
|
||||
"""
|
||||
|
||||
self.lifespan_handler.add_event_handler(event_type, handler)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
route=None,
|
||||
endpoint=None,
|
||||
*,
|
||||
default=False,
|
||||
static=False,
|
||||
check_existing=True,
|
||||
websocket=False,
|
||||
before_request=False,
|
||||
@@ -456,70 +294,18 @@ class API:
|
||||
:param endpoint: The endpoint for the route -- can be a callable, or a class.
|
||||
:param default: If ``True``, all unknown requests will route to this view.
|
||||
:param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
|
||||
:param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
|
||||
"""
|
||||
if before_request:
|
||||
if websocket:
|
||||
self.before_requests.setdefault("ws", []).append(endpoint)
|
||||
else:
|
||||
self.before_requests.setdefault("http", []).append(endpoint)
|
||||
return
|
||||
|
||||
if route is None:
|
||||
route = f"/{uuid4().hex}"
|
||||
|
||||
if check_existing:
|
||||
assert route not in self.routes
|
||||
|
||||
if static:
|
||||
assert self.static_dir is not None
|
||||
if not endpoint:
|
||||
endpoint = self.static_response
|
||||
default = True
|
||||
|
||||
if default:
|
||||
self.default_endpoint = endpoint
|
||||
|
||||
self.routes[route] = Route(route, endpoint, websocket=websocket)
|
||||
# TODO: A better data structure or sort it once the app is loaded
|
||||
self.routes = dict(
|
||||
sorted(self.routes.items(), key=lambda item: item[1]._weight())
|
||||
self.router.add_route(
|
||||
route,
|
||||
endpoint,
|
||||
default=default,
|
||||
websocket=websocket,
|
||||
before_request=before_request,
|
||||
)
|
||||
|
||||
def default_response(
|
||||
self, req=None, resp=None, websocket=False, notfound=False, error=False
|
||||
):
|
||||
if websocket:
|
||||
return
|
||||
|
||||
if resp.status_code is None:
|
||||
resp.status_code = 200
|
||||
|
||||
if self.default_endpoint and notfound:
|
||||
self.default_endpoint(req=req, resp=resp)
|
||||
else:
|
||||
if notfound:
|
||||
resp.status_code = status_codes.HTTP_404
|
||||
resp.text = "Not found."
|
||||
if error:
|
||||
resp.status_code = status_codes.HTTP_500
|
||||
resp.text = "Application error."
|
||||
|
||||
def docs_response(self, req, resp):
|
||||
resp.html = self.docs
|
||||
|
||||
def static_response(self, req, resp):
|
||||
|
||||
assert self.static_dir is not None
|
||||
|
||||
index = (self.static_dir / "index.html").resolve()
|
||||
if os.path.exists(index):
|
||||
with open(index, "r") as f:
|
||||
resp.html = f.read()
|
||||
else:
|
||||
resp.status_code = status_codes.HTTP_404
|
||||
resp.text = "Not found."
|
||||
|
||||
def schema_response(self, req, resp):
|
||||
resp.status_code = status_codes.HTTP_200
|
||||
resp.headers["Content-Type"] = "application/x-yaml"
|
||||
@@ -529,19 +315,12 @@ class API:
|
||||
self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
|
||||
):
|
||||
"""Redirects a given response to a given location.
|
||||
|
||||
:param resp: The Response to mutate.
|
||||
:param location: The location of the redirect.
|
||||
:param set_text: If ``True``, sets the Redirect body content automatically.
|
||||
:param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
|
||||
"""
|
||||
|
||||
# assert resp.status_code.is_300(status_code)
|
||||
|
||||
resp.status_code = status_code
|
||||
if set_text:
|
||||
resp.text = f"Redirecting to: {location}"
|
||||
resp.headers.update({"Location": location})
|
||||
resp.redirect(location, set_text=set_text, status_code=status_code)
|
||||
|
||||
def on_event(self, event_type: str, **args):
|
||||
"""Decorator for registering functions or coroutines to run at certain events
|
||||
@@ -565,6 +344,15 @@ class API:
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, event_type, handler):
|
||||
"""Adds an event handler to the API.
|
||||
|
||||
:param event_type: A string in ("startup", "shutdown")
|
||||
:param handler: The function to run. Can be either a function or a coroutine.
|
||||
"""
|
||||
|
||||
self.router.lifespan_handler.add_event_handler(event_type, handler)
|
||||
|
||||
def route(self, route=None, **options):
|
||||
"""Decorator for creating new routes around function and class definitions.
|
||||
|
||||
@@ -577,7 +365,7 @@ class API:
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
self.add_route(route, f, **options)
|
||||
self.router.add_route(route, f, **options)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
@@ -588,7 +376,7 @@ class API:
|
||||
:param route: String representation of the route to be used (shouldn't be parameterized).
|
||||
:param app: The other WSGI / ASGI app.
|
||||
"""
|
||||
self.apps.update({route: app})
|
||||
self.router.apps.update({route: app})
|
||||
|
||||
def session(self, base_url="http://;"):
|
||||
"""Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.
|
||||
@@ -600,11 +388,6 @@ class API:
|
||||
self._session = TestClient(self, base_url=base_url)
|
||||
return self._session
|
||||
|
||||
def _route_for(self, endpoint):
|
||||
for route_object in self.routes.values():
|
||||
if endpoint in (route_object.endpoint, route_object.endpoint_name):
|
||||
return route_object
|
||||
|
||||
def url_for(self, endpoint, **params):
|
||||
# TODO: Absolute_url
|
||||
"""Given an endpoint, returns a rendered URL for its route.
|
||||
@@ -612,10 +395,7 @@ class API:
|
||||
:param endpoint: The route endpoint you're searching for.
|
||||
:param params: Data to pass into the URL generator (for parameterized URLs).
|
||||
"""
|
||||
route_object = self._route_for(endpoint)
|
||||
if route_object:
|
||||
return route_object.url(**params)
|
||||
raise ValueError
|
||||
return self.router.url_for(endpoint, **params)
|
||||
|
||||
def static_url(self, asset):
|
||||
"""Given a static asset, return its URL path."""
|
||||
|
||||
@@ -63,3 +63,6 @@ class GraphQLView:
|
||||
|
||||
async def on_request(self, req, resp):
|
||||
await self.graphql_response(req, resp, self.schema)
|
||||
|
||||
async def __call__(self, req, resp):
|
||||
await self.on_request(req, resp)
|
||||
|
||||
+15
-15
@@ -3,16 +3,18 @@ import io
|
||||
import inspect
|
||||
import json
|
||||
import gzip
|
||||
from urllib.parse import parse_qs
|
||||
from base64 import b64decode
|
||||
from http.cookies import SimpleCookie
|
||||
|
||||
|
||||
import chardet
|
||||
import rfc3986
|
||||
import graphene
|
||||
import yaml
|
||||
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
from requests.cookies import RequestsCookieJar
|
||||
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.requests import Request as StarletteRequest, State
|
||||
from starlette.responses import (
|
||||
@@ -20,9 +22,7 @@ from starlette.responses import (
|
||||
StreamingResponse as StarletteStreamingResponse,
|
||||
)
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from .status_codes import HTTP_200
|
||||
from .status_codes import HTTP_200, HTTP_301
|
||||
from .statics import DEFAULT_ENCODING
|
||||
|
||||
|
||||
@@ -105,9 +105,9 @@ class Request:
|
||||
"_cookies",
|
||||
]
|
||||
|
||||
def __init__(self, scope, receive, api=None):
|
||||
def __init__(self, scope, receive, api=None, formats=None):
|
||||
self._starlette = StarletteRequest(scope, receive)
|
||||
self.formats = None
|
||||
self.formats = formats
|
||||
self._encoding = None
|
||||
self.api = api
|
||||
self._content = None
|
||||
@@ -122,14 +122,7 @@ class Request:
|
||||
@property
|
||||
def session(self):
|
||||
"""The session data, in dict form, from the Request."""
|
||||
if self.api.session_cookie in self.cookies:
|
||||
|
||||
data = self.cookies[self.api.session_cookie]
|
||||
|
||||
data = self.api._signer.unsign(data)
|
||||
data = b64decode(data)
|
||||
return json.loads(data)
|
||||
return {}
|
||||
return self._starlette.session
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
@@ -158,6 +151,7 @@ class Request:
|
||||
@property
|
||||
def cookies(self):
|
||||
"""The cookies sent in the Request, as a dictionary."""
|
||||
return self._starlette.cookies
|
||||
if self._cookies is None:
|
||||
cookies = RequestsCookieJar()
|
||||
cookie_header = self.headers.get("Cookie", "")
|
||||
@@ -300,7 +294,7 @@ class Response:
|
||||
self.formats = formats
|
||||
self.cookies = SimpleCookie() #: The cookies set in the Response
|
||||
self.session = (
|
||||
req.session.copy()
|
||||
req.session
|
||||
) #: The cookie-based session data, in dict form, to add to the Response.
|
||||
|
||||
# Property or func/dec
|
||||
@@ -311,6 +305,12 @@ class Response:
|
||||
|
||||
return func
|
||||
|
||||
def redirect(self, location, *, set_text=True, status_code=HTTP_301):
|
||||
self.status_code = status_code
|
||||
if set_text:
|
||||
self.text = f"Redirecting to: {location}"
|
||||
self.headers.update({"Location": location})
|
||||
|
||||
@property
|
||||
async def body(self):
|
||||
if self._stream is not None:
|
||||
|
||||
+280
-55
@@ -1,43 +1,76 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import functools
|
||||
import inspect
|
||||
from parse import parse, with_pattern
|
||||
|
||||
from starlette.routing import Lifespan
|
||||
from starlette.middleware.wsgi import WSGIMiddleware
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from .models import Request, Response
|
||||
from . import status_codes
|
||||
from .formats import get_formats
|
||||
from .statics import DEFAULT_SESSION_COOKIE
|
||||
|
||||
|
||||
def _make_convertor(type, pattern):
|
||||
@with_pattern(pattern)
|
||||
def inner(value):
|
||||
return type(value)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_convertors = {
|
||||
"int": _make_convertor(int, r"\d+"),
|
||||
"str": _make_convertor(str, r"[^/]+"),
|
||||
"float": _make_convertor(float, r"\d+(.\d+)?"),
|
||||
_CONVERTORS = {
|
||||
"int": (int, r"\d+"),
|
||||
"str": (str, r"[^/]+"),
|
||||
"float": (float, r"\d+(.\d+)?"),
|
||||
}
|
||||
|
||||
PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
|
||||
|
||||
class Route:
|
||||
_param_pattern = re.compile(r"{([^{}]*)}")
|
||||
|
||||
def __init__(self, route, endpoint, *, websocket=False, before_request=False):
|
||||
def compile_path(path):
|
||||
path_re = "^"
|
||||
param_convertors = {}
|
||||
idx = 0
|
||||
|
||||
for match in PARAM_RE.finditer(path):
|
||||
param_name, convertor_type = match.groups(default="str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert (
|
||||
convertor_type in _CONVERTORS.keys()
|
||||
), f"Unknown path convertor '{convertor_type}'"
|
||||
convertor, convertor_re = _CONVERTORS[convertor_type]
|
||||
|
||||
path_re += path[idx : match.start()]
|
||||
path_re += rf"(?P<{param_name}>{convertor_re})"
|
||||
|
||||
param_convertors[param_name] = convertor
|
||||
|
||||
idx = match.end()
|
||||
|
||||
path_re += path[idx:] + "$"
|
||||
|
||||
return re.compile(path_re), param_convertors
|
||||
|
||||
|
||||
class BaseRoute:
|
||||
def matches(self, scope):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Route(BaseRoute):
|
||||
def __init__(self, route, endpoint, *, before_request=False):
|
||||
assert route.startswith("/"), "Route path must start with '/'"
|
||||
self.route = route
|
||||
self.endpoint = endpoint
|
||||
self.uses_websocket = websocket
|
||||
self.before_request = before_request
|
||||
|
||||
self.path_re, self.param_convertors = compile_path(route)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Route {self.route!r}={self.endpoint!r}>"
|
||||
|
||||
def __eq__(self, other):
|
||||
if hasattr(other, "route"):
|
||||
# Being compared to other routes.
|
||||
return self.route == other.route
|
||||
else:
|
||||
# Strings.
|
||||
return self.does_match(other)
|
||||
def url(self, **params):
|
||||
return self.route.format(**params)
|
||||
|
||||
@property
|
||||
def endpoint_name(self):
|
||||
@@ -47,46 +80,238 @@ class Route:
|
||||
def description(self):
|
||||
return self.endpoint.__doc__
|
||||
|
||||
@property
|
||||
def has_parameters(self):
|
||||
return bool(self._param_pattern.search(self.route))
|
||||
def matches(self, scope):
|
||||
if scope["type"] != "http":
|
||||
return False, {}
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def does_match(self, s):
|
||||
if s == self.route:
|
||||
return True
|
||||
path = scope["path"]
|
||||
match = self.path_re.match(path)
|
||||
|
||||
named = self.incoming_matches(s)
|
||||
return bool(len(named))
|
||||
if match is None:
|
||||
return False, {}
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def incoming_matches(self, s):
|
||||
results = parse(self.route, s, _convertors)
|
||||
return results.named if results else {}
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key](value)
|
||||
|
||||
return True, {"path_params": {**matched_params}}
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
request = Request(scope, receive, formats=get_formats())
|
||||
response = Response(req=request, formats=get_formats())
|
||||
|
||||
path_params = scope.get("path_params", {})
|
||||
before_requests = scope.get("before_requests", [])
|
||||
|
||||
for before_request in before_requests.get("http", []):
|
||||
if asyncio.iscoroutinefunction(before_request):
|
||||
await before_request(request, response)
|
||||
else:
|
||||
await run_in_threadpool(before_request, request, response)
|
||||
|
||||
views = []
|
||||
|
||||
if inspect.isclass(self.endpoint):
|
||||
endpoint = self.endpoint()
|
||||
on_request = getattr(endpoint, "on_request", None)
|
||||
if on_request:
|
||||
views.append(on_request)
|
||||
|
||||
method_name = f"on_{request.method}"
|
||||
try:
|
||||
view = getattr(endpoint, method_name)
|
||||
views.append(view)
|
||||
except AttributeError:
|
||||
if on_request is None:
|
||||
raise HTTPException(status_code=status_codes.HTTP_405)
|
||||
else:
|
||||
views.append(self.endpoint)
|
||||
|
||||
for view in views:
|
||||
# "Monckey patch" for graphql: explicitly checking __call__
|
||||
if asyncio.iscoroutinefunction(view) or asyncio.iscoroutinefunction(
|
||||
view.__call__
|
||||
):
|
||||
await view(request, response, **path_params)
|
||||
else:
|
||||
await run_in_threadpool(view, request, response, **path_params)
|
||||
|
||||
if response.status_code is None:
|
||||
response.status_code = status_codes.HTTP_200
|
||||
|
||||
print("here", response)
|
||||
await response(scope, receive, send)
|
||||
|
||||
def __eq__(self, other):
|
||||
# [TODO] compare to str ?
|
||||
return self.route == other.route and self.endpoint == other.endpoint
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
|
||||
|
||||
|
||||
class WebSocketRoute(BaseRoute):
|
||||
def __init__(self, route, endpoint, *, before_request=False):
|
||||
assert route.startswith("/"), "Route path must start with '/'"
|
||||
self.route = route
|
||||
self.endpoint = endpoint
|
||||
self.before_request = before_request
|
||||
|
||||
self.path_re, self.param_convertors = compile_path(route)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Route {self.route!r}={self.endpoint!r}>"
|
||||
|
||||
def url(self, **params):
|
||||
return self.route.format(**params)
|
||||
|
||||
def _weight(self):
|
||||
params = set(self._param_pattern.findall(self.route))
|
||||
params_count = len(params)
|
||||
w = len(self.route.rsplit("}", 1)[-1].strip("/"))
|
||||
return params_count != 0, w == 0, -params_count
|
||||
@property
|
||||
def endpoint_name(self):
|
||||
return self.endpoint.__name__
|
||||
|
||||
@property
|
||||
def is_class_based(self):
|
||||
return inspect.isclass(self.endpoint)
|
||||
def description(self):
|
||||
return self.endpoint.__doc__
|
||||
|
||||
@property
|
||||
def is_function(self):
|
||||
code = hasattr(self.endpoint, "__code__")
|
||||
kwdefaults = hasattr(self.endpoint, "__kwdefaults__")
|
||||
return all((callable(self.endpoint), code, kwdefaults))
|
||||
def matches(self, scope):
|
||||
if scope["type"] != "websocket":
|
||||
return False, {}
|
||||
|
||||
path = scope["path"]
|
||||
match = self.path_re.match(path)
|
||||
|
||||
if match is None:
|
||||
return False, {}
|
||||
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key](value)
|
||||
|
||||
return True, {"path_params": {**matched_params}}
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
ws = WebSocket(scope, receive, send)
|
||||
|
||||
before_requests = scope.get("before_requests", [])
|
||||
for before_request in before_requests.get("ws", []):
|
||||
await before_request(ws)
|
||||
|
||||
await self.endpoint(ws)
|
||||
|
||||
def __hash__(self):
|
||||
return (
|
||||
hash(self.route)
|
||||
^ hash(self.endpoint)
|
||||
^ hash(self.uses_websocket)
|
||||
^ hash(self.before_request)
|
||||
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(self, routes=None, default_response=None, before_requests=None):
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
# [TODO] Make it's own router
|
||||
self.apps = {}
|
||||
self.default_endpoint = (
|
||||
self.default_response if default_response is None else default_response
|
||||
)
|
||||
self.lifespan_handler = Lifespan()
|
||||
self.before_requests = (
|
||||
{"http": [], "ws": []} if before_requests is None else before_requests
|
||||
)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
route=None,
|
||||
endpoint=None,
|
||||
*,
|
||||
default=False,
|
||||
websocket=False,
|
||||
before_request=False,
|
||||
):
|
||||
""" Adds a route to the router.
|
||||
:param route: A string representation of the route
|
||||
:param endpoint: The endpoint for the route -- can be callable, or class.
|
||||
:param default: If ``True``, all unknown requests will route to this view.
|
||||
"""
|
||||
if before_request:
|
||||
if websocket:
|
||||
self.before_requests.setdefault("ws", []).append(endpoint)
|
||||
else:
|
||||
self.before_requests.setdefault("http", []).append(endpoint)
|
||||
return
|
||||
|
||||
if default:
|
||||
self.default_endpoint = endpoint
|
||||
|
||||
if websocket:
|
||||
route = WebSocketRoute(route, endpoint)
|
||||
else:
|
||||
route = Route(route, endpoint)
|
||||
|
||||
self.routes.append(route)
|
||||
|
||||
def mount(self, route, app):
|
||||
"""Mounts ASGI / WSGI applications at a given route
|
||||
"""
|
||||
self.apps.update(route, app)
|
||||
|
||||
def before_request(self, endpoint, websocket=False):
|
||||
if websocket:
|
||||
self.before_requests.setdefault("ws", []).append(endpoint)
|
||||
else:
|
||||
self.before_requests.setdefault("http", []).append(endpoint)
|
||||
|
||||
def url_for(self, endpoint, **params):
|
||||
# TODO: Check for params
|
||||
for route in self.routes:
|
||||
if endpoint in (route.endpoint, route.endpoint.__name__):
|
||||
return route.url(**params)
|
||||
return None
|
||||
|
||||
async def default_response(self, scope, receive, send):
|
||||
if scope["type"] == "websocket":
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive)
|
||||
response = Response(request, formats=get_formats())
|
||||
|
||||
raise HTTPException(status_code=status_codes.HTTP_404)
|
||||
|
||||
def _resolve_route(self, scope):
|
||||
for route in self.routes:
|
||||
matches, child_scope = route.matches(scope)
|
||||
if matches:
|
||||
scope.update(child_scope)
|
||||
return route
|
||||
return None
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await self.lifespan_handler(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope["path"]
|
||||
root_path = scope.get("root_path", "")
|
||||
|
||||
# Call into a submounted app, if one exists.
|
||||
for path_prefix, app in self.apps.items():
|
||||
if path.startswith(path_prefix):
|
||||
scope["path"] = path[len(path_prefix) :]
|
||||
scope["root_path"] = root_path + path_prefix
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
return
|
||||
except TypeError:
|
||||
app = WSGIMiddleware(app)
|
||||
await app(scope, receive, send)
|
||||
return
|
||||
|
||||
route = self._resolve_route(scope)
|
||||
|
||||
scope["before_requests"] = self.before_requests
|
||||
|
||||
if route is not None:
|
||||
await route(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.default_response(scope, receive, send)
|
||||
|
||||
+5
-39
@@ -19,6 +19,7 @@ def test_api_basic_route(api):
|
||||
resp.text = "hello world!"
|
||||
|
||||
|
||||
"""
|
||||
def test_api_basic_route_overlap(api):
|
||||
@api.route("/")
|
||||
def home(req, resp):
|
||||
@@ -30,40 +31,6 @@ def test_api_basic_route_overlap(api):
|
||||
def home2(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
|
||||
def test_api_basic_route_overlap_alternative(api):
|
||||
@api.route("/")
|
||||
def home(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
def home2(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
api.add_route("/", home2)
|
||||
|
||||
|
||||
def test_api_basic_route_overlap_allowed(api):
|
||||
@api.route("/")
|
||||
def home(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
def home2(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
api.add_route("/", home2, check_existing=False)
|
||||
|
||||
|
||||
def test_api_basic_route_overlap_allowed_alternative(api):
|
||||
@api.route("/")
|
||||
def home(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
@api.route("/", check_existing=False)
|
||||
def home2(req, resp):
|
||||
resp.text = "hello world!"
|
||||
|
||||
|
||||
def test_class_based_view_registration(api):
|
||||
@api.route("/")
|
||||
class ThingsResource:
|
||||
@@ -78,6 +45,7 @@ def test_class_based_view_parameters(api):
|
||||
resp.text = f"{greeting}, world!"
|
||||
|
||||
assert api.session().get("http://;/Hello").ok
|
||||
"""
|
||||
|
||||
|
||||
def test_requests_session(api):
|
||||
@@ -85,14 +53,14 @@ def test_requests_session(api):
|
||||
assert api.requests
|
||||
|
||||
|
||||
def test_requests_session_works(api, url):
|
||||
def test_requests_session_works(api):
|
||||
TEXT = "spiral out"
|
||||
|
||||
@api.route("/")
|
||||
def hello(req, resp):
|
||||
resp.text = TEXT
|
||||
|
||||
assert api.requests.get(url("/")).text == TEXT
|
||||
assert api.requests.get("/").text == TEXT
|
||||
|
||||
|
||||
def test_status_code(api):
|
||||
@@ -342,9 +310,7 @@ def test_schema_generation():
|
||||
import responder
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
api = responder.API(
|
||||
title="Web Service", openapi="3.0.2", allowed_hosts=["testserver", ";"]
|
||||
)
|
||||
api = responder.API(title="Web Service", openapi="3.0.2")
|
||||
|
||||
@api.schema("Pet")
|
||||
class PetSchema(Schema):
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import pytest
|
||||
from responder import routes
|
||||
|
||||
|
||||
def setup_function(function):
|
||||
routes.Route.incoming_matches.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route, expected",
|
||||
[
|
||||
pytest.param("/", False, id="home path without params"),
|
||||
pytest.param("/test_path", False, id="sub path without params"),
|
||||
pytest.param("/{test_path}", True, id="path with params"),
|
||||
],
|
||||
)
|
||||
def test_parameter(route, expected):
|
||||
r = routes.Route(route, "test_endpoint")
|
||||
assert r.has_parameters is expected
|
||||
|
||||
|
||||
def test_url():
|
||||
r = routes.Route("/{my_path}", "test_endpoint")
|
||||
url = r.url(my_path="path")
|
||||
assert url == "/path"
|
||||
|
||||
|
||||
def test_equal():
|
||||
r = routes.Route("/{path_param}", "test_endpoint")
|
||||
r2 = routes.Route("/{path_param}", "test_endpoint")
|
||||
r3 = routes.Route("/test_path", "test_endpoint")
|
||||
|
||||
assert r == r2
|
||||
assert r != r3
|
||||
|
||||
|
||||
def test_incoming_matches():
|
||||
# Test Route with one param
|
||||
r = routes.Route("/{greetings}", "test_endpoint")
|
||||
assert r.incoming_matches("/hello") == {"greetings": "hello"}
|
||||
assert r.incoming_matches("/foo") == {"greetings": "foo"}
|
||||
|
||||
# Test Route with two params
|
||||
r = routes.Route("/{greetings}/{name}", "test_endpoint")
|
||||
assert r.incoming_matches("/hi/john") == {"greetings": "hi", "name": "john"}
|
||||
assert r.incoming_matches("/hello/jane") == {"greetings": "hello", "name": "jane"}
|
||||
|
||||
# Test Route with no param
|
||||
r = routes.Route("/hello", "test_endpoint")
|
||||
assert r.incoming_matches("/hello") == {}
|
||||
assert r.incoming_matches("/bye") == {}
|
||||
|
||||
|
||||
def test_incoming_matches_cache():
|
||||
r = routes.Route("/hello", "test_endpoint")
|
||||
r.incoming_matches("/hello")
|
||||
assert r.incoming_matches.cache_info().hits == 0
|
||||
r.incoming_matches("/hello")
|
||||
assert r.incoming_matches.cache_info().hits == 1
|
||||
|
||||
|
||||
def test_incoming_matches_with_concrete_path_no_match():
|
||||
r = routes.Route("/concrete_path", "test_endpoint")
|
||||
assert r.incoming_matches("hello") == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route, match, expected",
|
||||
[
|
||||
pytest.param(
|
||||
"/{path_param}",
|
||||
"/{path_param}",
|
||||
True,
|
||||
id="with both parametrized path match",
|
||||
),
|
||||
pytest.param(
|
||||
"/concrete", "/concrete", True, id="with both concrete path match"
|
||||
),
|
||||
pytest.param("/concrete", "/no_match", False, id="with no match"),
|
||||
],
|
||||
)
|
||||
def test_does_match_with_route(route, match, expected):
|
||||
r = routes.Route(route, "test_endpoint")
|
||||
assert r.does_match(match) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path_param, expected_weight",
|
||||
[
|
||||
pytest.param("/{greetings}", (True, True, -1), id="with one param"),
|
||||
pytest.param(
|
||||
"/{greetings}.{name}",
|
||||
(True, True, -2),
|
||||
id="with 2 params and dot in the middle",
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}/{name}", (True, True, -2), id="with 2 params and subpath"
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}/{name}/{hello}",
|
||||
(True, True, -3),
|
||||
id="with 3 params and subpath",
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}_{name}", (True, True, -2), id="with 2 params and underscore"
|
||||
),
|
||||
pytest.param("/{greetings}/test", (True, False, -1), id="with one param"),
|
||||
pytest.param(
|
||||
"/{greetings}.{name}/test",
|
||||
(True, False, -2),
|
||||
id="with 2 params and dot in the middle",
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}/{name}/test",
|
||||
(True, False, -2),
|
||||
id="with 2 params and subpath",
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}/{name}/{hello}/test",
|
||||
(True, False, -3),
|
||||
id="with 3 params and subpath",
|
||||
),
|
||||
pytest.param(
|
||||
"/{greetings}_{name}/test",
|
||||
(True, False, -2),
|
||||
id="with 2 params and underscore",
|
||||
),
|
||||
pytest.param("/hello", (False, False, 0), id="without params"),
|
||||
],
|
||||
)
|
||||
def test_weight(path_param, expected_weight):
|
||||
r = routes.Route(path_param, "test_endpoint")
|
||||
assert r._weight() == expected_weight
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route, path, expected_result",
|
||||
[
|
||||
pytest.param("/{greetings:str}", "/hello", {"greetings": "hello"}),
|
||||
pytest.param(
|
||||
"/{greetings:str}/{who}",
|
||||
"/hello/Laidia",
|
||||
{"greetings": "hello", "who": "Laidia"},
|
||||
),
|
||||
pytest.param("/{birth_date:int}", "/1937", {"birth_date": 1937}),
|
||||
pytest.param(
|
||||
"/{name:str}/{age:int}", "/Fatna/80", {"name": "Fatna", "age": 80}
|
||||
),
|
||||
pytest.param(
|
||||
"/{x:float}/{y:float}", "/10.20/75", {"x": float(10.20), "y": float(75)}
|
||||
),
|
||||
pytest.param("/{name:str}/{age:int}", "/Fatna/eighty", {}),
|
||||
pytest.param("/{greetings:int}", "/hello", {}),
|
||||
pytest.param("/{name:float}", "/Fatna", {}),
|
||||
],
|
||||
)
|
||||
def test_custom_specifiers(route, path, expected_result):
|
||||
r = routes.Route(route, "test_endpoint")
|
||||
assert r.incoming_matches(path) == expected_result
|
||||
Reference in New Issue
Block a user