mirror of
https://github.com/kennethreitz/responder.git
synced 2026-06-05 23:00:17 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9aa99869ae | |||
| 08e0d87347 | |||
| 3f9e4057d3 | |||
| a29e40353c | |||
| 778cb2dd0f | |||
| f7d5514b94 | |||
| 954637f7b3 | |||
| 1ab46104c8 | |||
| 815776d473 | |||
| 8db1a7be90 | |||
| 7b11fa24dd | |||
| 1f0f2318d5 | |||
| 029b3e2a52 | |||
| 4fff823def | |||
| cab78275f4 | |||
| 5f60e4fedb | |||
| 96971a33a7 | |||
| 9a7409f521 | |||
| 80aa7e305b | |||
| 27d513cb01 | |||
| 9bf5cc8c03 | |||
| 7994b210cd | |||
| 46555bbe3f | |||
| 4d15dbc465 | |||
| 855d3c4320 | |||
| 4564862acc |
@@ -1,3 +1,12 @@
|
||||
# v1.1.1
|
||||
- Run sync views in a threadpoolexecutor.
|
||||
|
||||
# v1.1.0
|
||||
- Support for `before_request`.
|
||||
|
||||
# v1.0.4
|
||||
- Potential bufix for cookies.
|
||||
|
||||
# v1.0.3
|
||||
- Bugfix for redirects.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Assuming existing ``api.py`` and ``Pipfile.lock`` containing ``responder``.
|
||||
|
||||
``Dockerfile``::
|
||||
|
||||
from kennethreitz/pipenv
|
||||
FROM kennethreitz/pipenv
|
||||
|
||||
COPY . /app
|
||||
CMD python3 api.py
|
||||
|
||||
@@ -173,6 +173,17 @@ You can easily read a Request's session data, that can be trusted to have origin
|
||||
|
||||
api = responder.API(secret_key=os.environ['SECRET_KEY'])
|
||||
|
||||
Using ``before_request``
|
||||
------------------------
|
||||
|
||||
If you'd like a view to be executed before every request, simply do the following::
|
||||
|
||||
@api.route(before_request=True)
|
||||
def prepare_response(req, resp):
|
||||
resp.headers["X-Pizza"] = "42"
|
||||
|
||||
Now all requests to your HTTP Service will include an ``X-Pizza`` header.
|
||||
|
||||
Using Requests Test Client
|
||||
--------------------------
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.0.3"
|
||||
__version__ = "1.1.1"
|
||||
|
||||
+97
-61
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
from base64 import b64encode
|
||||
|
||||
import apistar
|
||||
import itsdangerous
|
||||
@@ -24,6 +26,7 @@ from starlette.websockets import WebSocket
|
||||
from whitenoise import WhiteNoise
|
||||
|
||||
from . import models, status_codes
|
||||
from .middlewares.trustedhost import TrustedHostMiddleware
|
||||
from .background import BackgroundQueue
|
||||
from .formats import get_formats
|
||||
from .routes import Route
|
||||
@@ -64,7 +67,9 @@ class API:
|
||||
enable_hsts=False,
|
||||
docs_route=None,
|
||||
cors=False,
|
||||
allowed_hosts=None
|
||||
):
|
||||
self.background = BackgroundQueue()
|
||||
|
||||
self.secret_key = secret_key
|
||||
self.title = title
|
||||
@@ -85,6 +90,15 @@ class API:
|
||||
self.hsts_enabled = enable_hsts
|
||||
self.cors = cors
|
||||
self.cors_params = DEFAULT_CORS_PARAMS
|
||||
|
||||
if not allowed_hosts:
|
||||
# if not debug:
|
||||
# raise RuntimeError(
|
||||
# "You need to specify `allowed_hosts` when debug is set to False"
|
||||
# )
|
||||
allowed_hosts = ["*"]
|
||||
self.allowed_hosts = allowed_hosts
|
||||
|
||||
# Make the static/templates directory if they don't exist.
|
||||
for _dir in (self.static_dir, self.templates_dir):
|
||||
os.makedirs(_dir, exist_ok=True)
|
||||
@@ -105,7 +119,6 @@ class API:
|
||||
|
||||
# Cached requests session.
|
||||
self._session = None
|
||||
self.background = BackgroundQueue()
|
||||
|
||||
if self.openapi_version:
|
||||
self.add_route(openapi_route, self.schema_response)
|
||||
@@ -121,6 +134,9 @@ class API:
|
||||
|
||||
if self.hsts_enabled:
|
||||
self.add_middleware(HTTPSRedirectMiddleware)
|
||||
|
||||
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
|
||||
|
||||
self.lifespan_handler = LifespanHandler()
|
||||
|
||||
if self.cors:
|
||||
@@ -144,6 +160,15 @@ class API:
|
||||
def _default_wsgi_app(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@property
|
||||
def before_requests(self):
|
||||
def gen():
|
||||
for route in self.routes:
|
||||
if self.routes[route].before_request:
|
||||
yield self.routes[route]
|
||||
|
||||
return [g for g in gen()]
|
||||
|
||||
@property
|
||||
def _apispec(self):
|
||||
spec = APISpec(
|
||||
@@ -242,7 +267,7 @@ class API:
|
||||
|
||||
def _prepare_cookies(self, resp):
|
||||
if resp.cookies:
|
||||
header = " ".join([f"{k}={v}" for k, v in resp.cookies.items()])
|
||||
header = " ".join([f"{k}={v};" for k, v in resp.cookies.items()])
|
||||
resp.headers["Set-Cookie"] = header
|
||||
|
||||
@property
|
||||
@@ -252,7 +277,9 @@ class API:
|
||||
def _prepare_session(self, resp):
|
||||
|
||||
if resp.session:
|
||||
data = self._signer.sign(json.dumps(resp.session).encode("utf-8"))
|
||||
data = self._signer.sign(
|
||||
b64encode(json.dumps(resp.session).encode("utf-8"))
|
||||
)
|
||||
resp.cookies[self.session_cookie] = data.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
@@ -267,67 +294,16 @@ class API:
|
||||
route = self.path_matches_route(req.url.path)
|
||||
route = self.routes.get(route)
|
||||
|
||||
# Create the response object.
|
||||
cont = False
|
||||
if route:
|
||||
if route.uses_websocket:
|
||||
resp = WebSocket(**options)
|
||||
|
||||
else:
|
||||
resp = models.Response(req=req, formats=self.formats)
|
||||
|
||||
params = route.incoming_matches(req.url.path)
|
||||
|
||||
if route.is_function:
|
||||
try:
|
||||
try:
|
||||
# Run the view.
|
||||
r = route.endpoint(req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "cr_running"):
|
||||
await r
|
||||
except TypeError as e:
|
||||
cont = True
|
||||
except Exception:
|
||||
self.default_response(req, resp, error=True)
|
||||
raise
|
||||
|
||||
elif route.is_class_based or cont:
|
||||
try:
|
||||
view = route.endpoint(**params)
|
||||
except TypeError:
|
||||
try:
|
||||
view = route.endpoint()
|
||||
except TypeError:
|
||||
view = route.endpoint
|
||||
|
||||
# Run on_request first.
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, "on_request", self.no_response)(
|
||||
req, resp, **params
|
||||
)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception:
|
||||
self.default_response(req, resp, error=True)
|
||||
raise
|
||||
|
||||
# Then run on_method.
|
||||
method = req.method
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, f"on_{method}", self.no_response)(
|
||||
req, resp, **params
|
||||
)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception as e:
|
||||
|
||||
self.default_response(req, resp, error=True)
|
||||
for before_request in self.before_requests:
|
||||
await self._execute_route(route=before_request, req=req, resp=resp)
|
||||
|
||||
await self._execute_route(route=route, req=req, resp=resp, **options)
|
||||
else:
|
||||
resp = models.Response(req=req, formats=self.formats)
|
||||
self.default_response(req, resp, notfound=True)
|
||||
@@ -338,6 +314,60 @@ class API:
|
||||
|
||||
return resp
|
||||
|
||||
async def _execute_route(self, *, route, req, resp, **options):
|
||||
|
||||
params = route.incoming_matches(req.url.path)
|
||||
|
||||
if route.is_function:
|
||||
try:
|
||||
try:
|
||||
# Run the view.
|
||||
|
||||
r = self.background(route.endpoint, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "cr_running"):
|
||||
await r
|
||||
except TypeError as e:
|
||||
cont = True
|
||||
except Exception:
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
if route.is_class_based or cont:
|
||||
try:
|
||||
view = route.endpoint(**params)
|
||||
except TypeError:
|
||||
try:
|
||||
view = route.endpoint()
|
||||
except TypeError:
|
||||
view = route.endpoint
|
||||
pass
|
||||
|
||||
# Run on_request first.
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, "on_request", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception:
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
raise
|
||||
|
||||
# Then run on_method.
|
||||
method = req.method
|
||||
try:
|
||||
# Run the view.
|
||||
r = getattr(view, f"on_{method}", self.no_response)
|
||||
r = self.background(r, req, resp, **params)
|
||||
# If it's async, await it.
|
||||
if hasattr(r, "send"):
|
||||
await r
|
||||
except Exception as e:
|
||||
|
||||
self.background(self.default_response, req, resp, error=True)
|
||||
|
||||
def add_event_handler(self, event_type, handler):
|
||||
"""Adds an event handler to the API.
|
||||
|
||||
@@ -349,13 +379,14 @@ class API:
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
route,
|
||||
route=None,
|
||||
endpoint=None,
|
||||
*,
|
||||
default=False,
|
||||
static=False,
|
||||
check_existing=True,
|
||||
websocket=False,
|
||||
before_request=False,
|
||||
):
|
||||
"""Adds a route to the API.
|
||||
|
||||
@@ -365,6 +396,9 @@ class API:
|
||||
:param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
|
||||
:param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
|
||||
"""
|
||||
if route is None:
|
||||
route = f"/{uuid4().hex}"
|
||||
|
||||
if check_existing:
|
||||
assert route not in self.routes
|
||||
|
||||
@@ -375,7 +409,9 @@ class API:
|
||||
if default:
|
||||
self.default_endpoint = endpoint
|
||||
|
||||
self.routes[route] = Route(route, endpoint, websocket=websocket)
|
||||
self.routes[route] = Route(
|
||||
route, endpoint, websocket=websocket, before_request=before_request
|
||||
)
|
||||
# TODO: A better data structure or sort it once the app is loaded
|
||||
self.routes = dict(
|
||||
sorted(self.routes.items(), key=lambda item: item[1]._weight())
|
||||
@@ -454,7 +490,7 @@ class API:
|
||||
|
||||
return decorator
|
||||
|
||||
def route(self, route, **options):
|
||||
def route(self, route=None, **options):
|
||||
"""Decorator for creating new routes around function and class definitions.
|
||||
|
||||
Usage::
|
||||
@@ -486,7 +522,7 @@ class API:
|
||||
"""
|
||||
|
||||
if self._session is None:
|
||||
self._session = TestClient(self)
|
||||
self._session = TestClient(self, base_url=base_url)
|
||||
return self._session
|
||||
|
||||
def _route_for(self, endpoint):
|
||||
|
||||
+12
-2
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
import multiprocessing
|
||||
import asyncio
|
||||
import functools
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import traceback
|
||||
|
||||
|
||||
class BackgroundQueue:
|
||||
@@ -33,3 +35,11 @@ class BackgroundQueue:
|
||||
return result
|
||||
|
||||
return do_task
|
||||
|
||||
async def __call__(self, func, *args, **kwargs) -> None:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await asyncio.ensure_future(func(*args, **kwargs))
|
||||
else:
|
||||
fn = functools.partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, fn)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
def _is_trusted_host(host, allowed_hosts):
|
||||
"""
|
||||
Check if the host matchs the pattern.
|
||||
|
||||
Any given pattern starting with a period is considered a wildcard pattern.
|
||||
"""
|
||||
host = host.lower()
|
||||
for pattern in allowed_hosts:
|
||||
if (
|
||||
pattern == "*" or pattern == host or
|
||||
pattern[0] == "." and
|
||||
(host.endswith(pattern) or host == pattern[1:])
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(self, app, allowed_hosts):
|
||||
self.app = app
|
||||
self.allowed_hosts = allowed_hosts
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
|
||||
def __call__(self, scope):
|
||||
if scope["type"] in ("http", "websocket") and not self.allow_any:
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host").split(":")[0]
|
||||
if not _is_trusted_host(host, self.allowed_hosts):
|
||||
return PlainTextResponse("Invalid host header", status_code=400)
|
||||
return self.app(scope)
|
||||
+5
-1
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import gzip
|
||||
from base64 import b64decode
|
||||
from http.cookies import SimpleCookie
|
||||
|
||||
|
||||
@@ -109,8 +110,11 @@ class Request:
|
||||
def session(self):
|
||||
"""The session data, in dict form, from the Request."""
|
||||
if "Responder-Session" in self.cookies:
|
||||
|
||||
data = self.cookies[self.api.session_cookie]
|
||||
|
||||
data = self.api._signer.unsign(data)
|
||||
data = b64decode(data)
|
||||
return json.loads(data)
|
||||
return {}
|
||||
|
||||
@@ -142,7 +146,7 @@ class Request:
|
||||
def cookies(self):
|
||||
"""The cookies sent in the Request, as a dictionary."""
|
||||
cookies = RequestsCookieJar()
|
||||
cookie_header = self.headers.get("cookie", "")
|
||||
cookie_header = self.headers.get("Cookie", "")
|
||||
|
||||
bc = SimpleCookie(cookie_header)
|
||||
for k, v in bc.items():
|
||||
|
||||
+2
-1
@@ -15,10 +15,11 @@ def memoize(f):
|
||||
class Route:
|
||||
_param_pattern = re.compile(r"{([^{}]*)}")
|
||||
|
||||
def __init__(self, route, endpoint, *, websocket=False):
|
||||
def __init__(self, route, endpoint, *, websocket=False, before_request=False):
|
||||
self.route = route
|
||||
self.endpoint = endpoint
|
||||
self.uses_websocket = websocket
|
||||
self.before_request = before_request
|
||||
self._memo = {}
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ def current_dir():
|
||||
|
||||
@pytest.fixture
|
||||
def api():
|
||||
return responder.API()
|
||||
return responder.API(allowed_hosts=["testserver", ";"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
+93
-3
@@ -333,7 +333,11 @@ def test_schema_generation():
|
||||
import responder
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
api = responder.API(title="Web Service", openapi="3.0")
|
||||
api = responder.API(
|
||||
title="Web Service",
|
||||
openapi="3.0",
|
||||
allowed_hosts=["testserver", ";"]
|
||||
)
|
||||
|
||||
@api.schema("Pet")
|
||||
class PetSchema(Schema):
|
||||
@@ -364,7 +368,12 @@ def test_documentation():
|
||||
import responder
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs")
|
||||
api = responder.API(
|
||||
title="Web Service",
|
||||
openapi="3.0",
|
||||
docs_route="/docs",
|
||||
allowed_hosts=["testserver", ";"]
|
||||
)
|
||||
|
||||
@api.schema("Pet")
|
||||
class PetSchema(Schema):
|
||||
@@ -424,6 +433,7 @@ def test_cookies(api):
|
||||
assert r.json() == {"cookies": {"sent": "true"}}
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_sessions(api):
|
||||
@api.route("/")
|
||||
def view(req, resp):
|
||||
@@ -526,4 +536,84 @@ def test_redirects(api, session):
|
||||
def one(req, resp):
|
||||
resp.text = "redirected"
|
||||
|
||||
assert session.get("/1").url == "http://testserver/1"
|
||||
assert session.get("/1").url == "http://;/1"
|
||||
|
||||
|
||||
def test_session_thoroughly(api, session):
|
||||
@api.route("/set")
|
||||
def set(req, resp):
|
||||
resp.session["hello"] = "world"
|
||||
api.redirect(resp, location="/get")
|
||||
|
||||
@api.route("/get")
|
||||
def get(req, resp):
|
||||
resp.media = {"session": req.session}
|
||||
|
||||
r = session.get(api.url_for(set))
|
||||
r = session.get(api.url_for(get))
|
||||
assert r.json() == {"session": {"hello": "world"}}
|
||||
|
||||
def test_before_response(api, session):
|
||||
|
||||
@api.route("/get")
|
||||
def get(req, resp):
|
||||
resp.media = req.session
|
||||
|
||||
|
||||
@api.route(before_request=True)
|
||||
def before_request(req, resp):
|
||||
resp.headers["x-pizza"] = "1"
|
||||
|
||||
r = session.get(api.url_for(get))
|
||||
assert 'x-pizza' in r.headers
|
||||
|
||||
def test_allowed_hosts():
|
||||
api = responder.API(
|
||||
allowed_hosts=[";", "tenant.;"]
|
||||
)
|
||||
|
||||
@api.route("/")
|
||||
def get(req, resp):
|
||||
pass
|
||||
|
||||
# Exact match
|
||||
r = api.requests.get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://unkownhost").get(api.url_for(get))
|
||||
assert r.status_code == 400
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get))
|
||||
assert r.status_code == 400
|
||||
|
||||
api = responder.API(
|
||||
allowed_hosts=[".;"]
|
||||
)
|
||||
|
||||
@api.route("/")
|
||||
def get(req, resp):
|
||||
pass
|
||||
|
||||
# Wildcard domains
|
||||
# Using http://;
|
||||
r = api.requests.get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant1.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
# Reset the session
|
||||
api._session = None
|
||||
r = api.session(base_url="http://tenant2.;").get(api.url_for(get))
|
||||
assert r.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user