Trusted host

This commit is contained in:
taoufik07
2018-10-28 00:46:39 +01:00
parent 96971a33a7
commit 4fff823def
3 changed files with 47 additions and 0 deletions
+13
View File
@@ -26,6 +26,7 @@ from starlette.websockets import WebSocket
from whitenoise import WhiteNoise
from . import models, status_codes
from .middlewares.trustedhost import TrustedHostMiddleware
from .background import BackgroundQueue
from .formats import get_formats
from .routes import Route
@@ -66,6 +67,7 @@ class API:
enable_hsts=False,
docs_route=None,
cors=False,
allowed_hosts=None
):
self.secret_key = secret_key
@@ -87,6 +89,14 @@ class API:
self.hsts_enabled = enable_hsts
self.cors = cors
self.cors_params = DEFAULT_CORS_PARAMS
if not allowed_hosts:
if not debug:
raise RuntimeError("You need to specify `allowed_hosts` when debug is set to False")
allowed_hosts = ['localhost', '127.0.0.1']
allowed_hosts = ["localhost", "127.0.0.1", ";", "testserver"]
self.allowed_hosts = allowed_hosts
# Make the static/templates directory if they don't exist.
for _dir in (self.static_dir, self.templates_dir):
os.makedirs(_dir, exist_ok=True)
@@ -123,6 +133,9 @@ class API:
if self.hsts_enabled:
self.add_middleware(HTTPSRedirectMiddleware)
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
self.lifespan_handler = LifespanHandler()
if self.cors:
View File
+34
View File
@@ -0,0 +1,34 @@
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
def _is_trusted_host(host, allowed_hosts):
"""
Check if the host matchs the pattern.
Any given pattern starting with a period is considered a wildcard pattern.
"""
host = host.lower()
for pattern in allowed_hosts:
if (
pattern == "*" or pattern == host or
pattern[0] == "." and
(host.endswith(pattern) or host == pattern[1:])
):
return True
return False
class TrustedHostMiddleware:
def __init__(self, app, allowed_hosts):
self.app = app
self.allowed_hosts = allowed_hosts
self.allow_any = "*" in allowed_hosts
def __call__(self, scope):
if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope)
host = headers.get("host").split(":")[0]
print("HH", host, self.allowed_hosts)
print("EE", _is_trusted_host(host, self.allowed_hosts))
if not _is_trusted_host(host, self.allowed_hosts):
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)