diff --git a/responder/api.py b/responder/api.py index ec8604e..eced91e 100644 --- a/responder/api.py +++ b/responder/api.py @@ -26,6 +26,7 @@ from starlette.websockets import WebSocket from whitenoise import WhiteNoise from . import models, status_codes +from .middlewares.trustedhost import TrustedHostMiddleware from .background import BackgroundQueue from .formats import get_formats from .routes import Route @@ -66,6 +67,7 @@ class API: enable_hsts=False, docs_route=None, cors=False, + allowed_hosts=None ): self.background = BackgroundQueue() @@ -88,6 +90,15 @@ class API: self.hsts_enabled = enable_hsts self.cors = cors self.cors_params = DEFAULT_CORS_PARAMS + + if not allowed_hosts: + # if not debug: + # raise RuntimeError( + # "You need to specify `allowed_hosts` when debug is set to False" + # ) + allowed_hosts = ["*"] + self.allowed_hosts = allowed_hosts + # Make the static/templates directory if they don't exist. for _dir in (self.static_dir, self.templates_dir): os.makedirs(_dir, exist_ok=True) @@ -123,6 +134,9 @@ class API: if self.hsts_enabled: self.add_middleware(HTTPSRedirectMiddleware) + + self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts) + self.lifespan_handler = LifespanHandler() if self.cors: @@ -508,7 +522,7 @@ class API: """ if self._session is None: - self._session = TestClient(self) + self._session = TestClient(self, base_url=base_url) return self._session def _route_for(self, endpoint): diff --git a/responder/middlewares/__init__.py b/responder/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/responder/middlewares/trustedhost.py b/responder/middlewares/trustedhost.py new file mode 100644 index 0000000..de51fad --- /dev/null +++ b/responder/middlewares/trustedhost.py @@ -0,0 +1,32 @@ +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse + +def _is_trusted_host(host, allowed_hosts): + """ + Check if the host matchs the pattern. + + Any given pattern starting with a period is considered a wildcard pattern. + """ + host = host.lower() + for pattern in allowed_hosts: + if ( + pattern == "*" or pattern == host or + pattern[0] == "." and + (host.endswith(pattern) or host == pattern[1:]) + ): + return True + return False + +class TrustedHostMiddleware: + def __init__(self, app, allowed_hosts): + self.app = app + self.allowed_hosts = allowed_hosts + self.allow_any = "*" in allowed_hosts + + def __call__(self, scope): + if scope["type"] in ("http", "websocket") and not self.allow_any: + headers = Headers(scope=scope) + host = headers.get("host").split(":")[0] + if not _is_trusted_host(host, self.allowed_hosts): + return PlainTextResponse("Invalid host header", status_code=400) + return self.app(scope) diff --git a/tests/conftest.py b/tests/conftest.py index bb302b0..1334d2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ def current_dir(): @pytest.fixture def api(): - return responder.API() + return responder.API(allowed_hosts=["testserver", ";"]) @pytest.fixture diff --git a/tests/test_responder.py b/tests/test_responder.py index 3fdf48b..3a61b97 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -333,7 +333,11 @@ def test_schema_generation(): import responder from marshmallow import Schema, fields - api = responder.API(title="Web Service", openapi="3.0") + api = responder.API( + title="Web Service", + openapi="3.0", + allowed_hosts=["testserver", ";"] + ) @api.schema("Pet") class PetSchema(Schema): @@ -364,7 +368,12 @@ def test_documentation(): import responder from marshmallow import Schema, fields - api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs") + api = responder.API( + title="Web Service", + openapi="3.0", + docs_route="/docs", + allowed_hosts=["testserver", ";"] + ) @api.schema("Pet") class PetSchema(Schema): @@ -527,7 +536,7 @@ def test_redirects(api, session): def one(req, resp): resp.text = "redirected" - assert session.get("/1").url == "http://testserver/1" + assert session.get("/1").url == "http://;/1" def test_session_thoroughly(api, session): @@ -541,9 +550,7 @@ def test_session_thoroughly(api, session): resp.media = {"session": req.session} r = session.get(api.url_for(set)) - print(r.headers) r = session.get(api.url_for(get)) - print(r.request.headers) assert r.json() == {"session": {"hello": "world"}} def test_before_response(api, session): @@ -559,3 +566,54 @@ def test_before_response(api, session): r = session.get(api.url_for(get)) assert 'x-pizza' in r.headers + +def test_allowed_hosts(): + api = responder.API( + allowed_hosts=[";", "tenant.;"] + ) + + @api.route("/") + def get(req, resp): + pass + + # Exact match + r = api.requests.get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant.;").get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://unkownhost").get(api.url_for(get)) + assert r.status_code == 400 + + # Reset the session + api._session = None + r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get)) + assert r.status_code == 400 + + api = responder.API( + allowed_hosts=[".;"] + ) + + @api.route("/") + def get(req, resp): + pass + + # Wildcard domains + # Using http://; + r = api.requests.get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant1.;").get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant2.;").get(api.url_for(get)) + assert r.status_code == 200