mirror of
https://github.com/kennethreitz/responder.git
synced 2026-06-05 23:00:17 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9aa99869ae | |||
| 08e0d87347 | |||
| 3f9e4057d3 | |||
| a29e40353c | |||
| 778cb2dd0f | |||
| f7d5514b94 | |||
| 954637f7b3 | |||
| 1ab46104c8 | |||
| 815776d473 | |||
| 8db1a7be90 | |||
| 7b11fa24dd | |||
| 1f0f2318d5 | |||
| 029b3e2a52 | |||
| 4fff823def | |||
| cab78275f4 |
@@ -1,3 +1,6 @@
|
||||
# v1.1.1
|
||||
- Run sync views in a threadpoolexecutor.
|
||||
|
||||
# v1.1.0
|
||||
- Support for `before_request`.
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.1.0"
|
||||
__version__ = "1.1.1"
|
||||
|
||||
+27
-9
@@ -26,6 +26,7 @@ from starlette.websockets import WebSocket
|
||||
from whitenoise import WhiteNoise
|
||||
|
||||
from . import models, status_codes
|
||||
from .middlewares.trustedhost import TrustedHostMiddleware
|
||||
from .background import BackgroundQueue
|
||||
from .formats import get_formats
|
||||
from .routes import Route
|
||||
@@ -66,7 +67,9 @@ class API:
|
||||
enable_hsts=False,
|
||||
docs_route=None,
|
||||
cors=False,
|
||||
allowed_hosts=None
|
||||
):
|
||||
self.background = BackgroundQueue()
|
||||
|
||||
self.secret_key = secret_key
|
||||
self.title = title
|
||||
@@ -87,6 +90,15 @@ class API:
|
||||
self.hsts_enabled = enable_hsts
|
||||
self.cors = cors
|
||||
self.cors_params = DEFAULT_CORS_PARAMS
|
||||
|
||||
if not allowed_hosts:
|
||||
# if not debug:
|
||||
# raise RuntimeError(
|
||||
# "You need to specify `allowed_hosts` when debug is set to False"
|
||||
# )
|
||||
allowed_hosts = ["*"]
|
||||
self.allowed_hosts = allowed_hosts
|
||||
|
||||
# Make the static/templates directory if they don't exist.
|
||||
for _dir in (self.static_dir, self.templates_dir):
|
||||
os.makedirs(_dir, exist_ok=True)
|
||||
@@ -107,7 +119,6 @@ class API:
|
||||
|
||||
# Cached requests session.
|
||||
self._session = None
|
||||
self.background = BackgroundQueue()
|
||||
|
||||
if self.openapi_version:
|
||||
self.add_route(openapi_route, self.schema_response)
|
||||
@@ -123,6 +134,9 @@ class API:
|
||||
|
||||
if self.hsts_enabled:
|
||||
self.add_middleware(HTTPSRedirectMiddleware)
|
||||
|
||||
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
|
||||
|
||||
self.lifespan_handler = LifespanHandler()
|
||||
|
||||
if self.cors:
|
||||
@@ -308,17 +322,18 @@ class API:
|
||||
try:
|
||||
try:
|
||||
# Run the view.
|
||||
r = route.endpoint(req, resp, **params)
|
||||
|
||||
r = self.background(route.endpoint, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "cr_running"):
|
||||
await r
|
||||
except TypeError as e:
|
||||
cont = True
|
||||
except Exception:
|
||||
self.default_response(req, resp, error=True)
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
elif route.is_class_based or cont:
|
||||
if route.is_class_based or cont:
|
||||
try:
|
||||
view = route.endpoint(**params)
|
||||
except TypeError:
|
||||
@@ -326,29 +341,32 @@ class API:
|
||||
view = route.endpoint()
|
||||
except TypeError:
|
||||
view = route.endpoint
|
||||
pass
|
||||
|
||||
# Run on_request first.
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, "on_request", self.no_response)(req, resp, **params)
|
||||
r = getattr(view, "on_request", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception:
|
||||
self.default_response(req, resp, error=True)
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
# Then run on_method.
|
||||
method = req.method
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, f"on_{method}", self.no_response)(req, resp, **params)
|
||||
r = getattr(view, f"on_{method}", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception as e:
|
||||
|
||||
self.default_response(req, resp, error=True)
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
|
||||
def add_event_handler(self, event_type, handler):
|
||||
"""Adds an event handler to the API.
|
||||
@@ -504,7 +522,7 @@ class API:
|
||||
"""
|
||||
|
||||
if self._session is None:
|
||||
self._session = TestClient(self)
|
||||
self._session = TestClient(self, base_url=base_url)
|
||||
return self._session
|
||||
|
||||
def _route_for(self, endpoint):
|
||||
|
||||
+12
-2
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
import multiprocessing
|
||||
import asyncio
|
||||
import functools
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import traceback
|
||||
|
||||
|
||||
class BackgroundQueue:
|
||||
@@ -33,3 +35,11 @@ class BackgroundQueue:
|
||||
return result
|
||||
|
||||
return do_task
|
||||
|
||||
async def __call__(self, func, *args, **kwargs) -> None:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await asyncio.ensure_future(func(*args, **kwargs))
|
||||
else:
|
||||
fn = functools.partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, fn)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
def _is_trusted_host(host, allowed_hosts):
|
||||
"""
|
||||
Check if the host matchs the pattern.
|
||||
|
||||
Any given pattern starting with a period is considered a wildcard pattern.
|
||||
"""
|
||||
host = host.lower()
|
||||
for pattern in allowed_hosts:
|
||||
if (
|
||||
pattern == "*" or pattern == host or
|
||||
pattern[0] == "." and
|
||||
(host.endswith(pattern) or host == pattern[1:])
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(self, app, allowed_hosts):
|
||||
self.app = app
|
||||
self.allowed_hosts = allowed_hosts
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
|
||||
def __call__(self, scope):
|
||||
if scope["type"] in ("http", "websocket") and not self.allow_any:
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host").split(":")[0]
|
||||
if not _is_trusted_host(host, self.allowed_hosts):
|
||||
return PlainTextResponse("Invalid host header", status_code=400)
|
||||
return self.app(scope)
|
||||
+1
-1
@@ -18,7 +18,7 @@ def current_dir():
|
||||
|
||||
@pytest.fixture
|
||||
def api():
|
||||
return responder.API()
|
||||
return responder.API(allowed_hosts=["testserver", ";"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
+64
-6
@@ -333,7 +333,11 @@ def test_schema_generation():
|
||||
import responder
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
api = responder.API(title="Web Service", openapi="3.0")
|
||||
api = responder.API(
|
||||
title="Web Service",
|
||||
openapi="3.0",
|
||||
allowed_hosts=["testserver", ";"]
|
||||
)
|
||||
|
||||
@api.schema("Pet")
|
||||
class PetSchema(Schema):
|
||||
@@ -364,7 +368,12 @@ def test_documentation():
|
||||
import responder
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs")
|
||||
api = responder.API(
|
||||
title="Web Service",
|
||||
openapi="3.0",
|
||||
docs_route="/docs",
|
||||
allowed_hosts=["testserver", ";"]
|
||||
)
|
||||
|
||||
@api.schema("Pet")
|
||||
class PetSchema(Schema):
|
||||
@@ -527,7 +536,7 @@ def test_redirects(api, session):
|
||||
def one(req, resp):
|
||||
resp.text = "redirected"
|
||||
|
||||
assert session.get("/1").url == "http://testserver/1"
|
||||
assert session.get("/1").url == "http://;/1"
|
||||
|
||||
|
||||
def test_session_thoroughly(api, session):
|
||||
@@ -541,12 +550,10 @@ def test_session_thoroughly(api, session):
|
||||
resp.media = {"session": req.session}
|
||||
|
||||
r = session.get(api.url_for(set))
|
||||
print(r.headers)
|
||||
r = session.get(api.url_for(get))
|
||||
print(r.request.headers)
|
||||
assert r.json() == {"session": {"hello": "world"}}
|
||||
|
||||
def test_before_responpse(api, session):
|
||||
def test_before_response(api, session):
|
||||
|
||||
@api.route("/get")
|
||||
def get(req, resp):
|
||||
@@ -559,3 +566,54 @@ def test_before_responpse(api, session):
|
||||
|
||||
r = session.get(api.url_for(get))
|
||||
assert 'x-pizza' in r.headers
|
||||
|
||||
def test_allowed_hosts():
|
||||
api = responder.API(
|
||||
allowed_hosts=[";", "tenant.;"]
|
||||
)
|
||||
|
||||
@api.route("/")
|
||||
def get(req, resp):
|
||||
pass
|
||||
|
||||
# Exact match
|
||||
r = api.requests.get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://unkownhost").get(api.url_for(get))
|
||||
assert r.status_code == 400
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get))
|
||||
assert r.status_code == 400
|
||||
|
||||
api = responder.API(
|
||||
allowed_hosts=[".;"]
|
||||
)
|
||||
|
||||
@api.route("/")
|
||||
def get(req, resp):
|
||||
pass
|
||||
|
||||
# Wildcard domains
|
||||
# Using http://;
|
||||
r = api.requests.get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant1.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant2.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user