mirror of
https://github.com/kennethreitz/responder.git
synced 2026-06-21 15:00:57 +00:00
61f7f24256
Implements the remaining backlog features:
- Dependency injection: register providers with @api.dependency() and
declare them as view parameters by name. Supports sync/async functions
and generators (post-yield code runs as teardown after the response).
Providers taking a parameter receive the current Request. Resolved at
most once per request; path params take precedence.
- Per-route rate limiting via RateLimiter.limit decorator
- WebSocket before-request hooks can reject connections (closing the
socket skips the handler) and may now be sync functions
Also fixes bugs found along the way:
- float path convertor regex had an unescaped dot, matching garbage
like "1a5" and crashing with a 500
- literal route characters weren't regex-escaped (/file.json matched
/fileXjson)
- BackgroundQueue.results grew without bound; completed futures are
now pruned
- req.media("form") crashed when Content-Type was missing
- custom formats registered on api.formats were ignored; they now
thread through the router to request parsing and response negotiation
- Accept headers matching encode-incapable formats (e.g. form) returned
an empty body; negotiation now falls through to JSON
Performance: Request.url and Request.params are cached; format
registries are no longer rebuilt twice per request.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
156 lines
4.5 KiB
Python
156 lines
4.5 KiB
Python
"""Tests for route matching fixes, per-route rate limiting, WebSocket
|
|
short-circuiting, and custom format registration."""
|
|
|
|
import pytest
|
|
from starlette.testclient import TestClient as StarletteTestClient
|
|
from starlette.websockets import WebSocketDisconnect
|
|
|
|
from responder.ext.ratelimit import RateLimiter
|
|
|
|
# --- route matching fixes ---
|
|
|
|
|
|
def test_float_convertor_rejects_garbage(api):
|
|
@api.route("/measure/{value:float}")
|
|
def view(req, resp, *, value):
|
|
resp.media = {"value": value}
|
|
|
|
assert api.requests.get("/measure/1.5").json() == {"value": 1.5}
|
|
assert api.requests.get("/measure/10").json() == {"value": 10.0}
|
|
# "1a5" must not match the float pattern (previously the unescaped
|
|
# dot allowed any character and the convertor crashed with a 500).
|
|
assert api.requests.get("/measure/1a5").status_code == 404
|
|
|
|
|
|
def test_literal_dots_in_routes_are_escaped(api):
|
|
@api.route("/file.json")
|
|
def view(req, resp):
|
|
resp.text = "ok"
|
|
|
|
assert api.requests.get("/file.json").status_code == 200
|
|
# A literal "." in the route must not act as a regex wildcard.
|
|
assert api.requests.get("/fileXjson").status_code == 404
|
|
|
|
|
|
# --- per-route rate limiting ---
|
|
|
|
|
|
def test_per_route_rate_limit(api):
|
|
limiter = RateLimiter(requests=2, period=60)
|
|
|
|
@api.route("/limited")
|
|
@limiter.limit
|
|
def limited(req, resp):
|
|
resp.text = "ok"
|
|
|
|
@api.route("/open")
|
|
def unlimited(req, resp):
|
|
resp.text = "always"
|
|
|
|
assert api.requests.get("/limited").status_code == 200
|
|
assert api.requests.get("/limited").status_code == 200
|
|
third = api.requests.get("/limited")
|
|
assert third.status_code == 429
|
|
assert "Retry-After" in third.headers
|
|
|
|
# Other routes are unaffected.
|
|
assert api.requests.get("/open").status_code == 200
|
|
|
|
|
|
def test_per_route_rate_limit_async(api):
|
|
limiter = RateLimiter(requests=1, period=60)
|
|
|
|
@api.route("/limited")
|
|
@limiter.limit
|
|
async def limited(req, resp):
|
|
resp.text = "ok"
|
|
|
|
assert api.requests.get("/limited").status_code == 200
|
|
assert api.requests.get("/limited").status_code == 429
|
|
|
|
|
|
# --- WebSocket before_request short-circuit ---
|
|
|
|
|
|
def test_websocket_before_request_short_circuit(api):
|
|
endpoint_called = []
|
|
|
|
@api.before_request(websocket=True)
|
|
async def reject_unauthorized(ws):
|
|
if "Authorization" not in ws.headers:
|
|
await ws.close(code=4401)
|
|
|
|
@api.route("/ws", websocket=True)
|
|
async def ws_endpoint(ws):
|
|
endpoint_called.append(True)
|
|
await ws.accept()
|
|
await ws.send_text("hello")
|
|
await ws.close()
|
|
|
|
client = StarletteTestClient(api)
|
|
|
|
# Without auth: the hook closes the socket and the endpoint never runs.
|
|
with pytest.raises(WebSocketDisconnect):
|
|
with client.websocket_connect("ws://;/ws"):
|
|
pass
|
|
assert endpoint_called == []
|
|
|
|
# With auth: the connection proceeds normally.
|
|
with client.websocket_connect(
|
|
"ws://;/ws", headers={"Authorization": "Bearer token"}
|
|
) as ws:
|
|
assert ws.receive_text() == "hello"
|
|
assert endpoint_called == [True]
|
|
|
|
|
|
def test_websocket_sync_before_request(api):
|
|
seen = []
|
|
|
|
@api.before_request(websocket=True)
|
|
def observe(ws):
|
|
seen.append(ws.url.path)
|
|
|
|
@api.route("/ws", websocket=True)
|
|
async def ws_endpoint(ws):
|
|
await ws.accept()
|
|
await ws.send_text("hi")
|
|
await ws.close()
|
|
|
|
client = StarletteTestClient(api)
|
|
with client.websocket_connect("ws://;/ws") as ws:
|
|
assert ws.receive_text() == "hi"
|
|
assert seen == ["/ws"]
|
|
|
|
|
|
# --- custom formats ---
|
|
|
|
|
|
def test_custom_format_registration(api):
|
|
async def format_csv(r, encode=False):
|
|
if encode:
|
|
r.headers["Content-Type"] = "text/csv"
|
|
rows = r.media
|
|
return "\n".join(",".join(str(v) for v in row) for row in rows)
|
|
return [line.split(",") for line in (await r.text).splitlines()]
|
|
|
|
api.formats["csv"] = format_csv
|
|
|
|
@api.route("/report")
|
|
def report(req, resp):
|
|
resp.media = [["a", 1], ["b", 2]]
|
|
|
|
r = api.requests.get("/report", headers={"Accept": "text/csv"})
|
|
assert r.headers["Content-Type"].startswith("text/csv")
|
|
assert r.text == "a,1\nb,2"
|
|
|
|
|
|
def test_form_accept_header_falls_back_to_json(api):
|
|
@api.route("/data")
|
|
def data(req, resp):
|
|
resp.media = {"key": "value"}
|
|
|
|
# "form" can't encode responses; negotiation should fall through to JSON
|
|
# instead of returning an empty body.
|
|
r = api.requests.get("/data", headers={"Accept": "multipart/form-data"})
|
|
assert r.json() == {"key": "value"}
|