Compare commits

..

23 Commits

Author SHA1 Message Date
kennethreitz 9aa99869ae next version 2018-10-29 08:00:44 -04:00
kennethreitz 08e0d87347 Merge branch 'master' of github.com:kennethreitz/responder 2018-10-29 07:57:19 -04:00
kennethreitz 3f9e4057d3 run sync tasks in threadpoolexecutor 2018-10-29 07:54:59 -04:00
kennethreitz a29e40353c Merge pull request #170 from taoufik07/trusted_hosts
Trusted hosts support
2018-10-28 14:47:28 -04:00
taoufik07 778cb2dd0f Add Tests 2018-10-28 18:26:42 +00:00
taoufik07 f7d5514b94 Fix test base_url 2018-10-28 18:12:21 +00:00
taoufik07 954637f7b3 Pass base_url to the TestClient 2018-10-28 18:09:52 +00:00
taoufik07 1ab46104c8 Allow all hosts by default 2018-10-28 14:51:24 +00:00
kennethreitz 815776d473 Merge branch 'master' of github.com:kennethreitz/responder 2018-10-28 05:25:56 -04:00
kennethreitz 8db1a7be90 Merge pull request #169 from taoufik07/patch-13
typo
2018-10-28 05:23:56 -04:00
taoufik07 7b11fa24dd Silence for now 2018-10-28 01:38:17 +01:00
taoufik07 1f0f2318d5 cleanup 2018-10-28 01:34:26 +01:00
taoufik07 029b3e2a52 Tests 2018-10-28 00:46:50 +01:00
taoufik07 4fff823def Trusted host 2018-10-28 00:46:39 +01:00
Taoufik cab78275f4 typo 2018-10-27 22:25:11 +01:00
kennethreitz 5f60e4fedb before_request 2018-10-27 09:22:17 -04:00
kennethreitz 96971a33a7 tour 2018-10-27 09:20:18 -04:00
kennethreitz 9a7409f521 test for before_request 2018-10-27 09:18:07 -04:00
kennethreitz 80aa7e305b before_request=True 2018-10-27 09:15:52 -04:00
kennethreitz 27d513cb01 Merge branch 'master' of github.com:kennethreitz/responder 2018-10-27 09:04:29 -04:00
kennethreitz 9bf5cc8c03 before_request, v1 2018-10-27 09:04:19 -04:00
kennethreitz 7994b210cd Merge pull request #166 from repodevs/fix-deployment-example
DOC: change dockerfile instruction `from` to use UPPERCASE
2018-10-27 07:33:47 -04:00
Edi Santoso 46555bbe3f DOC: change dockerfile instruction from to use UPPERCASE 2018-10-27 18:29:43 +07:00
11 changed files with 235 additions and 70 deletions
+6
View File
@@ -1,3 +1,9 @@
# v1.1.1
- Run sync views in a threadpoolexecutor.
# v1.1.0
- Support for `before_request`.
# v1.0.4
- Potential bufix for cookies.
+1 -1
View File
@@ -10,7 +10,7 @@ Assuming existing ``api.py`` and ``Pipfile.lock`` containing ``responder``.
``Dockerfile``::
from kennethreitz/pipenv
FROM kennethreitz/pipenv
COPY . /app
CMD python3 api.py
+11
View File
@@ -173,6 +173,17 @@ You can easily read a Request's session data, that can be trusted to have origin
api = responder.API(secret_key=os.environ['SECRET_KEY'])
Using ``before_request``
------------------------
If you'd like a view to be executed before every request, simply do the following::
@api.route(before_request=True)
def prepare_response(req, resp):
resp.headers["X-Pizza"] = "42"
Now all requests to your HTTP Service will include an ``X-Pizza`` header.
Using Requests Test Client
--------------------------
+1 -1
View File
@@ -1 +1 @@
__version__ = "1.0.4"
__version__ = "1.1.1"
+92 -59
View File
@@ -1,6 +1,7 @@
import json
import os
from uuid import uuid4
from pathlib import Path
from base64 import b64encode
@@ -25,6 +26,7 @@ from starlette.websockets import WebSocket
from whitenoise import WhiteNoise
from . import models, status_codes
from .middlewares.trustedhost import TrustedHostMiddleware
from .background import BackgroundQueue
from .formats import get_formats
from .routes import Route
@@ -65,7 +67,9 @@ class API:
enable_hsts=False,
docs_route=None,
cors=False,
allowed_hosts=None
):
self.background = BackgroundQueue()
self.secret_key = secret_key
self.title = title
@@ -86,6 +90,15 @@ class API:
self.hsts_enabled = enable_hsts
self.cors = cors
self.cors_params = DEFAULT_CORS_PARAMS
if not allowed_hosts:
# if not debug:
# raise RuntimeError(
# "You need to specify `allowed_hosts` when debug is set to False"
# )
allowed_hosts = ["*"]
self.allowed_hosts = allowed_hosts
# Make the static/templates directory if they don't exist.
for _dir in (self.static_dir, self.templates_dir):
os.makedirs(_dir, exist_ok=True)
@@ -106,7 +119,6 @@ class API:
# Cached requests session.
self._session = None
self.background = BackgroundQueue()
if self.openapi_version:
self.add_route(openapi_route, self.schema_response)
@@ -122,6 +134,9 @@ class API:
if self.hsts_enabled:
self.add_middleware(HTTPSRedirectMiddleware)
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
self.lifespan_handler = LifespanHandler()
if self.cors:
@@ -145,6 +160,15 @@ class API:
def _default_wsgi_app(*args, **kwargs):
pass
@property
def before_requests(self):
def gen():
for route in self.routes:
if self.routes[route].before_request:
yield self.routes[route]
return [g for g in gen()]
@property
def _apispec(self):
spec = APISpec(
@@ -270,67 +294,16 @@ class API:
route = self.path_matches_route(req.url.path)
route = self.routes.get(route)
# Create the response object.
cont = False
if route:
if route.uses_websocket:
resp = WebSocket(**options)
else:
resp = models.Response(req=req, formats=self.formats)
params = route.incoming_matches(req.url.path)
if route.is_function:
try:
try:
# Run the view.
r = route.endpoint(req, resp, **params)
# If it's async, await it.
if hasattr(r, "cr_running"):
await r
except TypeError as e:
cont = True
except Exception:
self.default_response(req, resp, error=True)
raise
elif route.is_class_based or cont:
try:
view = route.endpoint(**params)
except TypeError:
try:
view = route.endpoint()
except TypeError:
view = route.endpoint
# Run on_request first.
try:
# Run the view.
r = getattr(view, "on_request", self.no_response)(
req, resp, **params
)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception:
self.default_response(req, resp, error=True)
raise
# Then run on_method.
method = req.method
try:
# Run the view.
r = getattr(view, f"on_{method}", self.no_response)(
req, resp, **params
)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception as e:
self.default_response(req, resp, error=True)
for before_request in self.before_requests:
await self._execute_route(route=before_request, req=req, resp=resp)
await self._execute_route(route=route, req=req, resp=resp, **options)
else:
resp = models.Response(req=req, formats=self.formats)
self.default_response(req, resp, notfound=True)
@@ -341,6 +314,60 @@ class API:
return resp
async def _execute_route(self, *, route, req, resp, **options):
params = route.incoming_matches(req.url.path)
if route.is_function:
try:
try:
# Run the view.
r = self.background(route.endpoint, req, resp, **params)
# If it's async, await it.
if hasattr(r, "cr_running"):
await r
except TypeError as e:
cont = True
except Exception:
self.background(self.default_response, req, resp, error=True)
raise
if route.is_class_based or cont:
try:
view = route.endpoint(**params)
except TypeError:
try:
view = route.endpoint()
except TypeError:
view = route.endpoint
pass
# Run on_request first.
try:
# Run the view.
r = getattr(view, "on_request", self.no_response)
r = self.background(r, req, resp, **params)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception:
self.background(self.default_response, req, resp, error=True)
raise
# Then run on_method.
method = req.method
try:
# Run the view.
r = getattr(view, f"on_{method}", self.no_response)
r = self.background(r, req, resp, **params)
# If it's async, await it.
if hasattr(r, "send"):
await r
except Exception as e:
self.background(self.default_response, req, resp, error=True)
def add_event_handler(self, event_type, handler):
"""Adds an event handler to the API.
@@ -352,13 +379,14 @@ class API:
def add_route(
self,
route,
route=None,
endpoint=None,
*,
default=False,
static=False,
check_existing=True,
websocket=False,
before_request=False,
):
"""Adds a route to the API.
@@ -368,6 +396,9 @@ class API:
:param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
:param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
"""
if route is None:
route = f"/{uuid4().hex}"
if check_existing:
assert route not in self.routes
@@ -378,7 +409,9 @@ class API:
if default:
self.default_endpoint = endpoint
self.routes[route] = Route(route, endpoint, websocket=websocket)
self.routes[route] = Route(
route, endpoint, websocket=websocket, before_request=before_request
)
# TODO: A better data structure or sort it once the app is loaded
self.routes = dict(
sorted(self.routes.items(), key=lambda item: item[1]._weight())
@@ -457,7 +490,7 @@ class API:
return decorator
def route(self, route, **options):
def route(self, route=None, **options):
"""Decorator for creating new routes around function and class definitions.
Usage::
@@ -489,7 +522,7 @@ class API:
"""
if self._session is None:
self._session = TestClient(self)
self._session = TestClient(self, base_url=base_url)
return self._session
def _route_for(self, endpoint):
+12 -2
View File
@@ -1,6 +1,8 @@
import traceback
import multiprocessing
import asyncio
import functools
import concurrent.futures
import multiprocessing
import traceback
class BackgroundQueue:
@@ -33,3 +35,11 @@ class BackgroundQueue:
return result
return do_task
async def __call__(self, func, *args, **kwargs) -> None:
if asyncio.iscoroutinefunction(func):
return await asyncio.ensure_future(func(*args, **kwargs))
else:
fn = functools.partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, fn)
View File
+32
View File
@@ -0,0 +1,32 @@
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
def _is_trusted_host(host, allowed_hosts):
"""
Check if the host matchs the pattern.
Any given pattern starting with a period is considered a wildcard pattern.
"""
host = host.lower()
for pattern in allowed_hosts:
if (
pattern == "*" or pattern == host or
pattern[0] == "." and
(host.endswith(pattern) or host == pattern[1:])
):
return True
return False
class TrustedHostMiddleware:
def __init__(self, app, allowed_hosts):
self.app = app
self.allowed_hosts = allowed_hosts
self.allow_any = "*" in allowed_hosts
def __call__(self, scope):
if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope)
host = headers.get("host").split(":")[0]
if not _is_trusted_host(host, self.allowed_hosts):
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)
+2 -1
View File
@@ -15,10 +15,11 @@ def memoize(f):
class Route:
_param_pattern = re.compile(r"{([^{}]*)}")
def __init__(self, route, endpoint, *, websocket=False):
def __init__(self, route, endpoint, *, websocket=False, before_request=False):
self.route = route
self.endpoint = endpoint
self.uses_websocket = websocket
self.before_request = before_request
self._memo = {}
def __repr__(self):
+1 -1
View File
@@ -18,7 +18,7 @@ def current_dir():
@pytest.fixture
def api():
return responder.API()
return responder.API(allowed_hosts=["testserver", ";"])
@pytest.fixture
+77 -5
View File
@@ -333,7 +333,11 @@ def test_schema_generation():
import responder
from marshmallow import Schema, fields
api = responder.API(title="Web Service", openapi="3.0")
api = responder.API(
title="Web Service",
openapi="3.0",
allowed_hosts=["testserver", ";"]
)
@api.schema("Pet")
class PetSchema(Schema):
@@ -364,7 +368,12 @@ def test_documentation():
import responder
from marshmallow import Schema, fields
api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs")
api = responder.API(
title="Web Service",
openapi="3.0",
docs_route="/docs",
allowed_hosts=["testserver", ";"]
)
@api.schema("Pet")
class PetSchema(Schema):
@@ -527,7 +536,7 @@ def test_redirects(api, session):
def one(req, resp):
resp.text = "redirected"
assert session.get("/1").url == "http://testserver/1"
assert session.get("/1").url == "http://;/1"
def test_session_thoroughly(api, session):
@@ -541,7 +550,70 @@ def test_session_thoroughly(api, session):
resp.media = {"session": req.session}
r = session.get(api.url_for(set))
print(r.headers)
r = session.get(api.url_for(get))
print(r.request.headers)
assert r.json() == {"session": {"hello": "world"}}
def test_before_response(api, session):
@api.route("/get")
def get(req, resp):
resp.media = req.session
@api.route(before_request=True)
def before_request(req, resp):
resp.headers["x-pizza"] = "1"
r = session.get(api.url_for(get))
assert 'x-pizza' in r.headers
def test_allowed_hosts():
api = responder.API(
allowed_hosts=[";", "tenant.;"]
)
@api.route("/")
def get(req, resp):
pass
# Exact match
r = api.requests.get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant.;").get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://unkownhost").get(api.url_for(get))
assert r.status_code == 400
# Reset the session
api._session = None
r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get))
assert r.status_code == 400
api = responder.API(
allowed_hosts=[".;"]
)
@api.route("/")
def get(req, resp):
pass
# Wildcard domains
# Using http://;
r = api.requests.get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant1.;").get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant2.;").get(api.url_for(get))
assert r.status_code == 200