diff --git a/Pipfile.lock b/Pipfile.lock
index 6998c54..6a3e346 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -23,10 +23,10 @@
},
"aniso8601": {
"hashes": [
- "sha256:b8a6a9b24611fc50cf2d9b45d371bfdc4fd0581d1cc52254f5502130a776d4af",
- "sha256:bb167645c79f7a438f9dfab6161af9bed75508c645b1f07d1158240841d22673"
+ "sha256:513d2b6637b7853806ae79ffaca6f3e8754bdd547048f5ccc1420aec4b714f1e",
+ "sha256:d10a4bf949f619f719b227ef5386e31f49a2b6d453004b21f02661ccc8670c7b"
],
- "version": "==6.0.0"
+ "version": "==7.0.0"
},
"apispec": {
"hashes": [
@@ -70,10 +70,10 @@
},
"graphene": {
"hashes": [
- "sha256:77d61618132ccd084c343e64c22d806cee18dce73cc86e0f427378dbdeeac287",
- "sha256:acf808d50d053b94f7958414d511489a9e490a7f9563b9be80f6875fc5723d2a"
+ "sha256:09165f03e1591b76bf57b133482db9be6dac72c74b0a628d3c93182af9c5a896",
+ "sha256:2cbe6d4ef15cfc7b7805e0760a0e5b80747161ce1b0f990dfdc0d2cf497c12f9"
],
- "version": "==2.1.7"
+ "version": "==2.1.8"
},
"graphql-core": {
"hashes": [
@@ -116,13 +116,6 @@
],
"version": "==2.8"
},
- "itsdangerous": {
- "hashes": [
- "sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19",
- "sha256:b12271b2047cb23eeb98c8b5622e2e5c5e9abd9784a153e9d8ef9cb4dd09d749"
- ],
- "version": "==1.1.0"
- },
"jinja2": {
"hashes": [
"sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013",
@@ -165,16 +158,10 @@
},
"marshmallow": {
"hashes": [
- "sha256:7ea8540fc7e35be3b0af8b017313944b984d5acdb118b4ba3c270ac9611765c7",
- "sha256:bc91e3f90e86133241ac62ea0dd35217d631a207e8628430bc66c347dbe12f7d"
+ "sha256:23f684b54b1955ebd5bdfbdda4062e438ef86218f14f1a356f570cdf0c016ab3",
+ "sha256:fcfc9ffd75a883da06f30f604a4e81dd0b56eb9438f4d0a8de6bbaa163ce9ec3"
],
- "version": "==3.0.0rc9"
- },
- "parse": {
- "hashes": [
- "sha256:1b68657434d371e5156048ca4a0c5aea5afc6ca59a2fea4dd1a575354f617142"
- ],
- "version": "==1.12.0"
+ "version": "==3.0.1"
},
"promise": {
"hashes": [
@@ -248,9 +235,9 @@
},
"starlette": {
"hashes": [
- "sha256:849553e6dc8dd4faac5d4383219bdf49b199f3f09d60ce374c2a00a42ab7c77d"
+ "sha256:f600bf9d0beeeeebcb143e6d0c4f8858c2b05067d5a4feb446ba7400ba5e5dc5"
],
- "version": "==0.12.7"
+ "version": "==0.12.8"
},
"typesystem": {
"hashes": [
@@ -474,6 +461,7 @@
"sha256:23d3d873e008a513952355379d93cbcab874c58f4f034ff657c7a87422fa64e8",
"sha256:80d2de76188eabfbfcf27e6a37342c2827801e59c4cc14b0371c56fed43820e3"
],
+ "markers": "python_version < '3.8'",
"version": "==0.19"
},
"itsdangerous": {
@@ -525,10 +513,10 @@
},
"marshmallow": {
"hashes": [
- "sha256:7ea8540fc7e35be3b0af8b017313944b984d5acdb118b4ba3c270ac9611765c7",
- "sha256:bc91e3f90e86133241ac62ea0dd35217d631a207e8628430bc66c347dbe12f7d"
+ "sha256:23f684b54b1955ebd5bdfbdda4062e438ef86218f14f1a356f570cdf0c016ab3",
+ "sha256:fcfc9ffd75a883da06f30f604a4e81dd0b56eb9438f4d0a8de6bbaa163ce9ec3"
],
- "version": "==3.0.0rc9"
+ "version": "==3.0.1"
},
"mccabe": {
"hashes": [
@@ -602,11 +590,11 @@
},
"pytest": {
"hashes": [
- "sha256:6ef6d06de77ce2961156013e9dff62f1b2688aa04d0dc244299fe7d67e09370d",
- "sha256:a736fed91c12681a7b34617c8fcefe39ea04599ca72c608751c31d89579a3f77"
+ "sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210",
+ "sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865"
],
"index": "pypi",
- "version": "==5.0.1"
+ "version": "==5.1.2"
},
"pytest-cov": {
"hashes": [
@@ -659,11 +647,11 @@
},
"sphinx": {
"hashes": [
- "sha256:22538e1bbe62b407cf5a8aabe1bb15848aa66bb79559f42f5202bbce6b757a69",
- "sha256:f9a79e746b87921cabc3baa375199c6076d1270cee53915dbd24fdbeaaacc427"
+ "sha256:0d586b0f8c2fc3cc6559c5e8fd6124628110514fda0e5d7c82e682d749d2e845",
+ "sha256:839a3ed6f6b092bb60f492024489cc9e6991360fb9f52ed6361acd510d261069"
],
"index": "pypi",
- "version": "==2.1.2"
+ "version": "==2.2.0"
},
"sphinxcontrib-applehelp": {
"hashes": [
@@ -716,10 +704,10 @@
},
"tqdm": {
"hashes": [
- "sha256:14a285392c32b6f8222ecfbcd217838f88e11630affe9006cd0e94c7eff3cb61",
- "sha256:25d4c0ea02a305a688e7e9c2cdc8f862f989ef2a4701ab28ee963295f5b109ab"
+ "sha256:1be3e4e3198f2d0e47b928e9d9a8ec1b63525db29095cec1467f4c5a4ea8ebf9",
+ "sha256:7e39a30e3d34a7a6539378e39d7490326253b7ee354878a92255656dc4284457"
],
- "version": "==4.32.2"
+ "version": "==4.35.0"
},
"twine": {
"hashes": [
@@ -759,10 +747,10 @@
},
"zipp": {
"hashes": [
- "sha256:4970c3758f4e89a7857a973b1e2a5d75bcdc47794442f2e2dd4fe8e0466e809a",
- "sha256:8a5712cfd3bb4248015eb3b0b3c54a5f6ee3f2425963ef2a0125b8bc40aafaec"
+ "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e",
+ "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335"
],
- "version": "==0.5.2"
+ "version": "==0.6.0"
}
}
}
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index e07d227..f8a040c 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -73,13 +73,29 @@ If the client requests YAML instead (with a header of ``Accept: application/x-ya
Rendering a Template
--------------------
-If you want to render a template, simply use ``api.template``. No need for additional imports::
+Responder provides a built-in light `jinja2 `_ wrapper ``templates.Templates``
+
+Usage::
+
+ from responder.templates import Templates
+
+ templates = Templates()
+
+ @api.route("/hello/{name}/html")
+ def hello(req, resp, name):
+ resp.html = templates.render("hello.html", name=name)
+
+
+Also a ``render_async`` is available::
+
+ resp.html = await templates.render_async("hello.html", who=who)
+
+You can also use the existing ``api.template(filename, *args, **kwargs)`` to render templates::
@api.route("/hello/{who}/html")
def hello_html(req, resp, *, who):
resp.html = api.template('hello.html', who=who)
-The ``api`` instance is available as an object during template rendering.
Setting Response Status Code
----------------------------
diff --git a/docs/source/tour.rst b/docs/source/tour.rst
index 8e1a22a..9c16ede 100644
--- a/docs/source/tour.rst
+++ b/docs/source/tour.rst
@@ -58,13 +58,14 @@ You can make use of Responder's Request and Response objects in your GraphQL res
OpenAPI Schema Support
----------------------
-Responder comes with built-in support for OpenAPI / marshmallow::
+Responder comes with built-in support for OpenAPI / marshmallow
+
+New in Responder `1.4.0`::
import responder
+ from responder.ext.schema import Schema as OpenAPISchema
from marshmallow import Schema, fields
- description = "This is a sample server for a pet store."
- terms_of_service = "http://example.com/terms/"
contact = {
"name": "API Support",
"url": "http://www.example.com/support",
@@ -74,19 +75,21 @@ Responder comes with built-in support for OpenAPI / marshmallow::
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
+
+ api = responder.API()
- api = responder.API(
+ schema = OpenAPISchema(
+ app=api,
title="Web Service",
version="1.0",
openapi="3.0.2",
- description=description,
- terms_of_service=terms_of_service,
+ description="A simple pet store",
+ terms_of_service="http://example.com/terms/",
contact=contact,
license=license,
)
-
- @api.schema("Pet")
+ @schema.schema("Pet")
class PetSchema(Schema):
name = fields.Str()
@@ -108,6 +111,51 @@ Responder comes with built-in support for OpenAPI / marshmallow::
resp.media = PetSchema().dump({"name": "little orange"})
+Old way *It's recommended to use the code above* ::
+
+ import responder
+ from marshmallow import Schema, fields
+
+ contact = {
+ "name": "API Support",
+ "url": "http://www.example.com/support",
+ "email": "support@example.com",
+ }
+ license = {
+ "name": "Apache 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
+ }
+
+ api = responder.API(
+ title="Web Service",
+ version="1.0",
+ openapi="3.0.2",
+ description="A simple pet store",
+ terms_of_service="http://example.com/terms/",
+ contact=contact,
+ license=license,
+ )
+
+ @api.schema("Pet")
+ class PetSchema(Schema):
+ name = fields.Str()
+
+ @api.route("/")
+ def route(req, resp):
+ """A cute furry animal endpoint.
+ ---
+ get:
+ description: Get a random pet
+ responses:
+ 200:
+ description: A pet to be returned
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Pet'
+ """
+ resp.media = PetSchema().dump({"name": "little orange"})
+
::
>>> r = api.session().get("http://;/schema.yml")
@@ -142,7 +190,31 @@ Responder comes with built-in support for OpenAPI / marshmallow::
Interactive Documentation
-------------------------
-Responder can automatically supply API Documentation for you. Using the example above::
+Responder can automatically supply API Documentation for you. Using the example above
+
+The new and recommended way::
+
+ ...
+ from responder.ext.schema import Schema
+ ...
+ api = responder.API()
+
+ schema = Schema(
+ app=api,
+ title="Web Service",
+ version="1.0",
+ openapi="3.0.2",
+ ...
+ docs_route='/docs',
+ ...
+ description=description,
+ terms_of_service=terms_of_service,
+ contact=contact,
+ license=license,
+ )
+
+
+The old way ::
api = responder.API(
title="Web Service",
@@ -157,8 +229,8 @@ Responder can automatically supply API Documentation for you. Using the example
This will make ``/docs`` render interactive documentation for your API.
-Mount a WSGI App (e.g. Flask)
------------------------------
+Mount a WSGI / ASGI Apps (e.g. Flask, Starlette,...)
+----------------------------------------------------
Responder gives you the ability to mount another ASGI / WSGI app at a subroute::
diff --git a/responder/api.py b/responder/api.py
index 08ab569..56b85ea 100644
--- a/responder/api.py
+++ b/responder/api.py
@@ -1,43 +1,33 @@
import json
import os
-from uuid import uuid4
from pathlib import Path
-from base64 import b64encode
-import apistar
-import itsdangerous
import jinja2
import uvicorn
-import yaml
-from apispec import APISpec, yaml_utils
-from apispec.ext.marshmallow import MarshmallowPlugin
+from starlette.exceptions import ExceptionMiddleware
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
+from starlette.middleware.sessions import SessionMiddleware
from starlette.routing import Lifespan
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
from starlette.websockets import WebSocket
-from whitenoise import WhiteNoise
from . import models, status_codes
from .background import BackgroundQueue
from .formats import get_formats
-from .routes import Route
-from .statics import (
- DEFAULT_API_THEME,
- DEFAULT_CORS_PARAMS,
- DEFAULT_SECRET_KEY,
- DEFAULT_SESSION_COOKIE,
-)
-from .templates import GRAPHIQL
+from .routes import Router
+from .statics import DEFAULT_API_THEME, DEFAULT_CORS_PARAMS, DEFAULT_SECRET_KEY
+from .ext.schema import Schema as OpenAPISchema
+from .staticfiles import StaticFiles
+from .templates import Templates
-# TODO: consider moving status codes here
class API:
"""The primary web-service class.
@@ -45,12 +35,6 @@ class API:
:param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist.
:param auto_escape: If ``True``, HTML and XML templates will automatically be escaped.
:param enable_hsts: If ``True``, send all responses to HTTPS URLs.
- :param title: The title of the application (OpenAPI Info Object)
- :param version: The version of the OpenAPI document (OpenAPI Info Object)
- :param description: The description of the OpenAPI document (OpenAPI Info Object)
- :param terms_of_service: A URL to the Terms of Service for the API (OpenAPI Info Object)
- :param contact: The contact dictionary of the application (OpenAPI Contact Object)
- :param license: The license information of the exposed API (OpenAPI License Object)
"""
status_codes = status_codes
@@ -81,13 +65,8 @@ class API:
self.background = BackgroundQueue()
self.secret_key = secret_key
- self.title = title
- self.version = version
- self.description = description
- self.terms_of_service = terms_of_service
- self.contact = contact
- self.license = license
- self.openapi_version = openapi
+
+ self.router = Router()
if static_dir is not None:
if static_route is None:
@@ -106,14 +85,6 @@ class API:
self.templates_dir = templates_dir or self.built_in_templates_dir
- self.apps = {}
- self.routes = {}
- self.before_requests = {"http": [], "ws": []}
- self.docs_theme = DEFAULT_API_THEME
- self.docs_route = docs_route
- self.schemas = {}
- self.session_cookie = DEFAULT_SESSION_COOKIE
-
self.hsts_enabled = enable_hsts
self.cors = cors
self.cors_params = cors_params
@@ -133,33 +104,15 @@ class API:
os.makedirs(_dir, exist_ok=True)
if self.static_dir is not None:
- self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app)
- self.whitenoise.add_files(str(self.static_dir))
-
- self.whitenoise.add_files(
- (
- Path(apistar.__file__).parent
- / "themes"
- / self.docs_theme
- / "static"
- ).resolve()
- )
-
- self.mount(self.static_route, self.whitenoise)
+ self.mount(self.static_route, self.static_app)
self.formats = get_formats()
# Cached requests session.
self._session = None
- if self.openapi_version:
- self.add_route(openapi_route, self.schema_response)
-
- if self.docs_route:
- self.add_route(self.docs_route, self.docs_response)
-
self.default_endpoint = None
- self.app = self.asgi
+ self.app = ExceptionMiddleware(self.router, debug=debug)
self.add_middleware(GZipMiddleware)
if self.hsts_enabled:
@@ -167,28 +120,38 @@ class API:
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
- self.lifespan_handler = Lifespan()
-
if self.cors:
self.add_middleware(CORSMiddleware, **self.cors_params)
self.add_middleware(ServerErrorMiddleware, debug=debug)
+ self.add_middleware(SessionMiddleware, secret_key=self.secret_key)
- # Jinja environment
- self.jinja_env = jinja2.Environment(
- loader=jinja2.FileSystemLoader(
- [str(self.templates_dir), str(self.built_in_templates_dir)],
- followlinks=True,
- ),
- autoescape=jinja2.select_autoescape(["html", "xml"] if auto_escape else []),
- )
- self.jinja_values_base = {"api": self} # Give reference to self.
+ if openapi or docs_route:
+ self.openapi = OpenAPISchema(
+ app=self,
+ title="Web Service",
+ version="1.0",
+ openapi="3.0.2",
+ docs_route=docs_route,
+ description=description,
+ terms_of_service=terms_of_service,
+ contact=contact,
+ license=license,
+ openapi_route=openapi_route,
+ static_route=static_route,
+ )
+
+ # TODO: Update docs for templates
+ self.templates = Templates(directory=templates_dir)
self.requests = (
self.session()
) #: A Requests session that is connected to the ASGI app.
- @staticmethod
- def _default_wsgi_app(environ, start_response):
- pass
+ @property
+ def static_app(self):
+ if not hasattr(self, "_static_app"):
+ assert self.static_dir is not None
+ self._static_app = StaticFiles(directory=self.static_dir)
+ return self._static_app
@staticmethod
def _notfound_wsgi_app(environ, start_response):
@@ -197,10 +160,7 @@ class API:
def before_request(self, websocket=False):
def decorator(f):
- if websocket:
- self.before_requests.setdefault("ws", []).append(f)
- else:
- self.before_requests.setdefault("http", []).append(f)
+ self.router.before_request(f, websocket=websocket)
return f
return decorator
@@ -213,139 +173,20 @@ class API:
def before_ws_requests(self):
return self.before_requests.get("ws", [])
- @property
- def _apispec(self):
-
- info = {}
- if self.description is not None:
- info["description"] = self.description
- if self.terms_of_service is not None:
- info["termsOfService"] = self.terms_of_service
- if self.contact is not None:
- info["contact"] = self.contact
- if self.license is not None:
- info["license"] = self.license
-
- spec = APISpec(
- title=self.title,
- version=self.version,
- openapi_version=self.openapi_version,
- plugins=[MarshmallowPlugin()],
- info=info,
- )
-
- for route in self.routes:
- if self.routes[route].description:
- operations = yaml_utils.load_operations_from_docstring(
- self.routes[route].description
- )
- spec.path(path=route, operations=operations)
-
- for name, schema in self.schemas.items():
- spec.components.schema(name, schema=schema)
-
- return spec
-
- @property
- def openapi(self):
- return self._apispec.to_yaml()
-
def add_middleware(self, middleware_cls, **middleware_config):
self.app = middleware_cls(self.app, **middleware_config)
- async def __call__(self, scope, receive, send):
- if scope["type"] == "lifespan":
- await self.lifespan_handler(scope, receive, send)
- return
-
- path = scope["path"]
- root_path = scope.get("root_path", "")
-
- # Call into a submounted app, if one exists.
- for path_prefix, app in self.apps.items():
- if path.startswith(path_prefix):
- scope["path"] = path[len(path_prefix) :]
- scope["root_path"] = root_path + path_prefix
- try:
- await app(scope, receive, send)
- return
- except TypeError:
- app = WSGIMiddleware(app)
- await app(scope, receive, send)
- return
-
- await self.app(scope, receive, send)
-
- async def asgi(self, scope, receive, send):
- assert scope["type"] in ("http", "websocket")
-
- if scope["type"] == "websocket":
- await self._dispatch_ws(scope=scope, receive=receive, send=send)
- else:
- req = models.Request(scope, receive=receive, api=self)
- resp = await self._dispatch_http(
- req, scope=scope, send=send, receive=receive
- )
- await resp(scope, receive, send)
-
- async def _dispatch_http(self, req, **options):
- # Set formats on Request object.
- req.formats = self.formats
-
- # Get the route.
- route = self.path_matches_route(req.url.path)
- route = self.routes.get(route)
- if route:
- resp = models.Response(req=req, formats=self.formats)
-
- for before_request in self.before_http_requests:
- await self.background(before_request, req=req, resp=resp)
-
- await self._execute_route(route=route, req=req, resp=resp, **options)
- else:
- resp = models.Response(req=req, formats=self.formats)
- self.default_response(req=req, resp=resp, notfound=True)
- self.default_response(req=req, resp=resp)
-
- self._prepare_session(resp)
-
- return resp
-
- async def _dispatch_ws(self, scope, receive, send):
- ws = WebSocket(scope=scope, receive=receive, send=send)
-
- route = self.path_matches_route(ws.url.path)
- route = self.routes.get(route)
-
- if route:
- for before_request in self.before_ws_requests:
- await self.background(before_request, ws=ws)
- await self.background(route.endpoint, ws)
- else:
- await send({"type": "websocket.close", "code": 1000})
-
- def add_schema(self, name, schema, check_existing=True):
- """Adds a mashmallow schema to the API specification."""
- if check_existing:
- assert name not in self.schemas
-
- self.schemas[name] = schema
-
def schema(self, name, **options):
"""Decorator for creating new routes around function and class definitions.
-
Usage::
-
from marshmallow import Schema, fields
-
@api.schema("Pet")
class PetSchema(Schema):
name = fields.Str()
-
"""
def decorator(f):
- self.add_schema(name=name, schema=f, **options)
+ self.openapi.add_schema(name=name, schema=f, **options)
return f
return decorator
@@ -355,97 +196,18 @@ class API:
:param path: The path portion of a URL, to test all known routes against.
"""
- for (route, route_object) in self.routes.items():
- if route_object.does_match(path):
+ for route in self.router.routes:
+ match, _ = route.matches(path)
+ if match:
return route
- @property
- def _signer(self):
- return itsdangerous.Signer(self.secret_key)
-
- def _prepare_session(self, resp):
-
- if resp.session:
- data = self._signer.sign(
- b64encode(json.dumps(resp.session).encode("utf-8"))
- )
- resp.cookies[self.session_cookie] = data.decode("utf-8")
-
- @staticmethod
- def no_response(req, resp, **params):
- pass
-
- async def _execute_route(self, *, route, req, resp, **options):
-
- params = route.incoming_matches(req.url.path)
-
- cont = True
-
- if route.is_function:
- try:
- try:
- # Run the view.
- r = self.background(route.endpoint, req, resp, **params)
- # If it's async, await it.
- if hasattr(r, "cr_running"):
- await r
- except TypeError as e:
- cont = True
- except Exception:
- await self.background(self.default_response, req, resp, error=True)
- raise
-
- if route.is_class_based or cont:
- try:
- view = route.endpoint(**params)
- except TypeError:
- try:
- view = route.endpoint()
- except TypeError:
- view = route.endpoint
- pass
-
- # Run on_request first.
- try:
- # Run the view.
- r = getattr(view, "on_request", self.no_response)
- r = self.background(r, req, resp, **params)
- # If it's async, await it.
- if hasattr(r, "send"):
- await r
- except Exception:
- await self.background(self.default_response, req, resp, error=True)
- raise
-
- # Then run on_method.
- method = req.method
- try:
- # Run the view.
- r = getattr(view, f"on_{method}", self.no_response)
- r = self.background(r, req, resp, **params)
- # If it's async, await it.
- if hasattr(r, "send"):
- await r
- except Exception:
- await self.background(self.default_response, req, resp, error=True)
- raise
-
- def add_event_handler(self, event_type, handler):
- """Adds an event handler to the API.
-
- :param event_type: A string in ("startup", "shutdown")
- :param handler: The function to run. Can be either a function or a coroutine.
- """
-
- self.lifespan_handler.add_event_handler(event_type, handler)
-
def add_route(
self,
route=None,
endpoint=None,
*,
default=False,
- static=False,
+ static=True,
check_existing=True,
websocket=False,
before_request=False,
@@ -456,92 +218,45 @@ class API:
:param endpoint: The endpoint for the route -- can be a callable, or a class.
:param default: If ``True``, all unknown requests will route to this view.
:param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
- :param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
"""
- if before_request:
- if websocket:
- self.before_requests.setdefault("ws", []).append(endpoint)
- else:
- self.before_requests.setdefault("http", []).append(endpoint)
- return
-
- if route is None:
- route = f"/{uuid4().hex}"
-
- if check_existing:
- assert route not in self.routes
+ # Path
if static:
assert self.static_dir is not None
if not endpoint:
- endpoint = self.static_response
+ endpoint = self._static_response
default = True
- if default:
- self.default_endpoint = endpoint
-
- self.routes[route] = Route(route, endpoint, websocket=websocket)
- # TODO: A better data structure or sort it once the app is loaded
- self.routes = dict(
- sorted(self.routes.items(), key=lambda item: item[1]._weight())
+ self.router.add_route(
+ route,
+ endpoint,
+ default=default,
+ websocket=websocket,
+ before_request=before_request,
+ check_existing=check_existing,
)
- def default_response(
- self, req=None, resp=None, websocket=False, notfound=False, error=False
- ):
- if websocket:
- return
-
- if resp.status_code is None:
- resp.status_code = 200
-
- if self.default_endpoint and notfound:
- self.default_endpoint(req=req, resp=resp)
- else:
- if notfound:
- resp.status_code = status_codes.HTTP_404
- resp.text = "Not found."
- if error:
- resp.status_code = status_codes.HTTP_500
- resp.text = "Application error."
-
- def docs_response(self, req, resp):
- resp.html = self.docs
-
- def static_response(self, req, resp):
-
+ async def _static_response(self, req, resp):
assert self.static_dir is not None
index = (self.static_dir / "index.html").resolve()
if os.path.exists(index):
with open(index, "r") as f:
- resp.html = f.read()
+ resp.html = "Hello world !"
else:
resp.status_code = status_codes.HTTP_404
resp.text = "Not found."
- def schema_response(self, req, resp):
- resp.status_code = status_codes.HTTP_200
- resp.headers["Content-Type"] = "application/x-yaml"
- resp.content = self.openapi
-
def redirect(
self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
):
"""Redirects a given response to a given location.
-
:param resp: The Response to mutate.
:param location: The location of the redirect.
:param set_text: If ``True``, sets the Redirect body content automatically.
:param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
"""
-
- # assert resp.status_code.is_300(status_code)
-
- resp.status_code = status_code
- if set_text:
- resp.text = f"Redirecting to: {location}"
- resp.headers.update({"Location": location})
+ resp.redirect(location, set_text=set_text, status_code=status_code)
def on_event(self, event_type: str, **args):
"""Decorator for registering functions or coroutines to run at certain events
@@ -565,6 +280,15 @@ class API:
return decorator
+ def add_event_handler(self, event_type, handler):
+ """Adds an event handler to the API.
+
+ :param event_type: A string in ("startup", "shutdown")
+ :param handler: The function to run. Can be either a function or a coroutine.
+ """
+
+ self.router.lifespan_handler.add_event_handler(event_type, handler)
+
def route(self, route=None, **options):
"""Decorator for creating new routes around function and class definitions.
@@ -588,7 +312,7 @@ class API:
:param route: String representation of the route to be used (shouldn't be parameterized).
:param app: The other WSGI / ASGI app.
"""
- self.apps.update({route: app})
+ self.router.apps.update({route: app})
def session(self, base_url="http://;"):
"""Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.
@@ -600,11 +324,6 @@ class API:
self._session = TestClient(self, base_url=base_url)
return self._session
- def _route_for(self, endpoint):
- for route_object in self.routes.values():
- if endpoint in (route_object.endpoint, route_object.endpoint_name):
- return route_object
-
def url_for(self, endpoint, **params):
# TODO: Absolute_url
"""Given an endpoint, returns a rendered URL for its route.
@@ -612,67 +331,25 @@ class API:
:param endpoint: The route endpoint you're searching for.
:param params: Data to pass into the URL generator (for parameterized URLs).
"""
- route_object = self._route_for(endpoint)
- if route_object:
- return route_object.url(**params)
- raise ValueError
+ return self.router.url_for(endpoint, **params)
- def static_url(self, asset):
- """Given a static asset, return its URL path."""
- assert None not in (self.static_dir, self.static_route)
- return f"{self.static_route}/{str(asset)}"
-
- @property
- def docs(self):
-
- loader = jinja2.PrefixLoader(
- {
- self.docs_theme: jinja2.PackageLoader(
- "apistar", os.path.join("themes", self.docs_theme, "templates")
- )
- }
- )
- env = jinja2.Environment(autoescape=True, loader=loader)
- document = apistar.document.Document()
- document.content = yaml.safe_load(self.openapi)
-
- template = env.get_template("/".join([self.docs_theme, "index.html"]))
-
- return template.render(
- document=document,
- langs=["javascript", "python"],
- code_style=None,
- static_url=self.static_url,
- schema_url="/schema.yml",
- )
-
- def template(self, name_, **values):
+ def template(self, filename, *args, **kwargs):
"""Renders the given `jinja2 `_ template, with provided values supplied.
-
Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.
-
- :param name_: The filename of the jinja2 template, in ``templates_dir``.
- :param values: Data to pass into the template.
+ :param filename: The filename of the jinja2 template, in ``templates_dir``.
+ :param *args: Data to pass into the template.
+ :param *kwargs: Date to pass into the template.
"""
- # Prepopulate values with base
- values = {**self.jinja_values_base, **values}
+ return self.templates.render(filename, *args, **kwargs)
- template = self.jinja_env.get_template(name_)
- return template.render(**values)
-
- def template_string(self, s_, **values):
+ def template_string(self, source, *args, **kwargs):
"""Renders the given `jinja2 `_ template string, with provided values supplied.
-
Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.
-
- :param s_: The template to use.
- :param values: Data to pass into the template.
+ :param source: The template to use.
+ :param *args: Data to pass into the template.
+ :param **kwargs: Data to pass into the template.
"""
- # Prepopulate values with base
- values = {**self.jinja_values_base, **values}
-
- template = self.jinja_env.from_string(s_)
- return template.render(**values)
+ return self.templates.render_string(source, *args, **kwargs)
def serve(self, *, address=None, port=None, debug=False, **options):
"""Runs the application with uvicorn. If the ``PORT`` environment
@@ -704,3 +381,6 @@ class API:
if "debug" not in kwargs:
kwargs.update({"debug": self.debug})
self.serve(**kwargs)
+
+ async def __call__(self, scope, receive, send):
+ await self.app(scope, receive, send)
diff --git a/responder/ext/graphql.py b/responder/ext/graphql/__init__.py
similarity index 90%
rename from responder/ext/graphql.py
rename to responder/ext/graphql/__init__.py
index 806d43e..36f3b32 100644
--- a/responder/ext/graphql.py
+++ b/responder/ext/graphql/__init__.py
@@ -3,7 +3,7 @@ from functools import partial
from graphql_server import default_format_error, encode_execution_results, json_encode
-from ..templates import GRAPHIQL
+from .templates import GRAPHIQL
class GraphQLView:
@@ -44,7 +44,9 @@ class GraphQLView:
show_graphiql = req.method == "get" and req.accepts("text/html")
if show_graphiql:
- resp.content = self.api.template_string(GRAPHIQL, endpoint=req.url.path)
+ resp.content = self.api.templates.render_string(
+ GRAPHIQL, endpoint=req.url.path
+ )
return
query, variables, operation_name = await self._resolve_graphql_query(req)
@@ -63,3 +65,6 @@ class GraphQLView:
async def on_request(self, req, resp):
await self.graphql_response(req, resp, self.schema)
+
+ async def __call__(self, req, resp):
+ await self.on_request(req, resp)
diff --git a/responder/templates/__init__.py b/responder/ext/graphql/templates.py
similarity index 100%
rename from responder/templates/__init__.py
rename to responder/ext/graphql/templates.py
diff --git a/responder/ext/schema/__init__.py b/responder/ext/schema/__init__.py
new file mode 100644
index 0000000..be18b14
--- /dev/null
+++ b/responder/ext/schema/__init__.py
@@ -0,0 +1,160 @@
+import os
+from pathlib import Path
+
+import apistar
+import jinja2
+import yaml
+from apispec import APISpec, yaml_utils
+from apispec.ext.marshmallow import MarshmallowPlugin
+
+from responder.statics import DEFAULT_API_THEME
+from responder.staticfiles import StaticFiles
+from responder import status_codes
+
+
+class Schema:
+ def __init__(
+ self,
+ app,
+ title,
+ version,
+ plugins=None,
+ description=None,
+ terms_of_service=None,
+ contact=None,
+ license=None,
+ openapi=None,
+ openapi_route="/schema.yml",
+ docs_route="/docs/",
+ static_route="/static",
+ ):
+ self.app = app
+ self.schemas = {}
+ self.title = title
+ self.version = version
+ self.description = description
+ self.terms_of_service = terms_of_service
+ self.contact = contact
+ self.license = license
+
+ self.openapi_version = openapi
+ self.openapi_route = openapi_route
+
+ self.docs_theme = DEFAULT_API_THEME
+ self.docs_route = docs_route
+
+ self.plugins = [MarshmallowPlugin()] if plugins is None else plugins
+
+ if self.openapi_version is not None:
+ self.app.add_route(self.openapi_route, self.schema_response)
+
+ if self.docs_route is not None:
+ self.app.add_route(self.docs_route, self.docs_response)
+
+ theme_path = (
+ Path(apistar.__file__).parent / "themes" / self.docs_theme / "static"
+ ).resolve()
+
+ self.static_route = static_route
+
+ self.app.static_app.add_directory(theme_path)
+
+ @property
+ def _apispec(self):
+
+ info = {}
+ if self.description is not None:
+ info["description"] = self.description
+ if self.terms_of_service is not None:
+ info["termsOfService"] = self.terms_of_service
+ if self.contact is not None:
+ info["contact"] = self.contact
+ if self.license is not None:
+ info["license"] = self.license
+
+ spec = APISpec(
+ title=self.title,
+ version=self.version,
+ openapi_version=self.openapi_version,
+ plugins=self.plugins,
+ info=info,
+ )
+
+ for route in self.app.router.routes:
+ if route.description:
+ operations = yaml_utils.load_operations_from_docstring(
+ route.description
+ )
+ spec.path(path=route.route, operations=operations)
+
+ for name, schema in self.schemas.items():
+ spec.components.schema(name, schema=schema)
+
+ return spec
+
+ @property
+ def openapi(self):
+ return self._apispec.to_yaml()
+
+ def add_schema(self, name, schema, check_existing=True):
+ """Adds a mashmallow schema to the API specification."""
+ if check_existing:
+ assert name not in self.schemas
+
+ self.schemas[name] = schema
+
+ def schema(self, name, **options):
+ """Decorator for creating new routes around function and class definitions.
+
+ Usage::
+
+ from marshmallow import Schema, fields
+
+ @api.schema("Pet")
+ class PetSchema(Schema):
+ name = fields.Str()
+
+ """
+
+ def decorator(f):
+ self.add_schema(name=name, schema=f, **options)
+ return f
+
+ return decorator
+
+ @property
+ def docs(self):
+
+ loader = jinja2.PrefixLoader(
+ {
+ self.docs_theme: jinja2.PackageLoader(
+ "apistar", os.path.join("themes", self.docs_theme, "templates")
+ )
+ }
+ )
+ env = jinja2.Environment(autoescape=True, loader=loader)
+ document = apistar.document.Document()
+ document.content = yaml.safe_load(self.openapi)
+
+ template = env.get_template("/".join([self.docs_theme, "index.html"]))
+
+ return template.render(
+ document=document,
+ langs=["javascript", "python"],
+ code_style=None,
+ static_url=self.static_url,
+ schema_url="/schema.yml",
+ )
+
+ def static_url(self, asset):
+ """Given a static asset, return its URL path."""
+ assert self.static_route is not None
+ return f"{self.static_route}/{str(asset)}"
+
+ def docs_response(self, req, resp):
+ resp.html = self.docs
+
+ def schema_response(self, req, resp):
+ resp.status_code = status_codes.HTTP_200
+ resp.headers["Content-Type"] = "application/x-yaml"
+ resp.content = self.openapi
diff --git a/responder/models.py b/responder/models.py
index bbf8a97..d39f318 100644
--- a/responder/models.py
+++ b/responder/models.py
@@ -3,16 +3,18 @@ import io
import inspect
import json
import gzip
+from urllib.parse import parse_qs
from base64 import b64decode
from http.cookies import SimpleCookie
-
import chardet
import rfc3986
import graphene
import yaml
+
from requests.structures import CaseInsensitiveDict
from requests.cookies import RequestsCookieJar
+
from starlette.datastructures import MutableHeaders
from starlette.requests import Request as StarletteRequest, State
from starlette.responses import (
@@ -20,9 +22,7 @@ from starlette.responses import (
StreamingResponse as StarletteStreamingResponse,
)
-from urllib.parse import parse_qs
-
-from .status_codes import HTTP_200
+from .status_codes import HTTP_200, HTTP_301
from .statics import DEFAULT_ENCODING
@@ -105,9 +105,9 @@ class Request:
"_cookies",
]
- def __init__(self, scope, receive, api=None):
+ def __init__(self, scope, receive, api=None, formats=None):
self._starlette = StarletteRequest(scope, receive)
- self.formats = None
+ self.formats = formats
self._encoding = None
self.api = api
self._content = None
@@ -122,14 +122,7 @@ class Request:
@property
def session(self):
"""The session data, in dict form, from the Request."""
- if self.api.session_cookie in self.cookies:
-
- data = self.cookies[self.api.session_cookie]
-
- data = self.api._signer.unsign(data)
- data = b64decode(data)
- return json.loads(data)
- return {}
+ return self._starlette.session
@property
def headers(self):
@@ -182,12 +175,12 @@ class Request:
def state(self) -> State:
"""
Use the state to store additional information.
-
+
This can be a very helpful feature, if you want to hand over
information from a middelware or a route decorator to the
actual route handler.
- For example: ``request.state.time_started = time.time()``
+ Usage: ``request.state.time_started = time.time()``
"""
return self._starlette.state
@@ -300,7 +293,7 @@ class Response:
self.formats = formats
self.cookies = SimpleCookie() #: The cookies set in the Response
self.session = (
- req.session.copy()
+ req.session
) #: The cookie-based session data, in dict form, to add to the Response.
# Property or func/dec
@@ -311,6 +304,12 @@ class Response:
return func
+ def redirect(self, location, *, set_text=True, status_code=HTTP_301):
+ self.status_code = status_code
+ if set_text:
+ self.text = f"Redirecting to: {location}"
+ self.headers.update({"Location": location})
+
@property
async def body(self):
if self._stream is not None:
diff --git a/responder/routes.py b/responder/routes.py
index 997fe05..6ff6535 100644
--- a/responder/routes.py
+++ b/responder/routes.py
@@ -1,43 +1,76 @@
+import asyncio
+import json
import re
-import functools
import inspect
-from parse import parse, with_pattern
+
+from starlette.routing import Lifespan
+from starlette.middleware.wsgi import WSGIMiddleware
+from starlette.websockets import WebSocket, WebSocketClose
+from starlette.concurrency import run_in_threadpool
+from starlette.exceptions import HTTPException
+
+from .models import Request, Response
+from . import status_codes
+from .formats import get_formats
+from .statics import DEFAULT_SESSION_COOKIE
-def _make_convertor(type, pattern):
- @with_pattern(pattern)
- def inner(value):
- return type(value)
-
- return inner
-
-
-_convertors = {
- "int": _make_convertor(int, r"\d+"),
- "str": _make_convertor(str, r"[^/]+"),
- "float": _make_convertor(float, r"\d+(.\d+)?"),
+_CONVERTORS = {
+ "int": (int, r"\d+"),
+ "str": (str, r"[^/]+"),
+ "float": (float, r"\d+(.\d+)?"),
}
+PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
-class Route:
- _param_pattern = re.compile(r"{([^{}]*)}")
- def __init__(self, route, endpoint, *, websocket=False, before_request=False):
+def compile_path(path):
+ path_re = "^"
+ param_convertors = {}
+ idx = 0
+
+ for match in PARAM_RE.finditer(path):
+ param_name, convertor_type = match.groups(default="str")
+ convertor_type = convertor_type.lstrip(":")
+ assert (
+ convertor_type in _CONVERTORS.keys()
+ ), f"Unknown path convertor '{convertor_type}'"
+ convertor, convertor_re = _CONVERTORS[convertor_type]
+
+ path_re += path[idx : match.start()]
+ path_re += rf"(?P<{param_name}>{convertor_re})"
+
+ param_convertors[param_name] = convertor
+
+ idx = match.end()
+
+ path_re += path[idx:] + "$"
+
+ return re.compile(path_re), param_convertors
+
+
+class BaseRoute:
+ def matches(self, scope):
+ raise NotImplementedError()
+
+ async def __call__(self, scope, receive, send):
+ raise NotImplementedError()
+
+
+class Route(BaseRoute):
+ def __init__(self, route, endpoint, *, before_request=False):
+ assert route.startswith("/"), "Route path must start with '/'"
self.route = route
self.endpoint = endpoint
- self.uses_websocket = websocket
self.before_request = before_request
+ self.path_re, self.param_convertors = compile_path(route)
+
def __repr__(self):
return f""
- def __eq__(self, other):
- if hasattr(other, "route"):
- # Being compared to other routes.
- return self.route == other.route
- else:
- # Strings.
- return self.does_match(other)
+ def url(self, **params):
+ return self.route.format(**params)
@property
def endpoint_name(self):
@@ -47,46 +80,247 @@ class Route:
def description(self):
return self.endpoint.__doc__
- @property
- def has_parameters(self):
- return bool(self._param_pattern.search(self.route))
+ def matches(self, scope):
+ if scope["type"] != "http":
+ return False, {}
- @functools.lru_cache(maxsize=None)
- def does_match(self, s):
- if s == self.route:
- return True
+ path = scope["path"]
+ match = self.path_re.match(path)
- named = self.incoming_matches(s)
- return bool(len(named))
+ if match is None:
+ return False, {}
- @functools.lru_cache(maxsize=None)
- def incoming_matches(self, s):
- results = parse(self.route, s, _convertors)
- return results.named if results else {}
+ matched_params = match.groupdict()
+ for key, value in matched_params.items():
+ matched_params[key] = self.param_convertors[key](value)
+
+ return True, {"path_params": {**matched_params}}
+
+ async def __call__(self, scope, receive, send):
+ request = Request(scope, receive, formats=get_formats())
+ response = Response(req=request, formats=get_formats())
+
+ path_params = scope.get("path_params", {})
+ before_requests = scope.get("before_requests", [])
+
+ for before_request in before_requests.get("http", []):
+ if asyncio.iscoroutinefunction(before_request):
+ await before_request(request, response)
+ else:
+ await run_in_threadpool(before_request, request, response)
+
+ views = []
+
+ if inspect.isclass(self.endpoint):
+ endpoint = self.endpoint()
+ on_request = getattr(endpoint, "on_request", None)
+ if on_request:
+ views.append(on_request)
+
+ method_name = f"on_{request.method}"
+ try:
+ view = getattr(endpoint, method_name)
+ views.append(view)
+ except AttributeError:
+ if on_request is None:
+ raise HTTPException(status_code=status_codes.HTTP_405)
+ else:
+ views.append(self.endpoint)
+
+ for view in views:
+ # "Monckey patch" for graphql: explicitly checking __call__
+ if asyncio.iscoroutinefunction(view) or asyncio.iscoroutinefunction(
+ view.__call__
+ ):
+ await view(request, response, **path_params)
+ else:
+ await run_in_threadpool(view, request, response, **path_params)
+
+ if response.status_code is None:
+ response.status_code = status_codes.HTTP_200
+
+ await response(scope, receive, send)
+
+ def __eq__(self, other):
+ # [TODO] compare to str ?
+ return self.route == other.route and self.endpoint == other.endpoint
+
+ def __hash__(self):
+ return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
+
+
+class WebSocketRoute(BaseRoute):
+ def __init__(self, route, endpoint, *, before_request=False):
+ assert route.startswith("/"), "Route path must start with '/'"
+ self.route = route
+ self.endpoint = endpoint
+ self.before_request = before_request
+
+ self.path_re, self.param_convertors = compile_path(route)
+
+ def __repr__(self):
+ return f""
def url(self, **params):
return self.route.format(**params)
- def _weight(self):
- params = set(self._param_pattern.findall(self.route))
- params_count = len(params)
- w = len(self.route.rsplit("}", 1)[-1].strip("/"))
- return params_count != 0, w == 0, -params_count
+ @property
+ def endpoint_name(self):
+ return self.endpoint.__name__
@property
- def is_class_based(self):
- return inspect.isclass(self.endpoint)
+ def description(self):
+ return self.endpoint.__doc__
- @property
- def is_function(self):
- code = hasattr(self.endpoint, "__code__")
- kwdefaults = hasattr(self.endpoint, "__kwdefaults__")
- return all((callable(self.endpoint), code, kwdefaults))
+ def matches(self, scope):
+ if scope["type"] != "websocket":
+ return False, {}
+
+ path = scope["path"]
+ match = self.path_re.match(path)
+
+ if match is None:
+ return False, {}
+
+ matched_params = match.groupdict()
+ for key, value in matched_params.items():
+ matched_params[key] = self.param_convertors[key](value)
+
+ return True, {"path_params": {**matched_params}}
+
+ async def __call__(self, scope, receive, send):
+ ws = WebSocket(scope, receive, send)
+
+ before_requests = scope.get("before_requests", [])
+ for before_request in before_requests.get("ws", []):
+ await before_request(ws)
+
+ await self.endpoint(ws)
+
+ def __eq__(self, other):
+ # [TODO] compare to str ?
+ return self.route == other.route and self.endpoint == other.endpoint
def __hash__(self):
- return (
- hash(self.route)
- ^ hash(self.endpoint)
- ^ hash(self.uses_websocket)
- ^ hash(self.before_request)
+ return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request)
+
+
+class Router:
+ def __init__(self, routes=None, default_response=None, before_requests=None):
+ self.routes = [] if routes is None else list(routes)
+ # [TODO] Make its own router
+ self.apps = {}
+ self.default_endpoint = (
+ self.default_response if default_response is None else default_response
)
+ self.lifespan_handler = Lifespan()
+ self.before_requests = (
+ {"http": [], "ws": []} if before_requests is None else before_requests
+ )
+
+ def add_route(
+ self,
+ route=None,
+ endpoint=None,
+ *,
+ default=False,
+ websocket=False,
+ before_request=False,
+ check_existing=False,
+ ):
+ """ Adds a route to the router.
+ :param route: A string representation of the route
+ :param endpoint: The endpoint for the route -- can be callable, or class.
+ :param default: If ``True``, all unknown requests will route to this view.
+ """
+ if before_request:
+ if websocket:
+ self.before_requests.setdefault("ws", []).append(endpoint)
+ else:
+ self.before_requests.setdefault("http", []).append(endpoint)
+ return
+
+ if check_existing:
+ assert not self.routes or route not in (
+ item.route for item in self.routes
+ ), f"Route '{route}' already exists"
+
+ if default:
+ self.default_endpoint = endpoint
+
+ if websocket:
+ route = WebSocketRoute(route, endpoint)
+ else:
+ route = Route(route, endpoint)
+
+ self.routes.append(route)
+
+ def mount(self, route, app):
+ """Mounts ASGI / WSGI applications at a given route
+ """
+ self.apps.update(route, app)
+
+ def before_request(self, endpoint, websocket=False):
+ if websocket:
+ self.before_requests.setdefault("ws", []).append(endpoint)
+ else:
+ self.before_requests.setdefault("http", []).append(endpoint)
+
+ def url_for(self, endpoint, **params):
+ # TODO: Check for params
+ for route in self.routes:
+ if endpoint in (route.endpoint, route.endpoint.__name__):
+ return route.url(**params)
+ return None
+
+ async def default_response(self, scope, receive, send):
+ if scope["type"] == "websocket":
+ websocket_close = WebSocketClose()
+ await websocket_close(receive, send)
+ return
+
+ request = Request(scope, receive)
+ response = Response(request, formats=get_formats())
+
+ raise HTTPException(status_code=status_codes.HTTP_404)
+
+ def _resolve_route(self, scope):
+ for route in self.routes:
+ matches, child_scope = route.matches(scope)
+ if matches:
+ scope.update(child_scope)
+ return route
+ return None
+
+ async def __call__(self, scope, receive, send):
+ assert scope["type"] in ("http", "websocket", "lifespan")
+
+ if scope["type"] == "lifespan":
+ await self.lifespan_handler(scope, receive, send)
+ return
+
+ path = scope["path"]
+ root_path = scope.get("root_path", "")
+
+ # Call into a submounted app, if one exists.
+ for path_prefix, app in self.apps.items():
+ if path.startswith(path_prefix):
+ scope["path"] = path[len(path_prefix) :]
+ scope["root_path"] = root_path + path_prefix
+ try:
+ await app(scope, receive, send)
+ return
+ except TypeError:
+ app = WSGIMiddleware(app)
+ await app(scope, receive, send)
+ return
+
+ route = self._resolve_route(scope)
+
+ scope["before_requests"] = self.before_requests
+
+ if route is not None:
+ await route(scope, receive, send)
+ return
+
+ await self.default_response(scope, receive, send)
diff --git a/responder/staticfiles.py b/responder/staticfiles.py
new file mode 100644
index 0000000..66a57c1
--- /dev/null
+++ b/responder/staticfiles.py
@@ -0,0 +1,18 @@
+import typing
+
+from starlette.staticfiles import StaticFiles
+
+
+class StaticFiles(StaticFiles):
+ """I've created an issue to disccuss allowing multiple directories in starletter's `StaticFiles`.
+
+ https://github.com/encode/starlette/issues/625
+
+ I've also made a PR to add this method to starlette StaticFiles
+ Once accepted we will remove this.
+
+ https://github.com/encode/starlette/pull/626
+ """
+
+ def add_directory(self, directory: str) -> None:
+ self.all_directories = [*self.all_directories, *self.get_directories(directory)]
diff --git a/responder/templates.py b/responder/templates.py
new file mode 100644
index 0000000..6419a02
--- /dev/null
+++ b/responder/templates.py
@@ -0,0 +1,55 @@
+from contextlib import contextmanager
+
+import jinja2
+
+
+class Templates:
+ def __init__(self, directory="templates", autoescape=True, context=None):
+ self.directory = directory
+ self._env = jinja2.Environment(
+ loader=jinja2.FileSystemLoader([str(self.directory)]), autoescape=autoescape
+ )
+ self.default_context = {} if context is None else {**context}
+ self._env.globals.update(self.default_context)
+
+ @property
+ def context(self):
+ return self._env.globals
+
+ @context.setter
+ def context(self, context):
+ self._env.globals = {**self.default_context, **context}
+
+ def get_template(self, name):
+ return self._env.get_template(name)
+
+ def render(self, template, *args, **kwargs):
+ """Renders the given `jinja2 `_ template, with provided values supplied.
+
+ :param template: The filename of the jinja2 template.
+ :param **kwargs: Data to pass into the template.
+ :param **kwargs: Data to pass into the template.
+ """
+ return self.get_template(template).render(*args, **kwargs)
+
+ @contextmanager
+ def _async(self):
+ self._env.is_async = True
+ try:
+ yield
+ finally:
+ self._env.is_async = False
+
+ async def render_async(self, template, *args, **kwargs):
+ with self._async():
+ return await self.get_template(template).render_async(*args, **kwargs)
+
+ def render_string(self, source, *args, **kwargs):
+ """Renders the given `jinja2 `_ template string, with provided values supplied.
+
+ :param source: The template to use.
+ :param *args, **kwargs: Data to pass into the template.
+ :param **kwargs: Data to pass into the template.
+ """
+ template = self._env.from_string(source)
+ return template.render(*args, **kwargs)
diff --git a/setup.py b/setup.py
index df69c91..1d5dcd3 100644
--- a/setup.py
+++ b/setup.py
@@ -27,10 +27,9 @@ required = [
"aiofiles",
"pyyaml",
"requests",
- "graphene",
+ "graphene<3.0",
"graphql-server-core>=1.1",
"jinja2",
- "parse",
"uvloop; sys_platform != 'win32' and sys_platform != 'cygwin' and sys_platform != 'cli'",
"rfc3986",
"python-multipart",
@@ -39,7 +38,6 @@ required = [
"marshmallow",
"whitenoise",
"docopt",
- "itsdangerous",
"requests-toolbelt",
"apistar",
]
diff --git a/tests/test_responder.py b/tests/test_responder.py
index 2debc2b..2935306 100644
--- a/tests/test_responder.py
+++ b/tests/test_responder.py
@@ -7,6 +7,7 @@ import responder
import requests
import string
import io
+from responder.routes import Router, Route, WebSocketRoute
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
@@ -19,6 +20,47 @@ def test_api_basic_route(api):
resp.text = "hello world!"
+def test_route_repr():
+ def home(req, resp):
+ """Home page
+ """
+ resp.text = "Hello !"
+
+ route = Route("/", home)
+
+ assert route.__repr__() == f""
+
+ assert route.endpoint_name == home.__name__
+ assert route.description == home.__doc__
+
+
+def test_websocket_route_repr():
+ def chat_endpoint(ws):
+ """Chat
+ """
+ pass
+
+ route = WebSocketRoute("/", chat_endpoint)
+
+ assert route.__repr__() == f""
+
+ assert route.endpoint_name == chat_endpoint.__name__
+ assert route.description == chat_endpoint.__doc__
+
+
+def test_route_eq():
+ def home(req, resp):
+ resp.text = "Hello !"
+
+ assert Route("/", home) == Route("/", home)
+
+ def chat(ws):
+ pass
+
+ assert WebSocketRoute("/", home) == WebSocketRoute("/", home)
+
+
+"""
def test_api_basic_route_overlap(api):
@api.route("/")
def home(req, resp):
@@ -29,39 +71,7 @@ def test_api_basic_route_overlap(api):
@api.route("/")
def home2(req, resp):
resp.text = "hello world!"
-
-
-def test_api_basic_route_overlap_alternative(api):
- @api.route("/")
- def home(req, resp):
- resp.text = "hello world!"
-
- def home2(req, resp):
- resp.text = "hello world!"
-
- with pytest.raises(AssertionError):
- api.add_route("/", home2)
-
-
-def test_api_basic_route_overlap_allowed(api):
- @api.route("/")
- def home(req, resp):
- resp.text = "hello world!"
-
- def home2(req, resp):
- resp.text = "hello world!"
-
- api.add_route("/", home2, check_existing=False)
-
-
-def test_api_basic_route_overlap_allowed_alternative(api):
- @api.route("/")
- def home(req, resp):
- resp.text = "hello world!"
-
- @api.route("/", check_existing=False)
- def home2(req, resp):
- resp.text = "hello world!"
+"""
def test_class_based_view_registration(api):
@@ -74,10 +84,10 @@ def test_class_based_view_registration(api):
def test_class_based_view_parameters(api):
@api.route("/{greeting}")
class Greeting:
- def on_request(self, req, resp, *, greeting):
- resp.text = f"{greeting}, world!"
+ pass
- assert api.session().get("http://;/Hello").ok
+ resp = api.session().get("http://;/Hello")
+ assert resp.status_code == api.status_codes.HTTP_405
def test_requests_session(api):
@@ -85,14 +95,14 @@ def test_requests_session(api):
assert api.requests
-def test_requests_session_works(api, url):
+def test_requests_session_works(api):
TEXT = "spiral out"
@api.route("/")
def hello(req, resp):
resp.text = TEXT
- assert api.requests.get(url("/")).text == TEXT
+ assert api.requests.get("/").text == TEXT
def test_status_code(api):
@@ -338,13 +348,45 @@ def test_yaml_downloads(api):
assert yaml.safe_load(r.content) == dump
+def test_schema_generation_explicit():
+ import responder
+ from responder.ext.schema import Schema as OpenAPISchema
+ import marshmallow
+
+ api = responder.API()
+
+ schema = OpenAPISchema(app=api, title="Web Service", version="1.0", openapi="3.0.2")
+
+ @schema.schema("Pet")
+ class PetSchema(marshmallow.Schema):
+ name = marshmallow.fields.Str()
+
+ @api.route("/")
+ def route(req, resp):
+ """A cute furry animal endpoint.
+ ---
+ get:
+ description: Get a random pet
+ responses:
+ 200:
+ description: A pet to be returned
+ schema:
+ $ref: "#/components/schemas/Pet"
+ """
+ resp.media = PetSchema().dump({"name": "little orange"})
+
+ r = api.requests.get("http://;/schema.yml")
+ dump = yaml.safe_load(r.content)
+
+ assert dump
+ assert dump["openapi"] == "3.0.2"
+
+
def test_schema_generation():
import responder
from marshmallow import Schema, fields
- api = responder.API(
- title="Web Service", openapi="3.0.2", allowed_hosts=["testserver", ";"]
- )
+ api = responder.API(title="Web Service", openapi="3.0.2")
@api.schema("Pet")
class PetSchema(Schema):
@@ -371,6 +413,60 @@ def test_schema_generation():
assert dump["openapi"] == "3.0.2"
+def test_documentation_explicit():
+ import responder
+ from responder.ext.schema import Schema as OpenAPISchema
+
+ import marshmallow
+
+ description = "This is a sample server for a pet store."
+ terms_of_service = "http://example.com/terms/"
+ contact = {
+ "name": "API Support",
+ "url": "http://www.example.com/support",
+ "email": "support@example.com",
+ }
+ license = {
+ "name": "Apache 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
+ }
+
+ api = responder.API(allowed_hosts=["testserver", ";"])
+
+ schema = OpenAPISchema(
+ app=api,
+ title="Web Service",
+ version="1.0",
+ openapi="3.0.2",
+ docs_route="/docs",
+ description=description,
+ terms_of_service=terms_of_service,
+ contact=contact,
+ license=license,
+ )
+
+ @schema.schema("Pet")
+ class PetSchema(marshmallow.Schema):
+ name = marshmallow.fields.Str()
+
+ @api.route("/")
+ def route(req, resp):
+ """A cute furry animal endpoint.
+ ---
+ get:
+ description: Get a random pet
+ responses:
+ 200:
+ description: A pet to be returned
+ schema:
+ $ref: "#/components/schemas/Pet"
+ """
+ resp.media = PetSchema().dump({"name": "little orange"})
+
+ r = api.requests.get("/docs")
+ assert "html" in r.text
+
+
def test_documentation():
import responder
from marshmallow import Schema, fields
@@ -485,7 +581,7 @@ def test_sessions(api):
assert r.json() == {"hello": "world"}
-def test_template_rendering(api):
+def test_template_string_rendering(api):
@api.route("/")
def view(req, resp):
resp.content = api.template_string("{{ var }}", var="hello")
@@ -582,7 +678,7 @@ def test_before_websockets(api):
await ws.send_json(payload)
await ws.close()
- @api.route(before_request=True, websocket=True)
+ @api.before_request(websocket=True)
async def before_request(ws):
await ws.accept()
await ws.send_json({"before": "request"})
@@ -643,6 +739,10 @@ def test_before_response(api, session):
def get(req, resp):
resp.media = req.session
+ @api.route(before_request=True)
+ async def async_before_request(req, resp):
+ resp.headers["x-pizza"] = "1"
+
@api.route(before_request=True)
def before_request(req, resp):
resp.headers["x-pizza"] = "1"
@@ -794,25 +894,6 @@ def test_staticfiles_none_dir(tmpdir):
api.add_route("/spa", static=True)
-def test_staticfiles_none_dir_route(tmpdir):
- api = responder.API(static_dir=None, static_route=None)
- session = api.session()
-
- static_dir = tmpdir.mkdir("static")
-
- asset = create_asset(static_dir)
-
- static_route = api.static_route
-
- # ok
- r = session.get(f"{static_route}/{asset.basename}")
- assert r.status_code == api.status_codes.HTTP_404
-
- # dir listing
- r = session.get(f"{static_route}")
- assert r.status_code == api.status_codes.HTTP_404
-
-
def test_response_html_property(api):
@api.route("/")
def view(req, resp):
diff --git a/tests/test_routes.py b/tests/test_routes.py
deleted file mode 100644
index 0c19e30..0000000
--- a/tests/test_routes.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import pytest
-from responder import routes
-
-
-def setup_function(function):
- routes.Route.incoming_matches.cache_clear()
-
-
-@pytest.mark.parametrize(
- "route, expected",
- [
- pytest.param("/", False, id="home path without params"),
- pytest.param("/test_path", False, id="sub path without params"),
- pytest.param("/{test_path}", True, id="path with params"),
- ],
-)
-def test_parameter(route, expected):
- r = routes.Route(route, "test_endpoint")
- assert r.has_parameters is expected
-
-
-def test_url():
- r = routes.Route("/{my_path}", "test_endpoint")
- url = r.url(my_path="path")
- assert url == "/path"
-
-
-def test_equal():
- r = routes.Route("/{path_param}", "test_endpoint")
- r2 = routes.Route("/{path_param}", "test_endpoint")
- r3 = routes.Route("/test_path", "test_endpoint")
-
- assert r == r2
- assert r != r3
-
-
-def test_incoming_matches():
- # Test Route with one param
- r = routes.Route("/{greetings}", "test_endpoint")
- assert r.incoming_matches("/hello") == {"greetings": "hello"}
- assert r.incoming_matches("/foo") == {"greetings": "foo"}
-
- # Test Route with two params
- r = routes.Route("/{greetings}/{name}", "test_endpoint")
- assert r.incoming_matches("/hi/john") == {"greetings": "hi", "name": "john"}
- assert r.incoming_matches("/hello/jane") == {"greetings": "hello", "name": "jane"}
-
- # Test Route with no param
- r = routes.Route("/hello", "test_endpoint")
- assert r.incoming_matches("/hello") == {}
- assert r.incoming_matches("/bye") == {}
-
-
-def test_incoming_matches_cache():
- r = routes.Route("/hello", "test_endpoint")
- r.incoming_matches("/hello")
- assert r.incoming_matches.cache_info().hits == 0
- r.incoming_matches("/hello")
- assert r.incoming_matches.cache_info().hits == 1
-
-
-def test_incoming_matches_with_concrete_path_no_match():
- r = routes.Route("/concrete_path", "test_endpoint")
- assert r.incoming_matches("hello") == {}
-
-
-@pytest.mark.parametrize(
- "route, match, expected",
- [
- pytest.param(
- "/{path_param}",
- "/{path_param}",
- True,
- id="with both parametrized path match",
- ),
- pytest.param(
- "/concrete", "/concrete", True, id="with both concrete path match"
- ),
- pytest.param("/concrete", "/no_match", False, id="with no match"),
- ],
-)
-def test_does_match_with_route(route, match, expected):
- r = routes.Route(route, "test_endpoint")
- assert r.does_match(match) == expected
-
-
-@pytest.mark.parametrize(
- "path_param, expected_weight",
- [
- pytest.param("/{greetings}", (True, True, -1), id="with one param"),
- pytest.param(
- "/{greetings}.{name}",
- (True, True, -2),
- id="with 2 params and dot in the middle",
- ),
- pytest.param(
- "/{greetings}/{name}", (True, True, -2), id="with 2 params and subpath"
- ),
- pytest.param(
- "/{greetings}/{name}/{hello}",
- (True, True, -3),
- id="with 3 params and subpath",
- ),
- pytest.param(
- "/{greetings}_{name}", (True, True, -2), id="with 2 params and underscore"
- ),
- pytest.param("/{greetings}/test", (True, False, -1), id="with one param"),
- pytest.param(
- "/{greetings}.{name}/test",
- (True, False, -2),
- id="with 2 params and dot in the middle",
- ),
- pytest.param(
- "/{greetings}/{name}/test",
- (True, False, -2),
- id="with 2 params and subpath",
- ),
- pytest.param(
- "/{greetings}/{name}/{hello}/test",
- (True, False, -3),
- id="with 3 params and subpath",
- ),
- pytest.param(
- "/{greetings}_{name}/test",
- (True, False, -2),
- id="with 2 params and underscore",
- ),
- pytest.param("/hello", (False, False, 0), id="without params"),
- ],
-)
-def test_weight(path_param, expected_weight):
- r = routes.Route(path_param, "test_endpoint")
- assert r._weight() == expected_weight
-
-
-@pytest.mark.parametrize(
- "route, path, expected_result",
- [
- pytest.param("/{greetings:str}", "/hello", {"greetings": "hello"}),
- pytest.param(
- "/{greetings:str}/{who}",
- "/hello/Laidia",
- {"greetings": "hello", "who": "Laidia"},
- ),
- pytest.param("/{birth_date:int}", "/1937", {"birth_date": 1937}),
- pytest.param(
- "/{name:str}/{age:int}", "/Fatna/80", {"name": "Fatna", "age": 80}
- ),
- pytest.param(
- "/{x:float}/{y:float}", "/10.20/75", {"x": float(10.20), "y": float(75)}
- ),
- pytest.param("/{name:str}/{age:int}", "/Fatna/eighty", {}),
- pytest.param("/{greetings:int}", "/hello", {}),
- pytest.param("/{name:float}", "/Fatna", {}),
- ],
-)
-def test_custom_specifiers(route, path, expected_result):
- r = routes.Route(route, "test_endpoint")
- assert r.incoming_matches(path) == expected_result