Implement a new Router and other changes

- Router
- Use starlette's Session middlewre
- Add Exception Middleware
- ...
This commit is contained in:
taoufik
2019-08-15 23:24:42 +02:00
parent d820f0277f
commit b31b742787
6 changed files with 338 additions and 523 deletions
+35 -255
View File
@@ -12,12 +12,14 @@ import uvicorn
import yaml
from apispec import APISpec, yaml_utils
from apispec.ext.marshmallow import MarshmallowPlugin
from starlette.exceptions import ExceptionMiddleware
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.routing import Lifespan
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
@@ -27,17 +29,11 @@ from whitenoise import WhiteNoise
from . import models, status_codes
from .background import BackgroundQueue
from .formats import get_formats
from .routes import Route
from .statics import (
DEFAULT_API_THEME,
DEFAULT_CORS_PARAMS,
DEFAULT_SECRET_KEY,
DEFAULT_SESSION_COOKIE,
)
from .routes import Router
from .statics import DEFAULT_API_THEME, DEFAULT_CORS_PARAMS, DEFAULT_SECRET_KEY
from .templates import GRAPHIQL
# TODO: consider moving status codes here
class API:
"""The primary web-service class.
@@ -106,13 +102,11 @@ class API:
self.templates_dir = templates_dir or self.built_in_templates_dir
self.apps = {}
self.routes = {}
self.before_requests = {"http": [], "ws": []}
self.router = Router()
self.docs_theme = DEFAULT_API_THEME
self.docs_route = docs_route
self.schemas = {}
self.session_cookie = DEFAULT_SESSION_COOKIE
self.hsts_enabled = enable_hsts
self.cors = cors
@@ -159,7 +153,7 @@ class API:
self.add_route(self.docs_route, self.docs_response)
self.default_endpoint = None
self.app = self.asgi
self.app = ExceptionMiddleware(self.router, debug=debug)
self.add_middleware(GZipMiddleware)
if self.hsts_enabled:
@@ -167,11 +161,10 @@ class API:
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
self.lifespan_handler = Lifespan()
if self.cors:
self.add_middleware(CORSMiddleware, **self.cors_params)
self.add_middleware(ServerErrorMiddleware, debug=debug)
self.add_middleware(SessionMiddleware, secret_key=self.secret_key)
# Jinja environment
self.jinja_env = jinja2.Environment(
@@ -186,10 +179,6 @@ class API:
self.session()
) #: A Requests session that is connected to the ASGI app.
@staticmethod
def _default_wsgi_app(environ, start_response):
pass
@staticmethod
def _notfound_wsgi_app(environ, start_response):
start_response("404 NOT FOUND", [("Content-Type", "text/plain")])
@@ -197,10 +186,7 @@ class API:
def before_request(self, websocket=False):
def decorator(f):
if websocket:
self.before_requests.setdefault("ws", []).append(f)
else:
self.before_requests.setdefault("http", []).append(f)
self.router.before_request(f, websocket=websocket)
return f
return decorator
@@ -234,12 +220,12 @@ class API:
info=info,
)
for route in self.routes:
if self.routes[route].description:
for route in self.router.routes:
if route.description:
operations = yaml_utils.load_operations_from_docstring(
self.routes[route].description
route.description
)
spec.path(path=route, operations=operations)
spec.path(path=route.route, operations=operations)
for name, schema in self.schemas.items():
spec.components.schema(name, schema=schema)
@@ -254,76 +240,8 @@ class API:
self.app = middleware_cls(self.app, **middleware_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
await self.lifespan_handler(scope, receive, send)
return
path = scope["path"]
root_path = scope.get("root_path", "")
# Call into a submounted app, if one exists.
for path_prefix, app in self.apps.items():
if path.startswith(path_prefix):
scope["path"] = path[len(path_prefix) :]
scope["root_path"] = root_path + path_prefix
try:
await app(scope, receive, send)
return
except TypeError:
app = WSGIMiddleware(app)
await app(scope, receive, send)
return
await self.app(scope, receive, send)
async def asgi(self, scope, receive, send):
assert scope["type"] in ("http", "websocket")
if scope["type"] == "websocket":
await self._dispatch_ws(scope=scope, receive=receive, send=send)
else:
req = models.Request(scope, receive=receive, api=self)
resp = await self._dispatch_http(
req, scope=scope, send=send, receive=receive
)
await resp(scope, receive, send)
async def _dispatch_http(self, req, **options):
# Set formats on Request object.
req.formats = self.formats
# Get the route.
route = self.path_matches_route(req.url.path)
route = self.routes.get(route)
if route:
resp = models.Response(req=req, formats=self.formats)
for before_request in self.before_http_requests:
await self.background(before_request, req=req, resp=resp)
await self._execute_route(route=route, req=req, resp=resp, **options)
else:
resp = models.Response(req=req, formats=self.formats)
self.default_response(req=req, resp=resp, notfound=True)
self.default_response(req=req, resp=resp)
self._prepare_session(resp)
return resp
async def _dispatch_ws(self, scope, receive, send):
ws = WebSocket(scope=scope, receive=receive, send=send)
route = self.path_matches_route(ws.url.path)
route = self.routes.get(route)
if route:
for before_request in self.before_ws_requests:
await self.background(before_request, ws=ws)
await self.background(route.endpoint, ws)
else:
await send({"type": "websocket.close", "code": 1000})
def add_schema(self, name, schema, check_existing=True):
"""Adds a mashmallow schema to the API specification."""
if check_existing:
@@ -355,97 +273,17 @@ class API:
:param path: The path portion of a URL, to test all known routes against.
"""
for (route, route_object) in self.routes.items():
if route_object.does_match(path):
for route in self.router.routes:
match, _ = route.matches(path)
if match:
return route
@property
def _signer(self):
return itsdangerous.Signer(self.secret_key)
def _prepare_session(self, resp):
if resp.session:
data = self._signer.sign(
b64encode(json.dumps(resp.session).encode("utf-8"))
)
resp.cookies[self.session_cookie] = data.decode("utf-8")
@staticmethod
def no_response(req, resp, **params):
pass
async def _execute_route(self, *, route, req, resp, **options):
params = route.incoming_matches(req.url.path)
cont = True
if route.is_function:
try:
try:
# Run the view.
r = self.background(route.endpoint, req, resp, **params)
# If it's async, await it.
if hasattr(r, "cr_running"):
await r
except TypeError as e:
cont = True
except Exception:
await self.background(self.default_response, req, resp, error=True)
raise
if route.is_class_based or cont:
try:
view = route.endpoint(**params)
except TypeError:
try:
view = route.endpoint()
except TypeError:
view = route.endpoint
pass
# Run on_request first.
try:
# Run the view.
r = getattr(view, "on_request", self.no_response)
r = self.background(r, req, resp, **params)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception:
await self.background(self.default_response, req, resp, error=True)
raise
# Then run on_method.
method = req.method
try:
# Run the view.
r = getattr(view, f"on_{method}", self.no_response)
r = self.background(r, req, resp, **params)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception:
await self.background(self.default_response, req, resp, error=True)
raise
def add_event_handler(self, event_type, handler):
"""Adds an event handler to the API.
:param event_type: A string in ("startup", "shutdown")
:param handler: The function to run. Can be either a function or a coroutine.
"""
self.lifespan_handler.add_event_handler(event_type, handler)
def add_route(
self,
route=None,
endpoint=None,
*,
default=False,
static=False,
check_existing=True,
websocket=False,
before_request=False,
@@ -456,70 +294,18 @@ class API:
:param endpoint: The endpoint for the route -- can be a callable, or a class.
:param default: If ``True``, all unknown requests will route to this view.
:param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
:param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
"""
if before_request:
if websocket:
self.before_requests.setdefault("ws", []).append(endpoint)
else:
self.before_requests.setdefault("http", []).append(endpoint)
return
if route is None:
route = f"/{uuid4().hex}"
if check_existing:
assert route not in self.routes
if static:
assert self.static_dir is not None
if not endpoint:
endpoint = self.static_response
default = True
if default:
self.default_endpoint = endpoint
self.routes[route] = Route(route, endpoint, websocket=websocket)
# TODO: A better data structure or sort it once the app is loaded
self.routes = dict(
sorted(self.routes.items(), key=lambda item: item[1]._weight())
self.router.add_route(
route,
endpoint,
default=default,
websocket=websocket,
before_request=before_request,
)
def default_response(
self, req=None, resp=None, websocket=False, notfound=False, error=False
):
if websocket:
return
if resp.status_code is None:
resp.status_code = 200
if self.default_endpoint and notfound:
self.default_endpoint(req=req, resp=resp)
else:
if notfound:
resp.status_code = status_codes.HTTP_404
resp.text = "Not found."
if error:
resp.status_code = status_codes.HTTP_500
resp.text = "Application error."
def docs_response(self, req, resp):
resp.html = self.docs
def static_response(self, req, resp):
assert self.static_dir is not None
index = (self.static_dir / "index.html").resolve()
if os.path.exists(index):
with open(index, "r") as f:
resp.html = f.read()
else:
resp.status_code = status_codes.HTTP_404
resp.text = "Not found."
def schema_response(self, req, resp):
resp.status_code = status_codes.HTTP_200
resp.headers["Content-Type"] = "application/x-yaml"
@@ -529,19 +315,12 @@ class API:
self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
):
"""Redirects a given response to a given location.
:param resp: The Response to mutate.
:param location: The location of the redirect.
:param set_text: If ``True``, sets the Redirect body content automatically.
:param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
"""
# assert resp.status_code.is_300(status_code)
resp.status_code = status_code
if set_text:
resp.text = f"Redirecting to: {location}"
resp.headers.update({"Location": location})
resp.redirect(location, set_text=set_text, status_code=status_code)
def on_event(self, event_type: str, **args):
"""Decorator for registering functions or coroutines to run at certain events
@@ -565,6 +344,15 @@ class API:
return decorator
def add_event_handler(self, event_type, handler):
"""Adds an event handler to the API.
:param event_type: A string in ("startup", "shutdown")
:param handler: The function to run. Can be either a function or a coroutine.
"""
self.router.lifespan_handler.add_event_handler(event_type, handler)
def route(self, route=None, **options):
"""Decorator for creating new routes around function and class definitions.
@@ -577,7 +365,7 @@ class API:
"""
def decorator(f):
self.add_route(route, f, **options)
self.router.add_route(route, f, **options)
return f
return decorator
@@ -588,7 +376,7 @@ class API:
:param route: String representation of the route to be used (shouldn't be parameterized).
:param app: The other WSGI / ASGI app.
"""
self.apps.update({route: app})
self.router.apps.update({route: app})
def session(self, base_url="http://;"):
"""Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.
@@ -600,11 +388,6 @@ class API:
self._session = TestClient(self, base_url=base_url)
return self._session
def _route_for(self, endpoint):
for route_object in self.routes.values():
if endpoint in (route_object.endpoint, route_object.endpoint_name):
return route_object
def url_for(self, endpoint, **params):
# TODO: Absolute_url
"""Given an endpoint, returns a rendered URL for its route.
@@ -612,10 +395,7 @@ class API:
:param endpoint: The route endpoint you're searching for.
:param params: Data to pass into the URL generator (for parameterized URLs).
"""
route_object = self._route_for(endpoint)
if route_object:
return route_object.url(**params)
raise ValueError
return self.router.url_for(endpoint, **params)
def static_url(self, asset):
"""Given a static asset, return its URL path."""
+3
View File
@@ -63,3 +63,6 @@ class GraphQLView:
async def on_request(self, req, resp):
await self.graphql_response(req, resp, self.schema)
async def __call__(self, req, resp):
await self.on_request(req, resp)
+15 -15
View File
@@ -3,16 +3,18 @@ import io
import inspect
import json
import gzip
from urllib.parse import parse_qs
from base64 import b64decode
from http.cookies import SimpleCookie
import chardet
import rfc3986
import graphene
import yaml
from requests.structures import CaseInsensitiveDict
from requests.cookies import RequestsCookieJar
from starlette.datastructures import MutableHeaders
from starlette.requests import Request as StarletteRequest, State
from starlette.responses import (
@@ -20,9 +22,7 @@ from starlette.responses import (
StreamingResponse as StarletteStreamingResponse,
)
from urllib.parse import parse_qs
from .status_codes import HTTP_200
from .status_codes import HTTP_200, HTTP_301
from .statics import DEFAULT_ENCODING
@@ -105,9 +105,9 @@ class Request:
"_cookies",
]
def __init__(self, scope, receive, api=None):
def __init__(self, scope, receive, api=None, formats=None):
self._starlette = StarletteRequest(scope, receive)
self.formats = None
self.formats = formats
self._encoding = None
self.api = api
self._content = None
@@ -122,14 +122,7 @@ class Request:
@property
def session(self):
"""The session data, in dict form, from the Request."""
if self.api.session_cookie in self.cookies:
data = self.cookies[self.api.session_cookie]
data = self.api._signer.unsign(data)
data = b64decode(data)
return json.loads(data)
return {}
return self._starlette.session
@property
def headers(self):
@@ -158,6 +151,7 @@ class Request:
@property
def cookies(self):
"""The cookies sent in the Request, as a dictionary."""
return self._starlette.cookies
if self._cookies is None:
cookies = RequestsCookieJar()
cookie_header = self.headers.get("Cookie", "")
@@ -300,7 +294,7 @@ class Response:
self.formats = formats
self.cookies = SimpleCookie() #: The cookies set in the Response
self.session = (
req.session.copy()
req.session
) #: The cookie-based session data, in dict form, to add to the Response.
# Property or func/dec
@@ -311,6 +305,12 @@ class Response:
return func
def redirect(self, location, *, set_text=True, status_code=HTTP_301):
self.status_code = status_code
if set_text:
self.text = f"Redirecting to: {location}"
self.headers.update({"Location": location})
@property
async def body(self):
if self._stream is not None:
+280 -55
View File
@@ -1,43 +1,76 @@
import asyncio
import json
import re
import functools
import inspect
from parse import parse, with_pattern
from starlette.routing import Lifespan
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.websockets import WebSocket, WebSocketClose
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from .models import Request, Response
from . import status_codes
from .formats import get_formats
from .statics import DEFAULT_SESSION_COOKIE
def _make_convertor(type, pattern):
@with_pattern(pattern)
def inner(value):
return type(value)
return inner
_convertors = {
"int": _make_convertor(int, r"\d+"),
"str": _make_convertor(str, r"[^/]+"),
"float": _make_convertor(float, r"\d+(.\d+)?"),
_CONVERTORS = {
"int": (int, r"\d+"),
"str": (str, r"[^/]+"),
"float": (float, r"\d+(.\d+)?"),
}
PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
class Route:
_param_pattern = re.compile(r"{([^{}]*)}")
def __init__(self, route, endpoint, *, websocket=False, before_request=False):
def compile_path(path):
path_re = "^"
param_convertors = {}
idx = 0
for match in PARAM_RE.finditer(path):
param_name, convertor_type = match.groups(default="str")
convertor_type = convertor_type.lstrip(":")
assert (
convertor_type in _CONVERTORS.keys()
), f"Unknown path convertor '{convertor_type}'"
convertor, convertor_re = _CONVERTORS[convertor_type]
path_re += path[idx : match.start()]
path_re += rf"(?P<{param_name}>{convertor_re})"
param_convertors[param_name] = convertor
idx = match.end()
path_re += path[idx:] + "$"
return re.compile(path_re), param_convertors
class BaseRoute:
def matches(self, scope):
raise NotImplementedError()
async def __call__(self, scope, receive, send):
raise NotImplementedError()
class Route(BaseRoute):
def __init__(self, route, endpoint, *, before_request=False):
assert route.startswith("/"), "Route path must start with '/'"
self.route = route
self.endpoint = endpoint
self.uses_websocket = websocket
self.before_request = before_request
self.path_re, self.param_convertors = compile_path(route)
def __repr__(self):
return f"<Route {self.route!r}={self.endpoint!r}>"
def __eq__(self, other):
if hasattr(other, "route"):
# Being compared to other routes.
return self.route == other.route
else:
# Strings.
return self.does_match(other)
def url(self, **params):
return self.route.format(**params)
@property
def endpoint_name(self):
@@ -47,46 +80,238 @@ class Route:
def description(self):
return self.endpoint.__doc__
@property
def has_parameters(self):
return bool(self._param_pattern.search(self.route))
def matches(self, scope):
if scope["type"] != "http":
return False, {}
@functools.lru_cache(maxsize=None)
def does_match(self, s):
if s == self.route:
return True
path = scope["path"]
match = self.path_re.match(path)
named = self.incoming_matches(s)
return bool(len(named))
if match is None:
return False, {}
@functools.lru_cache(maxsize=None)
def incoming_matches(self, s):
results = parse(self.route, s, _convertors)
return results.named if results else {}
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key](value)
return True, {"path_params": {**matched_params}}
async def __call__(self, scope, receive, send):
request = Request(scope, receive, formats=get_formats())
response = Response(req=request, formats=get_formats())
path_params = scope.get("path_params", {})
before_requests = scope.get("before_requests", [])
for before_request in before_requests.get("http", []):
if asyncio.iscoroutinefunction(before_request):
await before_request(request, response)
else:
await run_in_threadpool(before_request, request, response)
views = []
if inspect.isclass(self.endpoint):
endpoint = self.endpoint()
on_request = getattr(endpoint, "on_request", None)
if on_request:
views.append(on_request)
method_name = f"on_{request.method}"
try:
view = getattr(endpoint, method_name)
views.append(view)
except AttributeError:
if on_request is None:
raise HTTPException(status_code=status_codes.HTTP_405)
else:
views.append(self.endpoint)
for view in views:
# "Monckey patch" for graphql: explicitly checking __call__
if asyncio.iscoroutinefunction(view) or asyncio.iscoroutinefunction(
view.__call__
):
await view(request, response, **path_params)
else:
await run_in_threadpool(view, request, response, **path_params)
if response.status_code is None:
response.status_code = status_codes.HTTP_200
print("here", response)
await response(scope, receive, send)
def __eq__(self, other):
# [TODO] compare to str ?
return self.route == other.route and self.endpoint == other.endpoint
def __hash__(self):
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
class WebSocketRoute(BaseRoute):
def __init__(self, route, endpoint, *, before_request=False):
assert route.startswith("/"), "Route path must start with '/'"
self.route = route
self.endpoint = endpoint
self.before_request = before_request
self.path_re, self.param_convertors = compile_path(route)
def __repr__(self):
return f"<Route {self.route!r}={self.endpoint!r}>"
def url(self, **params):
return self.route.format(**params)
def _weight(self):
params = set(self._param_pattern.findall(self.route))
params_count = len(params)
w = len(self.route.rsplit("}", 1)[-1].strip("/"))
return params_count != 0, w == 0, -params_count
@property
def endpoint_name(self):
return self.endpoint.__name__
@property
def is_class_based(self):
return inspect.isclass(self.endpoint)
def description(self):
return self.endpoint.__doc__
@property
def is_function(self):
code = hasattr(self.endpoint, "__code__")
kwdefaults = hasattr(self.endpoint, "__kwdefaults__")
return all((callable(self.endpoint), code, kwdefaults))
def matches(self, scope):
if scope["type"] != "websocket":
return False, {}
path = scope["path"]
match = self.path_re.match(path)
if match is None:
return False, {}
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key](value)
return True, {"path_params": {**matched_params}}
async def __call__(self, scope, receive, send):
ws = WebSocket(scope, receive, send)
before_requests = scope.get("before_requests", [])
for before_request in before_requests.get("ws", []):
await before_request(ws)
await self.endpoint(ws)
def __hash__(self):
return (
hash(self.route)
^ hash(self.endpoint)
^ hash(self.uses_websocket)
^ hash(self.before_request)
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
class Router:
def __init__(self, routes=None, default_response=None, before_requests=None):
self.routes = [] if routes is None else list(routes)
# [TODO] Make it's own router
self.apps = {}
self.default_endpoint = (
self.default_response if default_response is None else default_response
)
self.lifespan_handler = Lifespan()
self.before_requests = (
{"http": [], "ws": []} if before_requests is None else before_requests
)
def add_route(
self,
route=None,
endpoint=None,
*,
default=False,
websocket=False,
before_request=False,
):
""" Adds a route to the router.
:param route: A string representation of the route
:param endpoint: The endpoint for the route -- can be callable, or class.
:param default: If ``True``, all unknown requests will route to this view.
"""
if before_request:
if websocket:
self.before_requests.setdefault("ws", []).append(endpoint)
else:
self.before_requests.setdefault("http", []).append(endpoint)
return
if default:
self.default_endpoint = endpoint
if websocket:
route = WebSocketRoute(route, endpoint)
else:
route = Route(route, endpoint)
self.routes.append(route)
def mount(self, route, app):
"""Mounts ASGI / WSGI applications at a given route
"""
self.apps.update(route, app)
def before_request(self, endpoint, websocket=False):
if websocket:
self.before_requests.setdefault("ws", []).append(endpoint)
else:
self.before_requests.setdefault("http", []).append(endpoint)
def url_for(self, endpoint, **params):
# TODO: Check for params
for route in self.routes:
if endpoint in (route.endpoint, route.endpoint.__name__):
return route.url(**params)
return None
async def default_response(self, scope, receive, send):
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(receive, send)
return
request = Request(scope, receive)
response = Response(request, formats=get_formats())
raise HTTPException(status_code=status_codes.HTTP_404)
def _resolve_route(self, scope):
for route in self.routes:
matches, child_scope = route.matches(scope)
if matches:
scope.update(child_scope)
return route
return None
async def __call__(self, scope, receive, send):
assert scope["type"] in ("http", "websocket", "lifespan")
if scope["type"] == "lifespan":
await self.lifespan_handler(scope, receive, send)
return
path = scope["path"]
root_path = scope.get("root_path", "")
# Call into a submounted app, if one exists.
for path_prefix, app in self.apps.items():
if path.startswith(path_prefix):
scope["path"] = path[len(path_prefix) :]
scope["root_path"] = root_path + path_prefix
try:
await app(scope, receive, send)
return
except TypeError:
app = WSGIMiddleware(app)
await app(scope, receive, send)
return
route = self._resolve_route(scope)
scope["before_requests"] = self.before_requests
if route is not None:
await route(scope, receive, send)
return
await self.default_response(scope, receive, send)
+5 -39
View File
@@ -19,6 +19,7 @@ def test_api_basic_route(api):
resp.text = "hello world!"
"""
def test_api_basic_route_overlap(api):
@api.route("/")
def home(req, resp):
@@ -30,40 +31,6 @@ def test_api_basic_route_overlap(api):
def home2(req, resp):
resp.text = "hello world!"
def test_api_basic_route_overlap_alternative(api):
@api.route("/")
def home(req, resp):
resp.text = "hello world!"
def home2(req, resp):
resp.text = "hello world!"
with pytest.raises(AssertionError):
api.add_route("/", home2)
def test_api_basic_route_overlap_allowed(api):
@api.route("/")
def home(req, resp):
resp.text = "hello world!"
def home2(req, resp):
resp.text = "hello world!"
api.add_route("/", home2, check_existing=False)
def test_api_basic_route_overlap_allowed_alternative(api):
@api.route("/")
def home(req, resp):
resp.text = "hello world!"
@api.route("/", check_existing=False)
def home2(req, resp):
resp.text = "hello world!"
def test_class_based_view_registration(api):
@api.route("/")
class ThingsResource:
@@ -78,6 +45,7 @@ def test_class_based_view_parameters(api):
resp.text = f"{greeting}, world!"
assert api.session().get("http://;/Hello").ok
"""
def test_requests_session(api):
@@ -85,14 +53,14 @@ def test_requests_session(api):
assert api.requests
def test_requests_session_works(api, url):
def test_requests_session_works(api):
TEXT = "spiral out"
@api.route("/")
def hello(req, resp):
resp.text = TEXT
assert api.requests.get(url("/")).text == TEXT
assert api.requests.get("/").text == TEXT
def test_status_code(api):
@@ -342,9 +310,7 @@ def test_schema_generation():
import responder
from marshmallow import Schema, fields
api = responder.API(
title="Web Service", openapi="3.0.2", allowed_hosts=["testserver", ";"]
)
api = responder.API(title="Web Service", openapi="3.0.2")
@api.schema("Pet")
class PetSchema(Schema):
-159
View File
@@ -1,159 +0,0 @@
import pytest
from responder import routes
def setup_function(function):
routes.Route.incoming_matches.cache_clear()
@pytest.mark.parametrize(
"route, expected",
[
pytest.param("/", False, id="home path without params"),
pytest.param("/test_path", False, id="sub path without params"),
pytest.param("/{test_path}", True, id="path with params"),
],
)
def test_parameter(route, expected):
r = routes.Route(route, "test_endpoint")
assert r.has_parameters is expected
def test_url():
r = routes.Route("/{my_path}", "test_endpoint")
url = r.url(my_path="path")
assert url == "/path"
def test_equal():
r = routes.Route("/{path_param}", "test_endpoint")
r2 = routes.Route("/{path_param}", "test_endpoint")
r3 = routes.Route("/test_path", "test_endpoint")
assert r == r2
assert r != r3
def test_incoming_matches():
# Test Route with one param
r = routes.Route("/{greetings}", "test_endpoint")
assert r.incoming_matches("/hello") == {"greetings": "hello"}
assert r.incoming_matches("/foo") == {"greetings": "foo"}
# Test Route with two params
r = routes.Route("/{greetings}/{name}", "test_endpoint")
assert r.incoming_matches("/hi/john") == {"greetings": "hi", "name": "john"}
assert r.incoming_matches("/hello/jane") == {"greetings": "hello", "name": "jane"}
# Test Route with no param
r = routes.Route("/hello", "test_endpoint")
assert r.incoming_matches("/hello") == {}
assert r.incoming_matches("/bye") == {}
def test_incoming_matches_cache():
r = routes.Route("/hello", "test_endpoint")
r.incoming_matches("/hello")
assert r.incoming_matches.cache_info().hits == 0
r.incoming_matches("/hello")
assert r.incoming_matches.cache_info().hits == 1
def test_incoming_matches_with_concrete_path_no_match():
r = routes.Route("/concrete_path", "test_endpoint")
assert r.incoming_matches("hello") == {}
@pytest.mark.parametrize(
"route, match, expected",
[
pytest.param(
"/{path_param}",
"/{path_param}",
True,
id="with both parametrized path match",
),
pytest.param(
"/concrete", "/concrete", True, id="with both concrete path match"
),
pytest.param("/concrete", "/no_match", False, id="with no match"),
],
)
def test_does_match_with_route(route, match, expected):
r = routes.Route(route, "test_endpoint")
assert r.does_match(match) == expected
@pytest.mark.parametrize(
"path_param, expected_weight",
[
pytest.param("/{greetings}", (True, True, -1), id="with one param"),
pytest.param(
"/{greetings}.{name}",
(True, True, -2),
id="with 2 params and dot in the middle",
),
pytest.param(
"/{greetings}/{name}", (True, True, -2), id="with 2 params and subpath"
),
pytest.param(
"/{greetings}/{name}/{hello}",
(True, True, -3),
id="with 3 params and subpath",
),
pytest.param(
"/{greetings}_{name}", (True, True, -2), id="with 2 params and underscore"
),
pytest.param("/{greetings}/test", (True, False, -1), id="with one param"),
pytest.param(
"/{greetings}.{name}/test",
(True, False, -2),
id="with 2 params and dot in the middle",
),
pytest.param(
"/{greetings}/{name}/test",
(True, False, -2),
id="with 2 params and subpath",
),
pytest.param(
"/{greetings}/{name}/{hello}/test",
(True, False, -3),
id="with 3 params and subpath",
),
pytest.param(
"/{greetings}_{name}/test",
(True, False, -2),
id="with 2 params and underscore",
),
pytest.param("/hello", (False, False, 0), id="without params"),
],
)
def test_weight(path_param, expected_weight):
r = routes.Route(path_param, "test_endpoint")
assert r._weight() == expected_weight
@pytest.mark.parametrize(
"route, path, expected_result",
[
pytest.param("/{greetings:str}", "/hello", {"greetings": "hello"}),
pytest.param(
"/{greetings:str}/{who}",
"/hello/Laidia",
{"greetings": "hello", "who": "Laidia"},
),
pytest.param("/{birth_date:int}", "/1937", {"birth_date": 1937}),
pytest.param(
"/{name:str}/{age:int}", "/Fatna/80", {"name": "Fatna", "age": 80}
),
pytest.param(
"/{x:float}/{y:float}", "/10.20/75", {"x": float(10.20), "y": float(75)}
),
pytest.param("/{name:str}/{age:int}", "/Fatna/eighty", {}),
pytest.param("/{greetings:int}", "/hello", {}),
pytest.param("/{name:float}", "/Fatna", {}),
],
)
def test_custom_specifiers(route, path, expected_result):
r = routes.Route(route, "test_endpoint")
assert r.incoming_matches(path) == expected_result