Add type annotations to routes.py

Add comprehensive type hints to compile_path, BaseRoute, Route,
WebSocketRoute, and Router classes. Uses Starlette's Scope, Receive,
Send types and properly types the ASGI/WSGI union in Router.apps.
Fixes #566.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-22 21:47:21 -04:00
parent 724b769c9e
commit 74c872ed57
+75 -50
View File
@@ -1,14 +1,18 @@
from __future__ import annotations
import asyncio
import inspect
import re
import traceback
from collections import defaultdict
from collections.abc import Callable
from typing import Any, Union
__all__ = ["Route", "WebSocketRoute", "Router"]
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
from . import status_codes
@@ -28,9 +32,9 @@ _CONVERTORS = {
PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
def compile_path(path):
def compile_path(path: str) -> tuple[re.Pattern, dict[str, type]]:
path_re = "^"
param_convertors = {}
param_convertors: dict[str, type] = {}
idx = 0
for match in PARAM_RE.finditer(path):
@@ -54,10 +58,10 @@ def compile_path(path):
class BaseRoute:
def matches(self, scope):
def matches(self, scope: Scope) -> tuple[bool, dict]:
raise NotImplementedError()
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
raise NotImplementedError()
@@ -68,32 +72,41 @@ class Route(BaseRoute):
``{pk:uuid}``, ``{value:float}``, ``{rest:path}``).
"""
def __init__(self, route, endpoint, *, before_request=False, methods=None):
def __init__(
self,
route: str,
endpoint: Callable,
*,
before_request: bool = False,
methods: list[str] | None = None,
) -> None:
assert route.startswith("/"), "Route path must start with '/'"
self.route = route
self.endpoint = endpoint
self.before_request = before_request
self.methods = {m.upper() for m in methods} if methods else None
self.methods: set[str] | None = {m.upper() for m in methods} if methods else None
self.path_re: re.Pattern
self.param_convertors: dict[str, type]
self.path_re, self.param_convertors = compile_path(route)
# Strip type annotations for URL generation (e.g. {id:int} -> {id})
self._url_template = PARAM_RE.sub(r"{\1}", route)
def __repr__(self):
def __repr__(self) -> str:
return f"<Route {self.route!r}={self.endpoint!r}>"
def url(self, **params):
def url(self, **params: Any) -> str:
return self._url_template.format(**params)
@property
def endpoint_name(self):
def endpoint_name(self) -> str:
return self.endpoint.__name__
@property
def description(self):
def description(self) -> str | None:
return self.endpoint.__doc__
def matches(self, scope):
def matches(self, scope: Scope) -> tuple[bool, dict]:
if scope["type"] != "http":
return False, {}
@@ -112,7 +125,7 @@ class Route(BaseRoute):
return True, {"path_params": {**matched_params}}
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, formats=get_formats())
response = Response(req=request, formats=get_formats())
@@ -195,40 +208,46 @@ class Route(BaseRoute):
await response(scope, receive, send)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Route):
return NotImplemented
return self.route == other.route and self.endpoint == other.endpoint
def __hash__(self):
def __hash__(self) -> int:
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
class WebSocketRoute(BaseRoute):
"""A WebSocket route that maps a URL pattern to a WebSocket handler."""
def __init__(self, route, endpoint, *, before_request=False):
def __init__(
self, route: str, endpoint: Callable, *, before_request: bool = False
) -> None:
assert route.startswith("/"), "Route path must start with '/'"
self.route = route
self.endpoint = endpoint
self.before_request = before_request
self.path_re: re.Pattern
self.param_convertors: dict[str, type]
self.path_re, self.param_convertors = compile_path(route)
self._url_template = PARAM_RE.sub(r"{\1}", route)
def __repr__(self):
def __repr__(self) -> str:
return f"<Route {self.route!r}={self.endpoint!r}>"
def url(self, **params):
def url(self, **params: Any) -> str:
return self._url_template.format(**params)
@property
def endpoint_name(self):
def endpoint_name(self) -> str:
return self.endpoint.__name__
@property
def description(self):
def description(self) -> str | None:
return self.endpoint.__doc__
def matches(self, scope):
def matches(self, scope: Scope) -> tuple[bool, dict]:
if scope["type"] != "websocket":
return False, {}
@@ -244,7 +263,7 @@ class WebSocketRoute(BaseRoute):
return True, {"path_params": {**matched_params}}
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ws = WebSocket(scope, receive, send)
before_requests = scope.get("before_requests", [])
@@ -253,10 +272,12 @@ class WebSocketRoute(BaseRoute):
await self.endpoint(ws)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, WebSocketRoute):
return NotImplemented
return self.route == other.route and self.endpoint == other.endpoint
def __hash__(self):
def __hash__(self) -> int:
return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
@@ -268,32 +289,36 @@ class Router:
"""
def __init__(
self, routes=None, default_response=None, before_requests=None, lifespan=None
):
self.routes = [] if routes is None else list(routes)
self,
routes: list[BaseRoute] | None = None,
default_response: Callable | None = None,
before_requests: dict[str, list[Callable]] | None = None,
lifespan: Callable | None = None,
) -> None:
self.routes: list[BaseRoute] = [] if routes is None else list(routes)
self.apps: dict[str, ASGIApp] = {}
self.default_endpoint = (
self.apps: dict[str, Union[ASGIApp, Any]] = {}
self.default_endpoint: Callable = (
self.default_response if default_response is None else default_response
)
self.before_requests = (
self.before_requests: dict[str, list[Callable]] = (
{"http": [], "ws": []} if before_requests is None else before_requests
)
self.after_requests: list = []
self.events = defaultdict(list)
self.after_requests: list[Callable] = []
self.events: defaultdict[str, list[Callable]] = defaultdict(list)
self._lifespan_handler = lifespan
def add_route(
self,
route=None,
endpoint=None,
route: str | None = None,
endpoint: Callable | None = None,
*,
default=False,
websocket=False,
before_request=False,
check_existing=False,
methods=None,
):
default: bool = False,
websocket: bool = False,
before_request: bool = False,
check_existing: bool = False,
methods: list[str] | None = None,
) -> None:
"""Adds a route to the router.
:param route: A string representation of the route
:param endpoint: The endpoint for the route -- can be callable, or class.
@@ -322,40 +347,40 @@ class Router:
self.routes.append(route)
def mount(self, route, app):
def mount(self, route: str, app: Any) -> None:
"""Mounts ASGI / WSGI applications at a given route"""
self.apps.update({route: app})
def add_event_handler(self, event_type, handler):
def add_event_handler(self, event_type: str, handler: Callable) -> None:
assert event_type in (
"startup",
"shutdown",
), f"Only 'startup' and 'shutdown' events are supported, not {event_type}."
self.events[event_type].append(handler)
async def trigger_event(self, event_type):
async def trigger_event(self, event_type: str) -> None:
for handler in self.events.get(event_type, []):
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()
def before_request(self, endpoint, websocket=False):
def before_request(self, endpoint: Callable, websocket: bool = False) -> None:
if websocket:
self.before_requests.setdefault("ws", []).append(endpoint)
else:
self.before_requests.setdefault("http", []).append(endpoint)
def after_request(self, endpoint):
def after_request(self, endpoint: Callable) -> None:
self.after_requests.append(endpoint)
def url_for(self, endpoint, **params):
def url_for(self, endpoint: Callable | str, **params: Any) -> str | None:
for route in self.routes:
if endpoint in (route.endpoint, route.endpoint.__name__):
return route.url(**params)
return None
async def default_response(self, scope, receive, send):
async def default_response(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
@@ -366,7 +391,7 @@ class Router:
raise HTTPException(status_code=status_codes.HTTP_404) # type: ignore[attr-defined]
def _resolve_route(self, scope):
def _resolve_route(self, scope: Scope) -> BaseRoute | None:
for route in self.routes:
matches, child_scope = route.matches(scope)
if matches:
@@ -374,7 +399,7 @@ class Router:
return route
return None
async def lifespan(self, scope, receive, send):
async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
@@ -409,7 +434,7 @@ class Router:
await send({"type": "lifespan.shutdown.complete"})
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] in ("http", "websocket", "lifespan")
if scope["type"] == "lifespan":