diff --git a/Pipfile.lock b/Pipfile.lock index 6998c54..6a3e346 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -23,10 +23,10 @@ }, "aniso8601": { "hashes": [ - "sha256:b8a6a9b24611fc50cf2d9b45d371bfdc4fd0581d1cc52254f5502130a776d4af", - "sha256:bb167645c79f7a438f9dfab6161af9bed75508c645b1f07d1158240841d22673" + "sha256:513d2b6637b7853806ae79ffaca6f3e8754bdd547048f5ccc1420aec4b714f1e", + "sha256:d10a4bf949f619f719b227ef5386e31f49a2b6d453004b21f02661ccc8670c7b" ], - "version": "==6.0.0" + "version": "==7.0.0" }, "apispec": { "hashes": [ @@ -70,10 +70,10 @@ }, "graphene": { "hashes": [ - "sha256:77d61618132ccd084c343e64c22d806cee18dce73cc86e0f427378dbdeeac287", - "sha256:acf808d50d053b94f7958414d511489a9e490a7f9563b9be80f6875fc5723d2a" + "sha256:09165f03e1591b76bf57b133482db9be6dac72c74b0a628d3c93182af9c5a896", + "sha256:2cbe6d4ef15cfc7b7805e0760a0e5b80747161ce1b0f990dfdc0d2cf497c12f9" ], - "version": "==2.1.7" + "version": "==2.1.8" }, "graphql-core": { "hashes": [ @@ -116,13 +116,6 @@ ], "version": "==2.8" }, - "itsdangerous": { - "hashes": [ - "sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19", - "sha256:b12271b2047cb23eeb98c8b5622e2e5c5e9abd9784a153e9d8ef9cb4dd09d749" - ], - "version": "==1.1.0" - }, "jinja2": { "hashes": [ "sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013", @@ -165,16 +158,10 @@ }, "marshmallow": { "hashes": [ - "sha256:7ea8540fc7e35be3b0af8b017313944b984d5acdb118b4ba3c270ac9611765c7", - "sha256:bc91e3f90e86133241ac62ea0dd35217d631a207e8628430bc66c347dbe12f7d" + "sha256:23f684b54b1955ebd5bdfbdda4062e438ef86218f14f1a356f570cdf0c016ab3", + "sha256:fcfc9ffd75a883da06f30f604a4e81dd0b56eb9438f4d0a8de6bbaa163ce9ec3" ], - "version": "==3.0.0rc9" - }, - "parse": { - "hashes": [ - "sha256:1b68657434d371e5156048ca4a0c5aea5afc6ca59a2fea4dd1a575354f617142" - ], - "version": "==1.12.0" + "version": "==3.0.1" }, "promise": { "hashes": [ @@ -248,9 +235,9 @@ }, "starlette": { "hashes": [ - "sha256:849553e6dc8dd4faac5d4383219bdf49b199f3f09d60ce374c2a00a42ab7c77d" + "sha256:f600bf9d0beeeeebcb143e6d0c4f8858c2b05067d5a4feb446ba7400ba5e5dc5" ], - "version": "==0.12.7" + "version": "==0.12.8" }, "typesystem": { "hashes": [ @@ -474,6 +461,7 @@ "sha256:23d3d873e008a513952355379d93cbcab874c58f4f034ff657c7a87422fa64e8", "sha256:80d2de76188eabfbfcf27e6a37342c2827801e59c4cc14b0371c56fed43820e3" ], + "markers": "python_version < '3.8'", "version": "==0.19" }, "itsdangerous": { @@ -525,10 +513,10 @@ }, "marshmallow": { "hashes": [ - "sha256:7ea8540fc7e35be3b0af8b017313944b984d5acdb118b4ba3c270ac9611765c7", - "sha256:bc91e3f90e86133241ac62ea0dd35217d631a207e8628430bc66c347dbe12f7d" + "sha256:23f684b54b1955ebd5bdfbdda4062e438ef86218f14f1a356f570cdf0c016ab3", + "sha256:fcfc9ffd75a883da06f30f604a4e81dd0b56eb9438f4d0a8de6bbaa163ce9ec3" ], - "version": "==3.0.0rc9" + "version": "==3.0.1" }, "mccabe": { "hashes": [ @@ -602,11 +590,11 @@ }, "pytest": { "hashes": [ - "sha256:6ef6d06de77ce2961156013e9dff62f1b2688aa04d0dc244299fe7d67e09370d", - "sha256:a736fed91c12681a7b34617c8fcefe39ea04599ca72c608751c31d89579a3f77" + "sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210", + "sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865" ], "index": "pypi", - "version": "==5.0.1" + "version": "==5.1.2" }, "pytest-cov": { "hashes": [ @@ -659,11 +647,11 @@ }, "sphinx": { "hashes": [ - "sha256:22538e1bbe62b407cf5a8aabe1bb15848aa66bb79559f42f5202bbce6b757a69", - "sha256:f9a79e746b87921cabc3baa375199c6076d1270cee53915dbd24fdbeaaacc427" + "sha256:0d586b0f8c2fc3cc6559c5e8fd6124628110514fda0e5d7c82e682d749d2e845", + "sha256:839a3ed6f6b092bb60f492024489cc9e6991360fb9f52ed6361acd510d261069" ], "index": "pypi", - "version": "==2.1.2" + "version": "==2.2.0" }, "sphinxcontrib-applehelp": { "hashes": [ @@ -716,10 +704,10 @@ }, "tqdm": { "hashes": [ - "sha256:14a285392c32b6f8222ecfbcd217838f88e11630affe9006cd0e94c7eff3cb61", - "sha256:25d4c0ea02a305a688e7e9c2cdc8f862f989ef2a4701ab28ee963295f5b109ab" + "sha256:1be3e4e3198f2d0e47b928e9d9a8ec1b63525db29095cec1467f4c5a4ea8ebf9", + "sha256:7e39a30e3d34a7a6539378e39d7490326253b7ee354878a92255656dc4284457" ], - "version": "==4.32.2" + "version": "==4.35.0" }, "twine": { "hashes": [ @@ -759,10 +747,10 @@ }, "zipp": { "hashes": [ - "sha256:4970c3758f4e89a7857a973b1e2a5d75bcdc47794442f2e2dd4fe8e0466e809a", - "sha256:8a5712cfd3bb4248015eb3b0b3c54a5f6ee3f2425963ef2a0125b8bc40aafaec" + "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", + "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" ], - "version": "==0.5.2" + "version": "==0.6.0" } } } diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e07d227..f8a040c 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -73,13 +73,29 @@ If the client requests YAML instead (with a header of ``Accept: application/x-ya Rendering a Template -------------------- -If you want to render a template, simply use ``api.template``. No need for additional imports:: +Responder provides a built-in light `jinja2 `_ wrapper ``templates.Templates`` + +Usage:: + + from responder.templates import Templates + + templates = Templates() + + @api.route("/hello/{name}/html") + def hello(req, resp, name): + resp.html = templates.render("hello.html", name=name) + + +Also a ``render_async`` is available:: + + resp.html = await templates.render_async("hello.html", who=who) + +You can also use the existing ``api.template(filename, *args, **kwargs)`` to render templates:: @api.route("/hello/{who}/html") def hello_html(req, resp, *, who): resp.html = api.template('hello.html', who=who) -The ``api`` instance is available as an object during template rendering. Setting Response Status Code ---------------------------- diff --git a/docs/source/tour.rst b/docs/source/tour.rst index 8e1a22a..9c16ede 100644 --- a/docs/source/tour.rst +++ b/docs/source/tour.rst @@ -58,13 +58,14 @@ You can make use of Responder's Request and Response objects in your GraphQL res OpenAPI Schema Support ---------------------- -Responder comes with built-in support for OpenAPI / marshmallow:: +Responder comes with built-in support for OpenAPI / marshmallow + +New in Responder `1.4.0`:: import responder + from responder.ext.schema import Schema as OpenAPISchema from marshmallow import Schema, fields - description = "This is a sample server for a pet store." - terms_of_service = "http://example.com/terms/" contact = { "name": "API Support", "url": "http://www.example.com/support", @@ -74,19 +75,21 @@ Responder comes with built-in support for OpenAPI / marshmallow:: "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0.html", } + + api = responder.API() - api = responder.API( + schema = OpenAPISchema( + app=api, title="Web Service", version="1.0", openapi="3.0.2", - description=description, - terms_of_service=terms_of_service, + description="A simple pet store", + terms_of_service="http://example.com/terms/", contact=contact, license=license, ) - - @api.schema("Pet") + @schema.schema("Pet") class PetSchema(Schema): name = fields.Str() @@ -108,6 +111,51 @@ Responder comes with built-in support for OpenAPI / marshmallow:: resp.media = PetSchema().dump({"name": "little orange"}) +Old way *It's recommended to use the code above* :: + + import responder + from marshmallow import Schema, fields + + contact = { + "name": "API Support", + "url": "http://www.example.com/support", + "email": "support@example.com", + } + license = { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + } + + api = responder.API( + title="Web Service", + version="1.0", + openapi="3.0.2", + description="A simple pet store", + terms_of_service="http://example.com/terms/", + contact=contact, + license=license, + ) + + @api.schema("Pet") + class PetSchema(Schema): + name = fields.Str() + + @api.route("/") + def route(req, resp): + """A cute furry animal endpoint. + --- + get: + description: Get a random pet + responses: + 200: + description: A pet to be returned + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' + """ + resp.media = PetSchema().dump({"name": "little orange"}) + :: >>> r = api.session().get("http://;/schema.yml") @@ -142,7 +190,31 @@ Responder comes with built-in support for OpenAPI / marshmallow:: Interactive Documentation ------------------------- -Responder can automatically supply API Documentation for you. Using the example above:: +Responder can automatically supply API Documentation for you. Using the example above + +The new and recommended way:: + + ... + from responder.ext.schema import Schema + ... + api = responder.API() + + schema = Schema( + app=api, + title="Web Service", + version="1.0", + openapi="3.0.2", + ... + docs_route='/docs', + ... + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + ) + + +The old way :: api = responder.API( title="Web Service", @@ -157,8 +229,8 @@ Responder can automatically supply API Documentation for you. Using the example This will make ``/docs`` render interactive documentation for your API. -Mount a WSGI App (e.g. Flask) ------------------------------ +Mount a WSGI / ASGI Apps (e.g. Flask, Starlette,...) +---------------------------------------------------- Responder gives you the ability to mount another ASGI / WSGI app at a subroute:: diff --git a/responder/api.py b/responder/api.py index 08ab569..56b85ea 100644 --- a/responder/api.py +++ b/responder/api.py @@ -1,43 +1,33 @@ import json import os -from uuid import uuid4 from pathlib import Path -from base64 import b64encode -import apistar -import itsdangerous import jinja2 import uvicorn -import yaml -from apispec import APISpec, yaml_utils -from apispec.ext.marshmallow import MarshmallowPlugin +from starlette.exceptions import ExceptionMiddleware from starlette.middleware.wsgi import WSGIMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.middleware.sessions import SessionMiddleware from starlette.routing import Lifespan from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient from starlette.websockets import WebSocket -from whitenoise import WhiteNoise from . import models, status_codes from .background import BackgroundQueue from .formats import get_formats -from .routes import Route -from .statics import ( - DEFAULT_API_THEME, - DEFAULT_CORS_PARAMS, - DEFAULT_SECRET_KEY, - DEFAULT_SESSION_COOKIE, -) -from .templates import GRAPHIQL +from .routes import Router +from .statics import DEFAULT_API_THEME, DEFAULT_CORS_PARAMS, DEFAULT_SECRET_KEY +from .ext.schema import Schema as OpenAPISchema +from .staticfiles import StaticFiles +from .templates import Templates -# TODO: consider moving status codes here class API: """The primary web-service class. @@ -45,12 +35,6 @@ class API: :param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist. :param auto_escape: If ``True``, HTML and XML templates will automatically be escaped. :param enable_hsts: If ``True``, send all responses to HTTPS URLs. - :param title: The title of the application (OpenAPI Info Object) - :param version: The version of the OpenAPI document (OpenAPI Info Object) - :param description: The description of the OpenAPI document (OpenAPI Info Object) - :param terms_of_service: A URL to the Terms of Service for the API (OpenAPI Info Object) - :param contact: The contact dictionary of the application (OpenAPI Contact Object) - :param license: The license information of the exposed API (OpenAPI License Object) """ status_codes = status_codes @@ -81,13 +65,8 @@ class API: self.background = BackgroundQueue() self.secret_key = secret_key - self.title = title - self.version = version - self.description = description - self.terms_of_service = terms_of_service - self.contact = contact - self.license = license - self.openapi_version = openapi + + self.router = Router() if static_dir is not None: if static_route is None: @@ -106,14 +85,6 @@ class API: self.templates_dir = templates_dir or self.built_in_templates_dir - self.apps = {} - self.routes = {} - self.before_requests = {"http": [], "ws": []} - self.docs_theme = DEFAULT_API_THEME - self.docs_route = docs_route - self.schemas = {} - self.session_cookie = DEFAULT_SESSION_COOKIE - self.hsts_enabled = enable_hsts self.cors = cors self.cors_params = cors_params @@ -133,33 +104,15 @@ class API: os.makedirs(_dir, exist_ok=True) if self.static_dir is not None: - self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app) - self.whitenoise.add_files(str(self.static_dir)) - - self.whitenoise.add_files( - ( - Path(apistar.__file__).parent - / "themes" - / self.docs_theme - / "static" - ).resolve() - ) - - self.mount(self.static_route, self.whitenoise) + self.mount(self.static_route, self.static_app) self.formats = get_formats() # Cached requests session. self._session = None - if self.openapi_version: - self.add_route(openapi_route, self.schema_response) - - if self.docs_route: - self.add_route(self.docs_route, self.docs_response) - self.default_endpoint = None - self.app = self.asgi + self.app = ExceptionMiddleware(self.router, debug=debug) self.add_middleware(GZipMiddleware) if self.hsts_enabled: @@ -167,28 +120,38 @@ class API: self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts) - self.lifespan_handler = Lifespan() - if self.cors: self.add_middleware(CORSMiddleware, **self.cors_params) self.add_middleware(ServerErrorMiddleware, debug=debug) + self.add_middleware(SessionMiddleware, secret_key=self.secret_key) - # Jinja environment - self.jinja_env = jinja2.Environment( - loader=jinja2.FileSystemLoader( - [str(self.templates_dir), str(self.built_in_templates_dir)], - followlinks=True, - ), - autoescape=jinja2.select_autoescape(["html", "xml"] if auto_escape else []), - ) - self.jinja_values_base = {"api": self} # Give reference to self. + if openapi or docs_route: + self.openapi = OpenAPISchema( + app=self, + title="Web Service", + version="1.0", + openapi="3.0.2", + docs_route=docs_route, + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + openapi_route=openapi_route, + static_route=static_route, + ) + + # TODO: Update docs for templates + self.templates = Templates(directory=templates_dir) self.requests = ( self.session() ) #: A Requests session that is connected to the ASGI app. - @staticmethod - def _default_wsgi_app(environ, start_response): - pass + @property + def static_app(self): + if not hasattr(self, "_static_app"): + assert self.static_dir is not None + self._static_app = StaticFiles(directory=self.static_dir) + return self._static_app @staticmethod def _notfound_wsgi_app(environ, start_response): @@ -197,10 +160,7 @@ class API: def before_request(self, websocket=False): def decorator(f): - if websocket: - self.before_requests.setdefault("ws", []).append(f) - else: - self.before_requests.setdefault("http", []).append(f) + self.router.before_request(f, websocket=websocket) return f return decorator @@ -213,139 +173,20 @@ class API: def before_ws_requests(self): return self.before_requests.get("ws", []) - @property - def _apispec(self): - - info = {} - if self.description is not None: - info["description"] = self.description - if self.terms_of_service is not None: - info["termsOfService"] = self.terms_of_service - if self.contact is not None: - info["contact"] = self.contact - if self.license is not None: - info["license"] = self.license - - spec = APISpec( - title=self.title, - version=self.version, - openapi_version=self.openapi_version, - plugins=[MarshmallowPlugin()], - info=info, - ) - - for route in self.routes: - if self.routes[route].description: - operations = yaml_utils.load_operations_from_docstring( - self.routes[route].description - ) - spec.path(path=route, operations=operations) - - for name, schema in self.schemas.items(): - spec.components.schema(name, schema=schema) - - return spec - - @property - def openapi(self): - return self._apispec.to_yaml() - def add_middleware(self, middleware_cls, **middleware_config): self.app = middleware_cls(self.app, **middleware_config) - async def __call__(self, scope, receive, send): - if scope["type"] == "lifespan": - await self.lifespan_handler(scope, receive, send) - return - - path = scope["path"] - root_path = scope.get("root_path", "") - - # Call into a submounted app, if one exists. - for path_prefix, app in self.apps.items(): - if path.startswith(path_prefix): - scope["path"] = path[len(path_prefix) :] - scope["root_path"] = root_path + path_prefix - try: - await app(scope, receive, send) - return - except TypeError: - app = WSGIMiddleware(app) - await app(scope, receive, send) - return - - await self.app(scope, receive, send) - - async def asgi(self, scope, receive, send): - assert scope["type"] in ("http", "websocket") - - if scope["type"] == "websocket": - await self._dispatch_ws(scope=scope, receive=receive, send=send) - else: - req = models.Request(scope, receive=receive, api=self) - resp = await self._dispatch_http( - req, scope=scope, send=send, receive=receive - ) - await resp(scope, receive, send) - - async def _dispatch_http(self, req, **options): - # Set formats on Request object. - req.formats = self.formats - - # Get the route. - route = self.path_matches_route(req.url.path) - route = self.routes.get(route) - if route: - resp = models.Response(req=req, formats=self.formats) - - for before_request in self.before_http_requests: - await self.background(before_request, req=req, resp=resp) - - await self._execute_route(route=route, req=req, resp=resp, **options) - else: - resp = models.Response(req=req, formats=self.formats) - self.default_response(req=req, resp=resp, notfound=True) - self.default_response(req=req, resp=resp) - - self._prepare_session(resp) - - return resp - - async def _dispatch_ws(self, scope, receive, send): - ws = WebSocket(scope=scope, receive=receive, send=send) - - route = self.path_matches_route(ws.url.path) - route = self.routes.get(route) - - if route: - for before_request in self.before_ws_requests: - await self.background(before_request, ws=ws) - await self.background(route.endpoint, ws) - else: - await send({"type": "websocket.close", "code": 1000}) - - def add_schema(self, name, schema, check_existing=True): - """Adds a mashmallow schema to the API specification.""" - if check_existing: - assert name not in self.schemas - - self.schemas[name] = schema - def schema(self, name, **options): """Decorator for creating new routes around function and class definitions. - Usage:: - from marshmallow import Schema, fields - @api.schema("Pet") class PetSchema(Schema): name = fields.Str() - """ def decorator(f): - self.add_schema(name=name, schema=f, **options) + self.openapi.add_schema(name=name, schema=f, **options) return f return decorator @@ -355,97 +196,18 @@ class API: :param path: The path portion of a URL, to test all known routes against. """ - for (route, route_object) in self.routes.items(): - if route_object.does_match(path): + for route in self.router.routes: + match, _ = route.matches(path) + if match: return route - @property - def _signer(self): - return itsdangerous.Signer(self.secret_key) - - def _prepare_session(self, resp): - - if resp.session: - data = self._signer.sign( - b64encode(json.dumps(resp.session).encode("utf-8")) - ) - resp.cookies[self.session_cookie] = data.decode("utf-8") - - @staticmethod - def no_response(req, resp, **params): - pass - - async def _execute_route(self, *, route, req, resp, **options): - - params = route.incoming_matches(req.url.path) - - cont = True - - if route.is_function: - try: - try: - # Run the view. - r = self.background(route.endpoint, req, resp, **params) - # If it's async, await it. - if hasattr(r, "cr_running"): - await r - except TypeError as e: - cont = True - except Exception: - await self.background(self.default_response, req, resp, error=True) - raise - - if route.is_class_based or cont: - try: - view = route.endpoint(**params) - except TypeError: - try: - view = route.endpoint() - except TypeError: - view = route.endpoint - pass - - # Run on_request first. - try: - # Run the view. - r = getattr(view, "on_request", self.no_response) - r = self.background(r, req, resp, **params) - # If it's async, await it. - if hasattr(r, "send"): - await r - except Exception: - await self.background(self.default_response, req, resp, error=True) - raise - - # Then run on_method. - method = req.method - try: - # Run the view. - r = getattr(view, f"on_{method}", self.no_response) - r = self.background(r, req, resp, **params) - # If it's async, await it. - if hasattr(r, "send"): - await r - except Exception: - await self.background(self.default_response, req, resp, error=True) - raise - - def add_event_handler(self, event_type, handler): - """Adds an event handler to the API. - - :param event_type: A string in ("startup", "shutdown") - :param handler: The function to run. Can be either a function or a coroutine. - """ - - self.lifespan_handler.add_event_handler(event_type, handler) - def add_route( self, route=None, endpoint=None, *, default=False, - static=False, + static=True, check_existing=True, websocket=False, before_request=False, @@ -456,92 +218,45 @@ class API: :param endpoint: The endpoint for the route -- can be a callable, or a class. :param default: If ``True``, all unknown requests will route to this view. :param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route. - :param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined. """ - if before_request: - if websocket: - self.before_requests.setdefault("ws", []).append(endpoint) - else: - self.before_requests.setdefault("http", []).append(endpoint) - return - - if route is None: - route = f"/{uuid4().hex}" - - if check_existing: - assert route not in self.routes + # Path if static: assert self.static_dir is not None if not endpoint: - endpoint = self.static_response + endpoint = self._static_response default = True - if default: - self.default_endpoint = endpoint - - self.routes[route] = Route(route, endpoint, websocket=websocket) - # TODO: A better data structure or sort it once the app is loaded - self.routes = dict( - sorted(self.routes.items(), key=lambda item: item[1]._weight()) + self.router.add_route( + route, + endpoint, + default=default, + websocket=websocket, + before_request=before_request, + check_existing=check_existing, ) - def default_response( - self, req=None, resp=None, websocket=False, notfound=False, error=False - ): - if websocket: - return - - if resp.status_code is None: - resp.status_code = 200 - - if self.default_endpoint and notfound: - self.default_endpoint(req=req, resp=resp) - else: - if notfound: - resp.status_code = status_codes.HTTP_404 - resp.text = "Not found." - if error: - resp.status_code = status_codes.HTTP_500 - resp.text = "Application error." - - def docs_response(self, req, resp): - resp.html = self.docs - - def static_response(self, req, resp): - + async def _static_response(self, req, resp): assert self.static_dir is not None index = (self.static_dir / "index.html").resolve() if os.path.exists(index): with open(index, "r") as f: - resp.html = f.read() + resp.html = "Hello world !" else: resp.status_code = status_codes.HTTP_404 resp.text = "Not found." - def schema_response(self, req, resp): - resp.status_code = status_codes.HTTP_200 - resp.headers["Content-Type"] = "application/x-yaml" - resp.content = self.openapi - def redirect( self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301 ): """Redirects a given response to a given location. - :param resp: The Response to mutate. :param location: The location of the redirect. :param set_text: If ``True``, sets the Redirect body content automatically. :param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect. """ - - # assert resp.status_code.is_300(status_code) - - resp.status_code = status_code - if set_text: - resp.text = f"Redirecting to: {location}" - resp.headers.update({"Location": location}) + resp.redirect(location, set_text=set_text, status_code=status_code) def on_event(self, event_type: str, **args): """Decorator for registering functions or coroutines to run at certain events @@ -565,6 +280,15 @@ class API: return decorator + def add_event_handler(self, event_type, handler): + """Adds an event handler to the API. + + :param event_type: A string in ("startup", "shutdown") + :param handler: The function to run. Can be either a function or a coroutine. + """ + + self.router.lifespan_handler.add_event_handler(event_type, handler) + def route(self, route=None, **options): """Decorator for creating new routes around function and class definitions. @@ -588,7 +312,7 @@ class API: :param route: String representation of the route to be used (shouldn't be parameterized). :param app: The other WSGI / ASGI app. """ - self.apps.update({route: app}) + self.router.apps.update({route: app}) def session(self, base_url="http://;"): """Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application. @@ -600,11 +324,6 @@ class API: self._session = TestClient(self, base_url=base_url) return self._session - def _route_for(self, endpoint): - for route_object in self.routes.values(): - if endpoint in (route_object.endpoint, route_object.endpoint_name): - return route_object - def url_for(self, endpoint, **params): # TODO: Absolute_url """Given an endpoint, returns a rendered URL for its route. @@ -612,67 +331,25 @@ class API: :param endpoint: The route endpoint you're searching for. :param params: Data to pass into the URL generator (for parameterized URLs). """ - route_object = self._route_for(endpoint) - if route_object: - return route_object.url(**params) - raise ValueError + return self.router.url_for(endpoint, **params) - def static_url(self, asset): - """Given a static asset, return its URL path.""" - assert None not in (self.static_dir, self.static_route) - return f"{self.static_route}/{str(asset)}" - - @property - def docs(self): - - loader = jinja2.PrefixLoader( - { - self.docs_theme: jinja2.PackageLoader( - "apistar", os.path.join("themes", self.docs_theme, "templates") - ) - } - ) - env = jinja2.Environment(autoescape=True, loader=loader) - document = apistar.document.Document() - document.content = yaml.safe_load(self.openapi) - - template = env.get_template("/".join([self.docs_theme, "index.html"])) - - return template.render( - document=document, - langs=["javascript", "python"], - code_style=None, - static_url=self.static_url, - schema_url="/schema.yml", - ) - - def template(self, name_, **values): + def template(self, filename, *args, **kwargs): """Renders the given `jinja2 `_ template, with provided values supplied. - Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``. - - :param name_: The filename of the jinja2 template, in ``templates_dir``. - :param values: Data to pass into the template. + :param filename: The filename of the jinja2 template, in ``templates_dir``. + :param *args: Data to pass into the template. + :param *kwargs: Date to pass into the template. """ - # Prepopulate values with base - values = {**self.jinja_values_base, **values} + return self.templates.render(filename, *args, **kwargs) - template = self.jinja_env.get_template(name_) - return template.render(**values) - - def template_string(self, s_, **values): + def template_string(self, source, *args, **kwargs): """Renders the given `jinja2 `_ template string, with provided values supplied. - Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``. - - :param s_: The template to use. - :param values: Data to pass into the template. + :param source: The template to use. + :param *args: Data to pass into the template. + :param **kwargs: Data to pass into the template. """ - # Prepopulate values with base - values = {**self.jinja_values_base, **values} - - template = self.jinja_env.from_string(s_) - return template.render(**values) + return self.templates.render_string(source, *args, **kwargs) def serve(self, *, address=None, port=None, debug=False, **options): """Runs the application with uvicorn. If the ``PORT`` environment @@ -704,3 +381,6 @@ class API: if "debug" not in kwargs: kwargs.update({"debug": self.debug}) self.serve(**kwargs) + + async def __call__(self, scope, receive, send): + await self.app(scope, receive, send) diff --git a/responder/ext/graphql.py b/responder/ext/graphql/__init__.py similarity index 90% rename from responder/ext/graphql.py rename to responder/ext/graphql/__init__.py index 806d43e..36f3b32 100644 --- a/responder/ext/graphql.py +++ b/responder/ext/graphql/__init__.py @@ -3,7 +3,7 @@ from functools import partial from graphql_server import default_format_error, encode_execution_results, json_encode -from ..templates import GRAPHIQL +from .templates import GRAPHIQL class GraphQLView: @@ -44,7 +44,9 @@ class GraphQLView: show_graphiql = req.method == "get" and req.accepts("text/html") if show_graphiql: - resp.content = self.api.template_string(GRAPHIQL, endpoint=req.url.path) + resp.content = self.api.templates.render_string( + GRAPHIQL, endpoint=req.url.path + ) return query, variables, operation_name = await self._resolve_graphql_query(req) @@ -63,3 +65,6 @@ class GraphQLView: async def on_request(self, req, resp): await self.graphql_response(req, resp, self.schema) + + async def __call__(self, req, resp): + await self.on_request(req, resp) diff --git a/responder/templates/__init__.py b/responder/ext/graphql/templates.py similarity index 100% rename from responder/templates/__init__.py rename to responder/ext/graphql/templates.py diff --git a/responder/ext/schema/__init__.py b/responder/ext/schema/__init__.py new file mode 100644 index 0000000..be18b14 --- /dev/null +++ b/responder/ext/schema/__init__.py @@ -0,0 +1,160 @@ +import os +from pathlib import Path + +import apistar +import jinja2 +import yaml +from apispec import APISpec, yaml_utils +from apispec.ext.marshmallow import MarshmallowPlugin + +from responder.statics import DEFAULT_API_THEME +from responder.staticfiles import StaticFiles +from responder import status_codes + + +class Schema: + def __init__( + self, + app, + title, + version, + plugins=None, + description=None, + terms_of_service=None, + contact=None, + license=None, + openapi=None, + openapi_route="/schema.yml", + docs_route="/docs/", + static_route="/static", + ): + self.app = app + self.schemas = {} + self.title = title + self.version = version + self.description = description + self.terms_of_service = terms_of_service + self.contact = contact + self.license = license + + self.openapi_version = openapi + self.openapi_route = openapi_route + + self.docs_theme = DEFAULT_API_THEME + self.docs_route = docs_route + + self.plugins = [MarshmallowPlugin()] if plugins is None else plugins + + if self.openapi_version is not None: + self.app.add_route(self.openapi_route, self.schema_response) + + if self.docs_route is not None: + self.app.add_route(self.docs_route, self.docs_response) + + theme_path = ( + Path(apistar.__file__).parent / "themes" / self.docs_theme / "static" + ).resolve() + + self.static_route = static_route + + self.app.static_app.add_directory(theme_path) + + @property + def _apispec(self): + + info = {} + if self.description is not None: + info["description"] = self.description + if self.terms_of_service is not None: + info["termsOfService"] = self.terms_of_service + if self.contact is not None: + info["contact"] = self.contact + if self.license is not None: + info["license"] = self.license + + spec = APISpec( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + plugins=self.plugins, + info=info, + ) + + for route in self.app.router.routes: + if route.description: + operations = yaml_utils.load_operations_from_docstring( + route.description + ) + spec.path(path=route.route, operations=operations) + + for name, schema in self.schemas.items(): + spec.components.schema(name, schema=schema) + + return spec + + @property + def openapi(self): + return self._apispec.to_yaml() + + def add_schema(self, name, schema, check_existing=True): + """Adds a mashmallow schema to the API specification.""" + if check_existing: + assert name not in self.schemas + + self.schemas[name] = schema + + def schema(self, name, **options): + """Decorator for creating new routes around function and class definitions. + + Usage:: + + from marshmallow import Schema, fields + + @api.schema("Pet") + class PetSchema(Schema): + name = fields.Str() + + """ + + def decorator(f): + self.add_schema(name=name, schema=f, **options) + return f + + return decorator + + @property + def docs(self): + + loader = jinja2.PrefixLoader( + { + self.docs_theme: jinja2.PackageLoader( + "apistar", os.path.join("themes", self.docs_theme, "templates") + ) + } + ) + env = jinja2.Environment(autoescape=True, loader=loader) + document = apistar.document.Document() + document.content = yaml.safe_load(self.openapi) + + template = env.get_template("/".join([self.docs_theme, "index.html"])) + + return template.render( + document=document, + langs=["javascript", "python"], + code_style=None, + static_url=self.static_url, + schema_url="/schema.yml", + ) + + def static_url(self, asset): + """Given a static asset, return its URL path.""" + assert self.static_route is not None + return f"{self.static_route}/{str(asset)}" + + def docs_response(self, req, resp): + resp.html = self.docs + + def schema_response(self, req, resp): + resp.status_code = status_codes.HTTP_200 + resp.headers["Content-Type"] = "application/x-yaml" + resp.content = self.openapi diff --git a/responder/models.py b/responder/models.py index bbf8a97..d39f318 100644 --- a/responder/models.py +++ b/responder/models.py @@ -3,16 +3,18 @@ import io import inspect import json import gzip +from urllib.parse import parse_qs from base64 import b64decode from http.cookies import SimpleCookie - import chardet import rfc3986 import graphene import yaml + from requests.structures import CaseInsensitiveDict from requests.cookies import RequestsCookieJar + from starlette.datastructures import MutableHeaders from starlette.requests import Request as StarletteRequest, State from starlette.responses import ( @@ -20,9 +22,7 @@ from starlette.responses import ( StreamingResponse as StarletteStreamingResponse, ) -from urllib.parse import parse_qs - -from .status_codes import HTTP_200 +from .status_codes import HTTP_200, HTTP_301 from .statics import DEFAULT_ENCODING @@ -105,9 +105,9 @@ class Request: "_cookies", ] - def __init__(self, scope, receive, api=None): + def __init__(self, scope, receive, api=None, formats=None): self._starlette = StarletteRequest(scope, receive) - self.formats = None + self.formats = formats self._encoding = None self.api = api self._content = None @@ -122,14 +122,7 @@ class Request: @property def session(self): """The session data, in dict form, from the Request.""" - if self.api.session_cookie in self.cookies: - - data = self.cookies[self.api.session_cookie] - - data = self.api._signer.unsign(data) - data = b64decode(data) - return json.loads(data) - return {} + return self._starlette.session @property def headers(self): @@ -182,12 +175,12 @@ class Request: def state(self) -> State: """ Use the state to store additional information. - + This can be a very helpful feature, if you want to hand over information from a middelware or a route decorator to the actual route handler. - For example: ``request.state.time_started = time.time()`` + Usage: ``request.state.time_started = time.time()`` """ return self._starlette.state @@ -300,7 +293,7 @@ class Response: self.formats = formats self.cookies = SimpleCookie() #: The cookies set in the Response self.session = ( - req.session.copy() + req.session ) #: The cookie-based session data, in dict form, to add to the Response. # Property or func/dec @@ -311,6 +304,12 @@ class Response: return func + def redirect(self, location, *, set_text=True, status_code=HTTP_301): + self.status_code = status_code + if set_text: + self.text = f"Redirecting to: {location}" + self.headers.update({"Location": location}) + @property async def body(self): if self._stream is not None: diff --git a/responder/routes.py b/responder/routes.py index 997fe05..6ff6535 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -1,43 +1,76 @@ +import asyncio +import json import re -import functools import inspect -from parse import parse, with_pattern + +from starlette.routing import Lifespan +from starlette.middleware.wsgi import WSGIMiddleware +from starlette.websockets import WebSocket, WebSocketClose +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException + +from .models import Request, Response +from . import status_codes +from .formats import get_formats +from .statics import DEFAULT_SESSION_COOKIE -def _make_convertor(type, pattern): - @with_pattern(pattern) - def inner(value): - return type(value) - - return inner - - -_convertors = { - "int": _make_convertor(int, r"\d+"), - "str": _make_convertor(str, r"[^/]+"), - "float": _make_convertor(float, r"\d+(.\d+)?"), +_CONVERTORS = { + "int": (int, r"\d+"), + "str": (str, r"[^/]+"), + "float": (float, r"\d+(.\d+)?"), } +PARAM_RE = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") -class Route: - _param_pattern = re.compile(r"{([^{}]*)}") - def __init__(self, route, endpoint, *, websocket=False, before_request=False): +def compile_path(path): + path_re = "^" + param_convertors = {} + idx = 0 + + for match in PARAM_RE.finditer(path): + param_name, convertor_type = match.groups(default="str") + convertor_type = convertor_type.lstrip(":") + assert ( + convertor_type in _CONVERTORS.keys() + ), f"Unknown path convertor '{convertor_type}'" + convertor, convertor_re = _CONVERTORS[convertor_type] + + path_re += path[idx : match.start()] + path_re += rf"(?P<{param_name}>{convertor_re})" + + param_convertors[param_name] = convertor + + idx = match.end() + + path_re += path[idx:] + "$" + + return re.compile(path_re), param_convertors + + +class BaseRoute: + def matches(self, scope): + raise NotImplementedError() + + async def __call__(self, scope, receive, send): + raise NotImplementedError() + + +class Route(BaseRoute): + def __init__(self, route, endpoint, *, before_request=False): + assert route.startswith("/"), "Route path must start with '/'" self.route = route self.endpoint = endpoint - self.uses_websocket = websocket self.before_request = before_request + self.path_re, self.param_convertors = compile_path(route) + def __repr__(self): return f"" - def __eq__(self, other): - if hasattr(other, "route"): - # Being compared to other routes. - return self.route == other.route - else: - # Strings. - return self.does_match(other) + def url(self, **params): + return self.route.format(**params) @property def endpoint_name(self): @@ -47,46 +80,247 @@ class Route: def description(self): return self.endpoint.__doc__ - @property - def has_parameters(self): - return bool(self._param_pattern.search(self.route)) + def matches(self, scope): + if scope["type"] != "http": + return False, {} - @functools.lru_cache(maxsize=None) - def does_match(self, s): - if s == self.route: - return True + path = scope["path"] + match = self.path_re.match(path) - named = self.incoming_matches(s) - return bool(len(named)) + if match is None: + return False, {} - @functools.lru_cache(maxsize=None) - def incoming_matches(self, s): - results = parse(self.route, s, _convertors) - return results.named if results else {} + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key](value) + + return True, {"path_params": {**matched_params}} + + async def __call__(self, scope, receive, send): + request = Request(scope, receive, formats=get_formats()) + response = Response(req=request, formats=get_formats()) + + path_params = scope.get("path_params", {}) + before_requests = scope.get("before_requests", []) + + for before_request in before_requests.get("http", []): + if asyncio.iscoroutinefunction(before_request): + await before_request(request, response) + else: + await run_in_threadpool(before_request, request, response) + + views = [] + + if inspect.isclass(self.endpoint): + endpoint = self.endpoint() + on_request = getattr(endpoint, "on_request", None) + if on_request: + views.append(on_request) + + method_name = f"on_{request.method}" + try: + view = getattr(endpoint, method_name) + views.append(view) + except AttributeError: + if on_request is None: + raise HTTPException(status_code=status_codes.HTTP_405) + else: + views.append(self.endpoint) + + for view in views: + # "Monckey patch" for graphql: explicitly checking __call__ + if asyncio.iscoroutinefunction(view) or asyncio.iscoroutinefunction( + view.__call__ + ): + await view(request, response, **path_params) + else: + await run_in_threadpool(view, request, response, **path_params) + + if response.status_code is None: + response.status_code = status_codes.HTTP_200 + + await response(scope, receive, send) + + def __eq__(self, other): + # [TODO] compare to str ? + return self.route == other.route and self.endpoint == other.endpoint + + def __hash__(self): + return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request) + + +class WebSocketRoute(BaseRoute): + def __init__(self, route, endpoint, *, before_request=False): + assert route.startswith("/"), "Route path must start with '/'" + self.route = route + self.endpoint = endpoint + self.before_request = before_request + + self.path_re, self.param_convertors = compile_path(route) + + def __repr__(self): + return f"" def url(self, **params): return self.route.format(**params) - def _weight(self): - params = set(self._param_pattern.findall(self.route)) - params_count = len(params) - w = len(self.route.rsplit("}", 1)[-1].strip("/")) - return params_count != 0, w == 0, -params_count + @property + def endpoint_name(self): + return self.endpoint.__name__ @property - def is_class_based(self): - return inspect.isclass(self.endpoint) + def description(self): + return self.endpoint.__doc__ - @property - def is_function(self): - code = hasattr(self.endpoint, "__code__") - kwdefaults = hasattr(self.endpoint, "__kwdefaults__") - return all((callable(self.endpoint), code, kwdefaults)) + def matches(self, scope): + if scope["type"] != "websocket": + return False, {} + + path = scope["path"] + match = self.path_re.match(path) + + if match is None: + return False, {} + + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key](value) + + return True, {"path_params": {**matched_params}} + + async def __call__(self, scope, receive, send): + ws = WebSocket(scope, receive, send) + + before_requests = scope.get("before_requests", []) + for before_request in before_requests.get("ws", []): + await before_request(ws) + + await self.endpoint(ws) + + def __eq__(self, other): + # [TODO] compare to str ? + return self.route == other.route and self.endpoint == other.endpoint def __hash__(self): - return ( - hash(self.route) - ^ hash(self.endpoint) - ^ hash(self.uses_websocket) - ^ hash(self.before_request) + return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request) + + +class Router: + def __init__(self, routes=None, default_response=None, before_requests=None): + self.routes = [] if routes is None else list(routes) + # [TODO] Make its own router + self.apps = {} + self.default_endpoint = ( + self.default_response if default_response is None else default_response ) + self.lifespan_handler = Lifespan() + self.before_requests = ( + {"http": [], "ws": []} if before_requests is None else before_requests + ) + + def add_route( + self, + route=None, + endpoint=None, + *, + default=False, + websocket=False, + before_request=False, + check_existing=False, + ): + """ Adds a route to the router. + :param route: A string representation of the route + :param endpoint: The endpoint for the route -- can be callable, or class. + :param default: If ``True``, all unknown requests will route to this view. + """ + if before_request: + if websocket: + self.before_requests.setdefault("ws", []).append(endpoint) + else: + self.before_requests.setdefault("http", []).append(endpoint) + return + + if check_existing: + assert not self.routes or route not in ( + item.route for item in self.routes + ), f"Route '{route}' already exists" + + if default: + self.default_endpoint = endpoint + + if websocket: + route = WebSocketRoute(route, endpoint) + else: + route = Route(route, endpoint) + + self.routes.append(route) + + def mount(self, route, app): + """Mounts ASGI / WSGI applications at a given route + """ + self.apps.update(route, app) + + def before_request(self, endpoint, websocket=False): + if websocket: + self.before_requests.setdefault("ws", []).append(endpoint) + else: + self.before_requests.setdefault("http", []).append(endpoint) + + def url_for(self, endpoint, **params): + # TODO: Check for params + for route in self.routes: + if endpoint in (route.endpoint, route.endpoint.__name__): + return route.url(**params) + return None + + async def default_response(self, scope, receive, send): + if scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(receive, send) + return + + request = Request(scope, receive) + response = Response(request, formats=get_formats()) + + raise HTTPException(status_code=status_codes.HTTP_404) + + def _resolve_route(self, scope): + for route in self.routes: + matches, child_scope = route.matches(scope) + if matches: + scope.update(child_scope) + return route + return None + + async def __call__(self, scope, receive, send): + assert scope["type"] in ("http", "websocket", "lifespan") + + if scope["type"] == "lifespan": + await self.lifespan_handler(scope, receive, send) + return + + path = scope["path"] + root_path = scope.get("root_path", "") + + # Call into a submounted app, if one exists. + for path_prefix, app in self.apps.items(): + if path.startswith(path_prefix): + scope["path"] = path[len(path_prefix) :] + scope["root_path"] = root_path + path_prefix + try: + await app(scope, receive, send) + return + except TypeError: + app = WSGIMiddleware(app) + await app(scope, receive, send) + return + + route = self._resolve_route(scope) + + scope["before_requests"] = self.before_requests + + if route is not None: + await route(scope, receive, send) + return + + await self.default_response(scope, receive, send) diff --git a/responder/staticfiles.py b/responder/staticfiles.py new file mode 100644 index 0000000..66a57c1 --- /dev/null +++ b/responder/staticfiles.py @@ -0,0 +1,18 @@ +import typing + +from starlette.staticfiles import StaticFiles + + +class StaticFiles(StaticFiles): + """I've created an issue to disccuss allowing multiple directories in starletter's `StaticFiles`. + + https://github.com/encode/starlette/issues/625 + + I've also made a PR to add this method to starlette StaticFiles + Once accepted we will remove this. + + https://github.com/encode/starlette/pull/626 + """ + + def add_directory(self, directory: str) -> None: + self.all_directories = [*self.all_directories, *self.get_directories(directory)] diff --git a/responder/templates.py b/responder/templates.py new file mode 100644 index 0000000..6419a02 --- /dev/null +++ b/responder/templates.py @@ -0,0 +1,55 @@ +from contextlib import contextmanager + +import jinja2 + + +class Templates: + def __init__(self, directory="templates", autoescape=True, context=None): + self.directory = directory + self._env = jinja2.Environment( + loader=jinja2.FileSystemLoader([str(self.directory)]), autoescape=autoescape + ) + self.default_context = {} if context is None else {**context} + self._env.globals.update(self.default_context) + + @property + def context(self): + return self._env.globals + + @context.setter + def context(self, context): + self._env.globals = {**self.default_context, **context} + + def get_template(self, name): + return self._env.get_template(name) + + def render(self, template, *args, **kwargs): + """Renders the given `jinja2 `_ template, with provided values supplied. + + :param template: The filename of the jinja2 template. + :param **kwargs: Data to pass into the template. + :param **kwargs: Data to pass into the template. + """ + return self.get_template(template).render(*args, **kwargs) + + @contextmanager + def _async(self): + self._env.is_async = True + try: + yield + finally: + self._env.is_async = False + + async def render_async(self, template, *args, **kwargs): + with self._async(): + return await self.get_template(template).render_async(*args, **kwargs) + + def render_string(self, source, *args, **kwargs): + """Renders the given `jinja2 `_ template string, with provided values supplied. + + :param source: The template to use. + :param *args, **kwargs: Data to pass into the template. + :param **kwargs: Data to pass into the template. + """ + template = self._env.from_string(source) + return template.render(*args, **kwargs) diff --git a/setup.py b/setup.py index df69c91..1d5dcd3 100644 --- a/setup.py +++ b/setup.py @@ -27,10 +27,9 @@ required = [ "aiofiles", "pyyaml", "requests", - "graphene", + "graphene<3.0", "graphql-server-core>=1.1", "jinja2", - "parse", "uvloop; sys_platform != 'win32' and sys_platform != 'cygwin' and sys_platform != 'cli'", "rfc3986", "python-multipart", @@ -39,7 +38,6 @@ required = [ "marshmallow", "whitenoise", "docopt", - "itsdangerous", "requests-toolbelt", "apistar", ] diff --git a/tests/test_responder.py b/tests/test_responder.py index 2debc2b..2935306 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -7,6 +7,7 @@ import responder import requests import string import io +from responder.routes import Router, Route, WebSocketRoute from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse @@ -19,6 +20,47 @@ def test_api_basic_route(api): resp.text = "hello world!" +def test_route_repr(): + def home(req, resp): + """Home page + """ + resp.text = "Hello !" + + route = Route("/", home) + + assert route.__repr__() == f"" + + assert route.endpoint_name == home.__name__ + assert route.description == home.__doc__ + + +def test_websocket_route_repr(): + def chat_endpoint(ws): + """Chat + """ + pass + + route = WebSocketRoute("/", chat_endpoint) + + assert route.__repr__() == f"" + + assert route.endpoint_name == chat_endpoint.__name__ + assert route.description == chat_endpoint.__doc__ + + +def test_route_eq(): + def home(req, resp): + resp.text = "Hello !" + + assert Route("/", home) == Route("/", home) + + def chat(ws): + pass + + assert WebSocketRoute("/", home) == WebSocketRoute("/", home) + + +""" def test_api_basic_route_overlap(api): @api.route("/") def home(req, resp): @@ -29,39 +71,7 @@ def test_api_basic_route_overlap(api): @api.route("/") def home2(req, resp): resp.text = "hello world!" - - -def test_api_basic_route_overlap_alternative(api): - @api.route("/") - def home(req, resp): - resp.text = "hello world!" - - def home2(req, resp): - resp.text = "hello world!" - - with pytest.raises(AssertionError): - api.add_route("/", home2) - - -def test_api_basic_route_overlap_allowed(api): - @api.route("/") - def home(req, resp): - resp.text = "hello world!" - - def home2(req, resp): - resp.text = "hello world!" - - api.add_route("/", home2, check_existing=False) - - -def test_api_basic_route_overlap_allowed_alternative(api): - @api.route("/") - def home(req, resp): - resp.text = "hello world!" - - @api.route("/", check_existing=False) - def home2(req, resp): - resp.text = "hello world!" +""" def test_class_based_view_registration(api): @@ -74,10 +84,10 @@ def test_class_based_view_registration(api): def test_class_based_view_parameters(api): @api.route("/{greeting}") class Greeting: - def on_request(self, req, resp, *, greeting): - resp.text = f"{greeting}, world!" + pass - assert api.session().get("http://;/Hello").ok + resp = api.session().get("http://;/Hello") + assert resp.status_code == api.status_codes.HTTP_405 def test_requests_session(api): @@ -85,14 +95,14 @@ def test_requests_session(api): assert api.requests -def test_requests_session_works(api, url): +def test_requests_session_works(api): TEXT = "spiral out" @api.route("/") def hello(req, resp): resp.text = TEXT - assert api.requests.get(url("/")).text == TEXT + assert api.requests.get("/").text == TEXT def test_status_code(api): @@ -338,13 +348,45 @@ def test_yaml_downloads(api): assert yaml.safe_load(r.content) == dump +def test_schema_generation_explicit(): + import responder + from responder.ext.schema import Schema as OpenAPISchema + import marshmallow + + api = responder.API() + + schema = OpenAPISchema(app=api, title="Web Service", version="1.0", openapi="3.0.2") + + @schema.schema("Pet") + class PetSchema(marshmallow.Schema): + name = marshmallow.fields.Str() + + @api.route("/") + def route(req, resp): + """A cute furry animal endpoint. + --- + get: + description: Get a random pet + responses: + 200: + description: A pet to be returned + schema: + $ref: "#/components/schemas/Pet" + """ + resp.media = PetSchema().dump({"name": "little orange"}) + + r = api.requests.get("http://;/schema.yml") + dump = yaml.safe_load(r.content) + + assert dump + assert dump["openapi"] == "3.0.2" + + def test_schema_generation(): import responder from marshmallow import Schema, fields - api = responder.API( - title="Web Service", openapi="3.0.2", allowed_hosts=["testserver", ";"] - ) + api = responder.API(title="Web Service", openapi="3.0.2") @api.schema("Pet") class PetSchema(Schema): @@ -371,6 +413,60 @@ def test_schema_generation(): assert dump["openapi"] == "3.0.2" +def test_documentation_explicit(): + import responder + from responder.ext.schema import Schema as OpenAPISchema + + import marshmallow + + description = "This is a sample server for a pet store." + terms_of_service = "http://example.com/terms/" + contact = { + "name": "API Support", + "url": "http://www.example.com/support", + "email": "support@example.com", + } + license = { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + } + + api = responder.API(allowed_hosts=["testserver", ";"]) + + schema = OpenAPISchema( + app=api, + title="Web Service", + version="1.0", + openapi="3.0.2", + docs_route="/docs", + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + ) + + @schema.schema("Pet") + class PetSchema(marshmallow.Schema): + name = marshmallow.fields.Str() + + @api.route("/") + def route(req, resp): + """A cute furry animal endpoint. + --- + get: + description: Get a random pet + responses: + 200: + description: A pet to be returned + schema: + $ref: "#/components/schemas/Pet" + """ + resp.media = PetSchema().dump({"name": "little orange"}) + + r = api.requests.get("/docs") + assert "html" in r.text + + def test_documentation(): import responder from marshmallow import Schema, fields @@ -485,7 +581,7 @@ def test_sessions(api): assert r.json() == {"hello": "world"} -def test_template_rendering(api): +def test_template_string_rendering(api): @api.route("/") def view(req, resp): resp.content = api.template_string("{{ var }}", var="hello") @@ -582,7 +678,7 @@ def test_before_websockets(api): await ws.send_json(payload) await ws.close() - @api.route(before_request=True, websocket=True) + @api.before_request(websocket=True) async def before_request(ws): await ws.accept() await ws.send_json({"before": "request"}) @@ -643,6 +739,10 @@ def test_before_response(api, session): def get(req, resp): resp.media = req.session + @api.route(before_request=True) + async def async_before_request(req, resp): + resp.headers["x-pizza"] = "1" + @api.route(before_request=True) def before_request(req, resp): resp.headers["x-pizza"] = "1" @@ -794,25 +894,6 @@ def test_staticfiles_none_dir(tmpdir): api.add_route("/spa", static=True) -def test_staticfiles_none_dir_route(tmpdir): - api = responder.API(static_dir=None, static_route=None) - session = api.session() - - static_dir = tmpdir.mkdir("static") - - asset = create_asset(static_dir) - - static_route = api.static_route - - # ok - r = session.get(f"{static_route}/{asset.basename}") - assert r.status_code == api.status_codes.HTTP_404 - - # dir listing - r = session.get(f"{static_route}") - assert r.status_code == api.status_codes.HTTP_404 - - def test_response_html_property(api): @api.route("/") def view(req, resp): diff --git a/tests/test_routes.py b/tests/test_routes.py deleted file mode 100644 index 0c19e30..0000000 --- a/tests/test_routes.py +++ /dev/null @@ -1,159 +0,0 @@ -import pytest -from responder import routes - - -def setup_function(function): - routes.Route.incoming_matches.cache_clear() - - -@pytest.mark.parametrize( - "route, expected", - [ - pytest.param("/", False, id="home path without params"), - pytest.param("/test_path", False, id="sub path without params"), - pytest.param("/{test_path}", True, id="path with params"), - ], -) -def test_parameter(route, expected): - r = routes.Route(route, "test_endpoint") - assert r.has_parameters is expected - - -def test_url(): - r = routes.Route("/{my_path}", "test_endpoint") - url = r.url(my_path="path") - assert url == "/path" - - -def test_equal(): - r = routes.Route("/{path_param}", "test_endpoint") - r2 = routes.Route("/{path_param}", "test_endpoint") - r3 = routes.Route("/test_path", "test_endpoint") - - assert r == r2 - assert r != r3 - - -def test_incoming_matches(): - # Test Route with one param - r = routes.Route("/{greetings}", "test_endpoint") - assert r.incoming_matches("/hello") == {"greetings": "hello"} - assert r.incoming_matches("/foo") == {"greetings": "foo"} - - # Test Route with two params - r = routes.Route("/{greetings}/{name}", "test_endpoint") - assert r.incoming_matches("/hi/john") == {"greetings": "hi", "name": "john"} - assert r.incoming_matches("/hello/jane") == {"greetings": "hello", "name": "jane"} - - # Test Route with no param - r = routes.Route("/hello", "test_endpoint") - assert r.incoming_matches("/hello") == {} - assert r.incoming_matches("/bye") == {} - - -def test_incoming_matches_cache(): - r = routes.Route("/hello", "test_endpoint") - r.incoming_matches("/hello") - assert r.incoming_matches.cache_info().hits == 0 - r.incoming_matches("/hello") - assert r.incoming_matches.cache_info().hits == 1 - - -def test_incoming_matches_with_concrete_path_no_match(): - r = routes.Route("/concrete_path", "test_endpoint") - assert r.incoming_matches("hello") == {} - - -@pytest.mark.parametrize( - "route, match, expected", - [ - pytest.param( - "/{path_param}", - "/{path_param}", - True, - id="with both parametrized path match", - ), - pytest.param( - "/concrete", "/concrete", True, id="with both concrete path match" - ), - pytest.param("/concrete", "/no_match", False, id="with no match"), - ], -) -def test_does_match_with_route(route, match, expected): - r = routes.Route(route, "test_endpoint") - assert r.does_match(match) == expected - - -@pytest.mark.parametrize( - "path_param, expected_weight", - [ - pytest.param("/{greetings}", (True, True, -1), id="with one param"), - pytest.param( - "/{greetings}.{name}", - (True, True, -2), - id="with 2 params and dot in the middle", - ), - pytest.param( - "/{greetings}/{name}", (True, True, -2), id="with 2 params and subpath" - ), - pytest.param( - "/{greetings}/{name}/{hello}", - (True, True, -3), - id="with 3 params and subpath", - ), - pytest.param( - "/{greetings}_{name}", (True, True, -2), id="with 2 params and underscore" - ), - pytest.param("/{greetings}/test", (True, False, -1), id="with one param"), - pytest.param( - "/{greetings}.{name}/test", - (True, False, -2), - id="with 2 params and dot in the middle", - ), - pytest.param( - "/{greetings}/{name}/test", - (True, False, -2), - id="with 2 params and subpath", - ), - pytest.param( - "/{greetings}/{name}/{hello}/test", - (True, False, -3), - id="with 3 params and subpath", - ), - pytest.param( - "/{greetings}_{name}/test", - (True, False, -2), - id="with 2 params and underscore", - ), - pytest.param("/hello", (False, False, 0), id="without params"), - ], -) -def test_weight(path_param, expected_weight): - r = routes.Route(path_param, "test_endpoint") - assert r._weight() == expected_weight - - -@pytest.mark.parametrize( - "route, path, expected_result", - [ - pytest.param("/{greetings:str}", "/hello", {"greetings": "hello"}), - pytest.param( - "/{greetings:str}/{who}", - "/hello/Laidia", - {"greetings": "hello", "who": "Laidia"}, - ), - pytest.param("/{birth_date:int}", "/1937", {"birth_date": 1937}), - pytest.param( - "/{name:str}/{age:int}", "/Fatna/80", {"name": "Fatna", "age": 80} - ), - pytest.param( - "/{x:float}/{y:float}", "/10.20/75", {"x": float(10.20), "y": float(75)} - ), - pytest.param("/{name:str}/{age:int}", "/Fatna/eighty", {}), - pytest.param("/{greetings:int}", "/hello", {}), - pytest.param("/{name:float}", "/Fatna", {}), - ], -) -def test_custom_specifiers(route, path, expected_result): - r = routes.Route(route, "test_endpoint") - assert r.incoming_matches(path) == expected_result