Files
responder/tests/test_new_features.py
Andreas Motl ff6d530338 Chore: Code formatting (#594)
## About
A few cosmetic adjustments aka. code formatting.
Also validate the outcome on CI/GHA.
Feel free to improve now or later at your disposal.

## Details
The updates are based on using the most recent versions of pyproject-fmt
and ruff.
Specifically, spots marked with `noqa` might need further love, also at
your disposal.

---------

Co-authored-by: Kenneth Reitz <me@kennethreitz.org>
2026-03-24 15:21:04 -04:00

288 lines
7.3 KiB
Python

"""Tests for new features: validation, SSE, after_request, route groups, etc."""
from pydantic import BaseModel
import responder
from responder.ext.ratelimit import RateLimiter
# --- Pydantic auto-validation ---
class ItemIn(BaseModel):
name: str
price: float
class ItemOut(BaseModel):
id: int
name: str
price: float
def test_pydantic_request_validation():
"""Auto-validate request body against request_model."""
api = responder.API(allowed_hosts=[";"])
@api.route("/items", methods=["POST"], request_model=ItemIn)
async def create(req, resp):
data = await req.media()
resp.media = {"id": 1, **data}
# Valid request
r = api.requests.post("http://;/items", json={"name": "widget", "price": 9.99})
assert r.status_code == 200
assert r.json()["name"] == "widget"
# Invalid request — missing required field
r = api.requests.post("http://;/items", json={"name": "widget"})
assert r.status_code == 422
assert "errors" in r.json()
# Invalid request — wrong type
r = api.requests.post(
"http://;/items", json={"name": "widget", "price": "not_a_number"}
)
assert r.status_code == 422
def test_pydantic_response_serialization():
"""Auto-serialize response through response_model."""
api = responder.API(allowed_hosts=[";"])
@api.route("/items", methods=["POST"], request_model=ItemIn, response_model=ItemOut)
async def create(req, resp):
data = await req.media()
# Include an extra field that should be stripped by the model
resp.media = {"id": 1, "secret": "hidden", **data}
r = api.requests.post("http://;/items", json={"name": "widget", "price": 9.99})
assert r.status_code == 200
data = r.json()
assert data == {"id": 1, "name": "widget", "price": 9.99}
assert "secret" not in data
def test_pydantic_validation_skipped_for_get():
"""GET requests don't trigger request body validation."""
api = responder.API(allowed_hosts=[";"])
@api.route("/items", methods=["GET"], request_model=ItemIn)
def list_items(req, resp):
resp.media = []
r = api.requests.get("http://;/items")
assert r.status_code == 200
# --- SSE streaming ---
def test_sse_streaming(api):
"""Server-Sent Events with resp.sse."""
@api.route("/events")
async def events(req, resp):
@resp.sse
async def stream():
yield {"data": "hello"}
yield {"event": "update", "data": "world"}
yield "simple"
r = api.requests.get(api.url_for(events))
assert r.status_code == 200
assert "text/event-stream" in r.headers.get("content-type", "")
assert "data: hello" in r.text
assert "event: update" in r.text
assert "data: world" in r.text
assert "data: simple" in r.text
def test_sse_with_id_and_retry(api):
"""SSE events with id and retry fields."""
@api.route("/events")
async def events(req, resp):
@resp.sse
async def stream():
yield {"data": "msg", "id": "1", "retry": "5000"}
r = api.requests.get(api.url_for(events))
assert "id: 1" in r.text
assert "retry: 5000" in r.text
# --- stream_file ---
def test_stream_file(api, tmp_path):
"""Stream a file without loading into memory."""
big_file = tmp_path / "data.bin"
big_file.write_bytes(b"x" * 10000)
@api.route("/download")
def download(req, resp):
resp.stream_file(big_file)
r = api.requests.get(api.url_for(download))
assert len(r.content) == 10000
assert r.content == b"x" * 10000
def test_stream_file_content_type(api, tmp_path):
"""stream_file detects content type."""
css = tmp_path / "style.css"
css.write_text("body { color: red; }")
@api.route("/css")
def serve_css(req, resp):
resp.stream_file(css)
r = api.requests.get(api.url_for(serve_css))
assert "text/css" in r.headers.get("content-type", "")
# --- after_request hooks ---
def test_after_request(api):
"""after_request hook runs after route handler."""
@api.after_request()
def add_header(req, resp):
resp.headers["X-After"] = "yes"
@api.route("/")
def view(req, resp):
resp.text = "hello"
r = api.requests.get(api.url_for(view))
assert r.text == "hello"
assert r.headers["X-After"] == "yes"
def test_after_request_async(api):
"""Async after_request hook."""
@api.after_request()
async def add_header(req, resp):
resp.headers["X-Async-After"] = "yes"
@api.route("/")
def view(req, resp):
resp.text = "hello"
r = api.requests.get(api.url_for(view))
assert r.headers["X-Async-After"] == "yes"
# --- Route groups ---
def test_route_group(api):
"""Route group with shared prefix."""
v1 = api.group("/v1")
@v1.route("/users")
def list_users(req, resp):
resp.media = [{"name": "alice"}]
@v1.route("/users/{user_id:int}")
def get_user(req, resp, *, user_id):
resp.media = {"id": user_id}
r = api.requests.get("http://;/v1/users")
assert r.json() == [{"name": "alice"}]
r = api.requests.get("http://;/v1/users/42")
assert r.json() == {"id": 42}
def test_multiple_route_groups(api):
"""Multiple route groups coexist."""
v1 = api.group("/v1")
v2 = api.group("/v2")
@v1.route("/status")
def v1_status(req, resp):
resp.media = {"version": 1}
@v2.route("/status")
def v2_status(req, resp):
resp.media = {"version": 2}
assert api.requests.get("http://;/v1/status").json() == {"version": 1}
assert api.requests.get("http://;/v2/status").json() == {"version": 2}
# --- Request ID ---
def test_request_id():
"""Auto-generated request ID header."""
api = responder.API(request_id=True, allowed_hosts=[";"])
@api.route("/")
def view(req, resp):
resp.text = "ok"
r = api.requests.get("http://;/")
assert "X-Request-ID" in r.headers
assert len(r.headers["X-Request-ID"]) > 0
def test_request_id_forwarded():
"""Request ID is forwarded from client header."""
api = responder.API(request_id=True, allowed_hosts=[";"])
@api.route("/")
def view(req, resp):
resp.text = "ok"
r = api.requests.get("http://;/", headers={"X-Request-ID": "my-trace-123"})
assert r.headers["X-Request-ID"] == "my-trace-123"
# --- Rate Limiting ---
def test_rate_limiter():
"""Rate limiter returns 429 when exceeded."""
api = responder.API(allowed_hosts=[";"])
limiter = RateLimiter(requests=3, period=60)
limiter.install(api)
@api.route("/")
def view(req, resp):
resp.text = "ok"
for _i in range(3):
r = api.requests.get("http://;/")
assert r.status_code == 200
assert "X-RateLimit-Remaining" in r.headers
# 4th request should be rate limited
r = api.requests.get("http://;/")
assert r.status_code == 429
assert "Retry-After" in r.headers
# --- MessagePack ---
def test_msgpack_format(api):
"""MessagePack encoding and decoding."""
import msgpack
@api.route("/")
async def view(req, resp):
data = await req.media("msgpack")
resp.media = data
payload = {"hello": "world", "number": 42}
r = api.requests.post(
api.url_for(view),
content=msgpack.packb(payload),
headers={"Content-Type": "application/x-msgpack"},
)
assert r.json() == payload