Merge branch 'master' of github.com:kennethreitz/responder

This commit is contained in:
2018-10-29 07:57:19 -04:00
5 changed files with 111 additions and 7 deletions
+15 -1
View File
@@ -26,6 +26,7 @@ from starlette.websockets import WebSocket
from whitenoise import WhiteNoise
from . import models, status_codes
from .middlewares.trustedhost import TrustedHostMiddleware
from .background import BackgroundQueue
from .formats import get_formats
from .routes import Route
@@ -66,6 +67,7 @@ class API:
enable_hsts=False,
docs_route=None,
cors=False,
allowed_hosts=None
):
self.background = BackgroundQueue()
@@ -88,6 +90,15 @@ class API:
self.hsts_enabled = enable_hsts
self.cors = cors
self.cors_params = DEFAULT_CORS_PARAMS
if not allowed_hosts:
# if not debug:
# raise RuntimeError(
# "You need to specify `allowed_hosts` when debug is set to False"
# )
allowed_hosts = ["*"]
self.allowed_hosts = allowed_hosts
# Make the static/templates directory if they don't exist.
for _dir in (self.static_dir, self.templates_dir):
os.makedirs(_dir, exist_ok=True)
@@ -123,6 +134,9 @@ class API:
if self.hsts_enabled:
self.add_middleware(HTTPSRedirectMiddleware)
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
self.lifespan_handler = LifespanHandler()
if self.cors:
@@ -508,7 +522,7 @@ class API:
"""
if self._session is None:
self._session = TestClient(self)
self._session = TestClient(self, base_url=base_url)
return self._session
def _route_for(self, endpoint):
View File
+32
View File
@@ -0,0 +1,32 @@
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
def _is_trusted_host(host, allowed_hosts):
"""
Check if the host matchs the pattern.
Any given pattern starting with a period is considered a wildcard pattern.
"""
host = host.lower()
for pattern in allowed_hosts:
if (
pattern == "*" or pattern == host or
pattern[0] == "." and
(host.endswith(pattern) or host == pattern[1:])
):
return True
return False
class TrustedHostMiddleware:
def __init__(self, app, allowed_hosts):
self.app = app
self.allowed_hosts = allowed_hosts
self.allow_any = "*" in allowed_hosts
def __call__(self, scope):
if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope)
host = headers.get("host").split(":")[0]
if not _is_trusted_host(host, self.allowed_hosts):
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)
+1 -1
View File
@@ -18,7 +18,7 @@ def current_dir():
@pytest.fixture
def api():
return responder.API()
return responder.API(allowed_hosts=["testserver", ";"])
@pytest.fixture
+63 -5
View File
@@ -333,7 +333,11 @@ def test_schema_generation():
import responder
from marshmallow import Schema, fields
api = responder.API(title="Web Service", openapi="3.0")
api = responder.API(
title="Web Service",
openapi="3.0",
allowed_hosts=["testserver", ";"]
)
@api.schema("Pet")
class PetSchema(Schema):
@@ -364,7 +368,12 @@ def test_documentation():
import responder
from marshmallow import Schema, fields
api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs")
api = responder.API(
title="Web Service",
openapi="3.0",
docs_route="/docs",
allowed_hosts=["testserver", ";"]
)
@api.schema("Pet")
class PetSchema(Schema):
@@ -527,7 +536,7 @@ def test_redirects(api, session):
def one(req, resp):
resp.text = "redirected"
assert session.get("/1").url == "http://testserver/1"
assert session.get("/1").url == "http://;/1"
def test_session_thoroughly(api, session):
@@ -541,9 +550,7 @@ def test_session_thoroughly(api, session):
resp.media = {"session": req.session}
r = session.get(api.url_for(set))
print(r.headers)
r = session.get(api.url_for(get))
print(r.request.headers)
assert r.json() == {"session": {"hello": "world"}}
def test_before_response(api, session):
@@ -559,3 +566,54 @@ def test_before_response(api, session):
r = session.get(api.url_for(get))
assert 'x-pizza' in r.headers
def test_allowed_hosts():
api = responder.API(
allowed_hosts=[";", "tenant.;"]
)
@api.route("/")
def get(req, resp):
pass
# Exact match
r = api.requests.get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant.;").get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://unkownhost").get(api.url_for(get))
assert r.status_code == 400
# Reset the session
api._session = None
r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get))
assert r.status_code == 400
api = responder.API(
allowed_hosts=[".;"]
)
@api.route("/")
def get(req, resp):
pass
# Wildcard domains
# Using http://;
r = api.requests.get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant1.;").get(api.url_for(get))
assert r.status_code == 200
# Reset the session
api._session = None
r = api.session(base_url="http://tenant2.;").get(api.url_for(get))
assert r.status_code == 200