From 4fff823def896c5ed8eaa2f13ed5ae6bcbcb00fe Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 00:46:39 +0100 Subject: [PATCH 1/8] Trusted host --- responder/api.py | 13 +++++++++++ responder/middlewares/__init__.py | 0 responder/middlewares/trustedhost.py | 34 ++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 responder/middlewares/__init__.py create mode 100644 responder/middlewares/trustedhost.py diff --git a/responder/api.py b/responder/api.py index 3bb896a..683e4e9 100644 --- a/responder/api.py +++ b/responder/api.py @@ -26,6 +26,7 @@ from starlette.websockets import WebSocket from whitenoise import WhiteNoise from . import models, status_codes +from .middlewares.trustedhost import TrustedHostMiddleware from .background import BackgroundQueue from .formats import get_formats from .routes import Route @@ -66,6 +67,7 @@ class API: enable_hsts=False, docs_route=None, cors=False, + allowed_hosts=None ): self.secret_key = secret_key @@ -87,6 +89,14 @@ class API: self.hsts_enabled = enable_hsts self.cors = cors self.cors_params = DEFAULT_CORS_PARAMS + + if not allowed_hosts: + if not debug: + raise RuntimeError("You need to specify `allowed_hosts` when debug is set to False") + allowed_hosts = ['localhost', '127.0.0.1'] + allowed_hosts = ["localhost", "127.0.0.1", ";", "testserver"] + self.allowed_hosts = allowed_hosts + # Make the static/templates directory if they don't exist. for _dir in (self.static_dir, self.templates_dir): os.makedirs(_dir, exist_ok=True) @@ -123,6 +133,9 @@ class API: if self.hsts_enabled: self.add_middleware(HTTPSRedirectMiddleware) + + self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts) + self.lifespan_handler = LifespanHandler() if self.cors: diff --git a/responder/middlewares/__init__.py b/responder/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/responder/middlewares/trustedhost.py b/responder/middlewares/trustedhost.py new file mode 100644 index 0000000..505c569 --- /dev/null +++ b/responder/middlewares/trustedhost.py @@ -0,0 +1,34 @@ +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse + +def _is_trusted_host(host, allowed_hosts): + """ + Check if the host matchs the pattern. + + Any given pattern starting with a period is considered a wildcard pattern. + """ + host = host.lower() + for pattern in allowed_hosts: + if ( + pattern == "*" or pattern == host or + pattern[0] == "." and + (host.endswith(pattern) or host == pattern[1:]) + ): + return True + return False + +class TrustedHostMiddleware: + def __init__(self, app, allowed_hosts): + self.app = app + self.allowed_hosts = allowed_hosts + self.allow_any = "*" in allowed_hosts + + def __call__(self, scope): + if scope["type"] in ("http", "websocket") and not self.allow_any: + headers = Headers(scope=scope) + host = headers.get("host").split(":")[0] + print("HH", host, self.allowed_hosts) + print("EE", _is_trusted_host(host, self.allowed_hosts)) + if not _is_trusted_host(host, self.allowed_hosts): + return PlainTextResponse("Invalid host header", status_code=400) + return self.app(scope) From 029b3e2a52d678e33a02af95d08583b5058cafb3 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 00:46:50 +0100 Subject: [PATCH 2/8] Tests --- tests/conftest.py | 2 +- tests/test_responder.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bb302b0..1334d2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ def current_dir(): @pytest.fixture def api(): - return responder.API() + return responder.API(allowed_hosts=["testserver", ";"]) @pytest.fixture diff --git a/tests/test_responder.py b/tests/test_responder.py index 2267a1d..fef382f 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -333,7 +333,11 @@ def test_schema_generation(): import responder from marshmallow import Schema, fields - api = responder.API(title="Web Service", openapi="3.0") + api = responder.API( + title="Web Service", + openapi="3.0", + allowed_hosts=["testserver", ";"] + ) @api.schema("Pet") class PetSchema(Schema): @@ -364,7 +368,12 @@ def test_documentation(): import responder from marshmallow import Schema, fields - api = responder.API(title="Web Service", openapi="3.0", docs_route="/docs") + api = responder.API( + title="Web Service", + openapi="3.0", + docs_route="/docs", + allowed_hosts=["testserver", ";"] + ) @api.schema("Pet") class PetSchema(Schema): From 1f0f2318d5418eb68313999ebce54c610b1ded3a Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 01:02:37 +0100 Subject: [PATCH 3/8] cleanup --- responder/api.py | 7 ++++--- responder/middlewares/trustedhost.py | 2 -- tests/test_responder.py | 2 -- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/responder/api.py b/responder/api.py index 683e4e9..052c085 100644 --- a/responder/api.py +++ b/responder/api.py @@ -92,9 +92,10 @@ class API: if not allowed_hosts: if not debug: - raise RuntimeError("You need to specify `allowed_hosts` when debug is set to False") - allowed_hosts = ['localhost', '127.0.0.1'] - allowed_hosts = ["localhost", "127.0.0.1", ";", "testserver"] + raise RuntimeError( + "You need to specify `allowed_hosts` when debug is set to False" + ) + allowed_hosts = ["localhost", "127.0.0.1"] self.allowed_hosts = allowed_hosts # Make the static/templates directory if they don't exist. diff --git a/responder/middlewares/trustedhost.py b/responder/middlewares/trustedhost.py index 505c569..de51fad 100644 --- a/responder/middlewares/trustedhost.py +++ b/responder/middlewares/trustedhost.py @@ -27,8 +27,6 @@ class TrustedHostMiddleware: if scope["type"] in ("http", "websocket") and not self.allow_any: headers = Headers(scope=scope) host = headers.get("host").split(":")[0] - print("HH", host, self.allowed_hosts) - print("EE", _is_trusted_host(host, self.allowed_hosts)) if not _is_trusted_host(host, self.allowed_hosts): return PlainTextResponse("Invalid host header", status_code=400) return self.app(scope) diff --git a/tests/test_responder.py b/tests/test_responder.py index fef382f..a6f3964 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -550,9 +550,7 @@ def test_session_thoroughly(api, session): resp.media = {"session": req.session} r = session.get(api.url_for(set)) - print(r.headers) r = session.get(api.url_for(get)) - print(r.request.headers) assert r.json() == {"session": {"hello": "world"}} def test_before_responpse(api, session): From 7b11fa24dd8f999b3113e93cd9698afcc63ade61 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 01:38:17 +0100 Subject: [PATCH 4/8] Silence for now --- responder/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/responder/api.py b/responder/api.py index 052c085..3eb51dc 100644 --- a/responder/api.py +++ b/responder/api.py @@ -91,10 +91,10 @@ class API: self.cors_params = DEFAULT_CORS_PARAMS if not allowed_hosts: - if not debug: - raise RuntimeError( - "You need to specify `allowed_hosts` when debug is set to False" - ) + # if not debug: + # raise RuntimeError( + # "You need to specify `allowed_hosts` when debug is set to False" + # ) allowed_hosts = ["localhost", "127.0.0.1"] self.allowed_hosts = allowed_hosts From 1ab46104c87050fb4e3f61c511c57341b578f818 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 14:51:24 +0000 Subject: [PATCH 5/8] Allow all hosts by default --- responder/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/responder/api.py b/responder/api.py index 3eb51dc..2c8ccf2 100644 --- a/responder/api.py +++ b/responder/api.py @@ -95,7 +95,7 @@ class API: # raise RuntimeError( # "You need to specify `allowed_hosts` when debug is set to False" # ) - allowed_hosts = ["localhost", "127.0.0.1"] + allowed_hosts = ["*"] self.allowed_hosts = allowed_hosts # Make the static/templates directory if they don't exist. From 954637f7b3688ef8cbb4eb646bf6e43e70ceb11a Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 18:09:52 +0000 Subject: [PATCH 6/8] Pass base_url to the TestClient --- responder/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/responder/api.py b/responder/api.py index 2c8ccf2..4c53450 100644 --- a/responder/api.py +++ b/responder/api.py @@ -518,7 +518,7 @@ class API: """ if self._session is None: - self._session = TestClient(self) + self._session = TestClient(self, base_url=base_url) return self._session def _route_for(self, endpoint): From f7d5514b94872720a032024d5a0f7e89490e652b Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 18:12:21 +0000 Subject: [PATCH 7/8] Fix test base_url --- tests/test_responder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_responder.py b/tests/test_responder.py index a6f3964..e68c9f8 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -536,7 +536,7 @@ def test_redirects(api, session): def one(req, resp): resp.text = "redirected" - assert session.get("/1").url == "http://testserver/1" + assert session.get("/1").url == "http://;/1" def test_session_thoroughly(api, session): From 778cb2dd0fc4a83e366871e4e602922ee472ca29 Mon Sep 17 00:00:00 2001 From: taoufik07 Date: Sun, 28 Oct 2018 18:26:42 +0000 Subject: [PATCH 8/8] Add Tests --- tests/test_responder.py | 51 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_responder.py b/tests/test_responder.py index e68c9f8..58c8317 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -566,3 +566,54 @@ def test_before_responpse(api, session): r = session.get(api.url_for(get)) assert 'x-pizza' in r.headers + +def test_allowed_hosts(): + api = responder.API( + allowed_hosts=[";", "tenant.;"] + ) + + @api.route("/") + def get(req, resp): + pass + + # Exact match + r = api.requests.get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant.;").get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://unkownhost").get(api.url_for(get)) + assert r.status_code == 400 + + # Reset the session + api._session = None + r = api.session(base_url="http://unkown_tenant.;").get(api.url_for(get)) + assert r.status_code == 400 + + api = responder.API( + allowed_hosts=[".;"] + ) + + @api.route("/") + def get(req, resp): + pass + + # Wildcard domains + # Using http://; + r = api.requests.get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant1.;").get(api.url_for(get)) + assert r.status_code == 200 + + # Reset the session + api._session = None + r = api.session(base_url="http://tenant2.;").get(api.url_for(get)) + assert r.status_code == 200