From 74c872ed5721921b3974b2db1fc941dbc6a58c7e Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sun, 22 Mar 2026 21:47:21 -0400 Subject: [PATCH] Add type annotations to routes.py Add comprehensive type hints to compile_path, BaseRoute, Route, WebSocketRoute, and Router classes. Uses Starlette's Scope, Receive, Send types and properly types the ASGI/WSGI union in Router.apps. Fixes #566. Co-Authored-By: Claude Opus 4.6 (1M context) --- responder/routes.py | 125 ++++++++++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 50 deletions(-) diff --git a/responder/routes.py b/responder/routes.py index 8cc03ed..954e858 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import asyncio import inspect import re import traceback from collections import defaultdict +from collections.abc import Callable +from typing import Any, Union __all__ = ["Route", "WebSocketRoute", "Router"] from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose from . import status_codes @@ -28,9 +32,9 @@ _CONVERTORS = { PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") -def compile_path(path): +def compile_path(path: str) -> tuple[re.Pattern, dict[str, type]]: path_re = "^" - param_convertors = {} + param_convertors: dict[str, type] = {} idx = 0 for match in PARAM_RE.finditer(path): @@ -54,10 +58,10 @@ def compile_path(path): class BaseRoute: - def matches(self, scope): + def matches(self, scope: Scope) -> tuple[bool, dict]: raise NotImplementedError() - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: raise NotImplementedError() @@ -68,32 +72,41 @@ class Route(BaseRoute): ``{pk:uuid}``, ``{value:float}``, ``{rest:path}``). """ - def __init__(self, route, endpoint, *, before_request=False, methods=None): + def __init__( + self, + route: str, + endpoint: Callable, + *, + before_request: bool = False, + methods: list[str] | None = None, + ) -> None: assert route.startswith("/"), "Route path must start with '/'" self.route = route self.endpoint = endpoint self.before_request = before_request - self.methods = {m.upper() for m in methods} if methods else None + self.methods: set[str] | None = {m.upper() for m in methods} if methods else None + self.path_re: re.Pattern + self.param_convertors: dict[str, type] self.path_re, self.param_convertors = compile_path(route) # Strip type annotations for URL generation (e.g. {id:int} -> {id}) self._url_template = PARAM_RE.sub(r"{\1}", route) - def __repr__(self): + def __repr__(self) -> str: return f"" - def url(self, **params): + def url(self, **params: Any) -> str: return self._url_template.format(**params) @property - def endpoint_name(self): + def endpoint_name(self) -> str: return self.endpoint.__name__ @property - def description(self): + def description(self) -> str | None: return self.endpoint.__doc__ - def matches(self, scope): + def matches(self, scope: Scope) -> tuple[bool, dict]: if scope["type"] != "http": return False, {} @@ -112,7 +125,7 @@ class Route(BaseRoute): return True, {"path_params": {**matched_params}} - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive, formats=get_formats()) response = Response(req=request, formats=get_formats()) @@ -195,40 +208,46 @@ class Route(BaseRoute): await response(scope, receive, send) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Route): + return NotImplemented return self.route == other.route and self.endpoint == other.endpoint - def __hash__(self): + def __hash__(self) -> int: return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request) class WebSocketRoute(BaseRoute): """A WebSocket route that maps a URL pattern to a WebSocket handler.""" - def __init__(self, route, endpoint, *, before_request=False): + def __init__( + self, route: str, endpoint: Callable, *, before_request: bool = False + ) -> None: assert route.startswith("/"), "Route path must start with '/'" self.route = route self.endpoint = endpoint self.before_request = before_request + self.path_re: re.Pattern + self.param_convertors: dict[str, type] self.path_re, self.param_convertors = compile_path(route) self._url_template = PARAM_RE.sub(r"{\1}", route) - def __repr__(self): + def __repr__(self) -> str: return f"" - def url(self, **params): + def url(self, **params: Any) -> str: return self._url_template.format(**params) @property - def endpoint_name(self): + def endpoint_name(self) -> str: return self.endpoint.__name__ @property - def description(self): + def description(self) -> str | None: return self.endpoint.__doc__ - def matches(self, scope): + def matches(self, scope: Scope) -> tuple[bool, dict]: if scope["type"] != "websocket": return False, {} @@ -244,7 +263,7 @@ class WebSocketRoute(BaseRoute): return True, {"path_params": {**matched_params}} - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ws = WebSocket(scope, receive, send) before_requests = scope.get("before_requests", []) @@ -253,10 +272,12 @@ class WebSocketRoute(BaseRoute): await self.endpoint(ws) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, WebSocketRoute): + return NotImplemented return self.route == other.route and self.endpoint == other.endpoint - def __hash__(self): + def __hash__(self) -> int: return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request) @@ -268,32 +289,36 @@ class Router: """ def __init__( - self, routes=None, default_response=None, before_requests=None, lifespan=None - ): - self.routes = [] if routes is None else list(routes) + self, + routes: list[BaseRoute] | None = None, + default_response: Callable | None = None, + before_requests: dict[str, list[Callable]] | None = None, + lifespan: Callable | None = None, + ) -> None: + self.routes: list[BaseRoute] = [] if routes is None else list(routes) - self.apps: dict[str, ASGIApp] = {} - self.default_endpoint = ( + self.apps: dict[str, Union[ASGIApp, Any]] = {} + self.default_endpoint: Callable = ( self.default_response if default_response is None else default_response ) - self.before_requests = ( + self.before_requests: dict[str, list[Callable]] = ( {"http": [], "ws": []} if before_requests is None else before_requests ) - self.after_requests: list = [] - self.events = defaultdict(list) + self.after_requests: list[Callable] = [] + self.events: defaultdict[str, list[Callable]] = defaultdict(list) self._lifespan_handler = lifespan def add_route( self, - route=None, - endpoint=None, + route: str | None = None, + endpoint: Callable | None = None, *, - default=False, - websocket=False, - before_request=False, - check_existing=False, - methods=None, - ): + default: bool = False, + websocket: bool = False, + before_request: bool = False, + check_existing: bool = False, + methods: list[str] | None = None, + ) -> None: """Adds a route to the router. :param route: A string representation of the route :param endpoint: The endpoint for the route -- can be callable, or class. @@ -322,40 +347,40 @@ class Router: self.routes.append(route) - def mount(self, route, app): + def mount(self, route: str, app: Any) -> None: """Mounts ASGI / WSGI applications at a given route""" self.apps.update({route: app}) - def add_event_handler(self, event_type, handler): + def add_event_handler(self, event_type: str, handler: Callable) -> None: assert event_type in ( "startup", "shutdown", ), f"Only 'startup' and 'shutdown' events are supported, not {event_type}." self.events[event_type].append(handler) - async def trigger_event(self, event_type): + async def trigger_event(self, event_type: str) -> None: for handler in self.events.get(event_type, []): if asyncio.iscoroutinefunction(handler): await handler() else: handler() - def before_request(self, endpoint, websocket=False): + def before_request(self, endpoint: Callable, websocket: bool = False) -> None: if websocket: self.before_requests.setdefault("ws", []).append(endpoint) else: self.before_requests.setdefault("http", []).append(endpoint) - def after_request(self, endpoint): + def after_request(self, endpoint: Callable) -> None: self.after_requests.append(endpoint) - def url_for(self, endpoint, **params): + def url_for(self, endpoint: Callable | str, **params: Any) -> str | None: for route in self.routes: if endpoint in (route.endpoint, route.endpoint.__name__): return route.url(**params) return None - async def default_response(self, scope, receive, send): + async def default_response(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": websocket_close = WebSocketClose() await websocket_close(scope, receive, send) @@ -366,7 +391,7 @@ class Router: raise HTTPException(status_code=status_codes.HTTP_404) # type: ignore[attr-defined] - def _resolve_route(self, scope): + def _resolve_route(self, scope: Scope) -> BaseRoute | None: for route in self.routes: matches, child_scope = route.matches(scope) if matches: @@ -374,7 +399,7 @@ class Router: return route return None - async def lifespan(self, scope, receive, send): + async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: message = await receive() assert message["type"] == "lifespan.startup" @@ -409,7 +434,7 @@ class Router: await send({"type": "lifespan.shutdown.complete"}) - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] in ("http", "websocket", "lifespan") if scope["type"] == "lifespan":