From d24b921cdc4e19f062f8ae3dad01a5eaa69180dd Mon Sep 17 00:00:00 2001 From: taoufik Date: Sat, 17 Aug 2019 16:03:12 +0200 Subject: [PATCH] Move open api schema to ext/schema --- responder/api.py | 145 ++++------------------------ responder/ext/schema/__init__.py | 160 ++++++++++++++++++++++++++++++ responder/models.py | 5 +- responder/routes.py | 11 ++- responder/staticfiles.py | 18 ++++ tests/test_responder.py | 161 ++++++++++++++++++++++++++----- 6 files changed, 348 insertions(+), 152 deletions(-) create mode 100644 responder/ext/schema/__init__.py create mode 100644 responder/staticfiles.py diff --git a/responder/api.py b/responder/api.py index d1095a3..fd9885e 100644 --- a/responder/api.py +++ b/responder/api.py @@ -5,13 +5,8 @@ from uuid import uuid4 from pathlib import Path from base64 import b64encode -import apistar -import itsdangerous import jinja2 import uvicorn -import yaml -from apispec import APISpec, yaml_utils -from apispec.ext.marshmallow import MarshmallowPlugin from starlette.exceptions import ExceptionMiddleware from starlette.middleware.wsgi import WSGIMiddleware from starlette.middleware.errors import ServerErrorMiddleware @@ -32,6 +27,7 @@ from .formats import get_formats from .routes import Router from .statics import DEFAULT_API_THEME, DEFAULT_CORS_PARAMS, DEFAULT_SECRET_KEY from .templates import GRAPHIQL +from .ext.schema import Schema as OpenAPISchema class API: @@ -41,12 +37,6 @@ class API: :param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist. :param auto_escape: If ``True``, HTML and XML templates will automatically be escaped. :param enable_hsts: If ``True``, send all responses to HTTPS URLs. - :param title: The title of the application (OpenAPI Info Object) - :param version: The version of the OpenAPI document (OpenAPI Info Object) - :param description: The description of the OpenAPI document (OpenAPI Info Object) - :param terms_of_service: A URL to the Terms of Service for the API (OpenAPI Info Object) - :param contact: The contact dictionary of the application (OpenAPI Contact Object) - :param license: The license information of the exposed API (OpenAPI License Object) """ status_codes = status_codes @@ -77,13 +67,23 @@ class API: self.background = BackgroundQueue() self.secret_key = secret_key - self.title = title - self.version = version - self.description = description - self.terms_of_service = terms_of_service - self.contact = contact - self.license = license - self.openapi_version = openapi + + self.router = Router() + + if openapi or docs_route: + self.openapi = OpenAPISchema( + app=self, + title="Web Service", + version="1.0", + openapi="3.0.2", + docs_route=docs_route, + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + openapi_route=openapi_route, + static_route=static_route, + ) if static_dir is not None: if static_route is None: @@ -102,12 +102,6 @@ class API: self.templates_dir = templates_dir or self.built_in_templates_dir - self.router = Router() - - self.docs_theme = DEFAULT_API_THEME - self.docs_route = docs_route - self.schemas = {} - self.hsts_enabled = enable_hsts self.cors = cors self.cors_params = cors_params @@ -130,15 +124,6 @@ class API: self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app) self.whitenoise.add_files(str(self.static_dir)) - self.whitenoise.add_files( - ( - Path(apistar.__file__).parent - / "themes" - / self.docs_theme - / "static" - ).resolve() - ) - self.mount(self.static_route, self.whitenoise) self.formats = get_formats() @@ -146,12 +131,6 @@ class API: # Cached requests session. self._session = None - if self.openapi_version: - self.add_route(openapi_route, self.schema_response) - - if self.docs_route: - self.add_route(self.docs_route, self.docs_response) - self.default_endpoint = None self.app = ExceptionMiddleware(self.router, debug=debug) self.add_middleware(GZipMiddleware) @@ -199,71 +178,23 @@ class API: def before_ws_requests(self): return self.before_requests.get("ws", []) - @property - def _apispec(self): - - info = {} - if self.description is not None: - info["description"] = self.description - if self.terms_of_service is not None: - info["termsOfService"] = self.terms_of_service - if self.contact is not None: - info["contact"] = self.contact - if self.license is not None: - info["license"] = self.license - - spec = APISpec( - title=self.title, - version=self.version, - openapi_version=self.openapi_version, - plugins=[MarshmallowPlugin()], - info=info, - ) - - for route in self.router.routes: - if route.description: - operations = yaml_utils.load_operations_from_docstring( - route.description - ) - spec.path(path=route.route, operations=operations) - - for name, schema in self.schemas.items(): - spec.components.schema(name, schema=schema) - - return spec - - @property - def openapi(self): - return self._apispec.to_yaml() - def add_middleware(self, middleware_cls, **middleware_config): self.app = middleware_cls(self.app, **middleware_config) async def __call__(self, scope, receive, send): await self.app(scope, receive, send) - def add_schema(self, name, schema, check_existing=True): - """Adds a mashmallow schema to the API specification.""" - if check_existing: - assert name not in self.schemas - - self.schemas[name] = schema - def schema(self, name, **options): """Decorator for creating new routes around function and class definitions. - Usage:: - from marshmallow import Schema, fields - @api.schema("Pet") class PetSchema(Schema): name = fields.Str() - """ def decorator(f): - self.add_schema(name=name, schema=f, **options) + self.openapi.add_schema(name=name, schema=f, **options) return f return decorator @@ -301,16 +232,9 @@ class API: default=default, websocket=websocket, before_request=before_request, + check_existing=check_existing, ) - def docs_response(self, req, resp): - resp.html = self.docs - - def schema_response(self, req, resp): - resp.status_code = status_codes.HTTP_200 - resp.headers["Content-Type"] = "application/x-yaml" - resp.content = self.openapi - def redirect( self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301 ): @@ -397,35 +321,6 @@ class API: """ return self.router.url_for(endpoint, **params) - def static_url(self, asset): - """Given a static asset, return its URL path.""" - assert None not in (self.static_dir, self.static_route) - return f"{self.static_route}/{str(asset)}" - - @property - def docs(self): - - loader = jinja2.PrefixLoader( - { - self.docs_theme: jinja2.PackageLoader( - "apistar", os.path.join("themes", self.docs_theme, "templates") - ) - } - ) - env = jinja2.Environment(autoescape=True, loader=loader) - document = apistar.document.Document() - document.content = yaml.safe_load(self.openapi) - - template = env.get_template("/".join([self.docs_theme, "index.html"])) - - return template.render( - document=document, - langs=["javascript", "python"], - code_style=None, - static_url=self.static_url, - schema_url="/schema.yml", - ) - def template(self, name_, **values): """Renders the given `jinja2 `_ template, with provided values supplied. diff --git a/responder/ext/schema/__init__.py b/responder/ext/schema/__init__.py new file mode 100644 index 0000000..9c717c0 --- /dev/null +++ b/responder/ext/schema/__init__.py @@ -0,0 +1,160 @@ +import os +from pathlib import Path + +import apistar +import jinja2 +import yaml +from apispec import APISpec, yaml_utils +from apispec.ext.marshmallow import MarshmallowPlugin + +from responder.statics import DEFAULT_API_THEME +from responder.staticfiles import StaticFiles +from responder import status_codes + + +class Schema: + def __init__( + self, + app, + title, + version, + plugins=None, + description=None, + terms_of_service=None, + contact=None, + license=None, + openapi=None, + openapi_route="/schema.yml", + docs_route="/docs/", + static_route="/static", + ): + self.app = app + self.schemas = {} + self.title = title + self.version = version + self.description = description + self.terms_of_service = terms_of_service + self.contact = contact + self.license = license + + self.openapi_version = openapi + self.openapi_route = openapi_route + + self.docs_theme = DEFAULT_API_THEME + self.docs_route = docs_route + + self.plugins = [MarshmallowPlugin()] if plugins is None else plugins + + if self.openapi_version is not None: + self.app.add_route(self.openapi_route, self.schema_response) + + if self.docs_route is not None: + self.app.add_route(self.docs_route, self.docs_response) + + theme_path = ( + Path(apistar.__file__).parent / "themes" / self.docs_theme / "static" + ).resolve() + + self.static_route = static_route + + self.app.mount(self.static_route, StaticFiles(directory=theme_path)) + + @property + def _apispec(self): + + info = {} + if self.description is not None: + info["description"] = self.description + if self.terms_of_service is not None: + info["termsOfService"] = self.terms_of_service + if self.contact is not None: + info["contact"] = self.contact + if self.license is not None: + info["license"] = self.license + + spec = APISpec( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + plugins=self.plugins, + info=info, + ) + + for route in self.app.router.routes: + if route.description: + operations = yaml_utils.load_operations_from_docstring( + route.description + ) + spec.path(path=route.route, operations=operations) + + for name, schema in self.schemas.items(): + spec.components.schema(name, schema=schema) + + return spec + + @property + def openapi(self): + return self._apispec.to_yaml() + + def add_schema(self, name, schema, check_existing=True): + """Adds a mashmallow schema to the API specification.""" + if check_existing: + assert name not in self.schemas + + self.schemas[name] = schema + + def schema(self, name, **options): + """Decorator for creating new routes around function and class definitions. + + Usage:: + + from marshmallow import Schema, fields + + @api.schema("Pet") + class PetSchema(Schema): + name = fields.Str() + + """ + + def decorator(f): + self.add_schema(name=name, schema=f, **options) + return f + + return decorator + + @property + def docs(self): + + loader = jinja2.PrefixLoader( + { + self.docs_theme: jinja2.PackageLoader( + "apistar", os.path.join("themes", self.docs_theme, "templates") + ) + } + ) + env = jinja2.Environment(autoescape=True, loader=loader) + document = apistar.document.Document() + document.content = yaml.safe_load(self.openapi) + + template = env.get_template("/".join([self.docs_theme, "index.html"])) + + return template.render( + document=document, + langs=["javascript", "python"], + code_style=None, + static_url=self.static_url, + schema_url="/schema.yml", + ) + + def static_url(self, asset): + """Given a static asset, return its URL path.""" + assert self.static_route is not None + return f"{self.static_route}/{str(asset)}" + + def docs_response(self, req, resp): + resp.html = self.docs + + def schema_response(self, req, resp): + resp.status_code = status_codes.HTTP_200 + resp.headers["Content-Type"] = "application/x-yaml" + resp.content = self.openapi diff --git a/responder/models.py b/responder/models.py index be7c802..d39f318 100644 --- a/responder/models.py +++ b/responder/models.py @@ -151,7 +151,6 @@ class Request: @property def cookies(self): """The cookies sent in the Request, as a dictionary.""" - return self._starlette.cookies if self._cookies is None: cookies = RequestsCookieJar() cookie_header = self.headers.get("Cookie", "") @@ -176,12 +175,12 @@ class Request: def state(self) -> State: """ Use the state to store additional information. - + This can be a very helpful feature, if you want to hand over information from a middelware or a route decorator to the actual route handler. - For example: ``request.state.time_started = time.time()`` + Usage: ``request.state.time_started = time.time()`` """ return self._starlette.state diff --git a/responder/routes.py b/responder/routes.py index e85c89c..c53def1 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -139,7 +139,6 @@ class Route(BaseRoute): if response.status_code is None: response.status_code = status_codes.HTTP_200 - print("here", response) await response(scope, receive, send) def __eq__(self, other): @@ -198,6 +197,10 @@ class WebSocketRoute(BaseRoute): await self.endpoint(ws) + def __eq__(self, other): + # [TODO] compare to str ? + return self.route == other.route and self.endpoint == other.endpoint + def __hash__(self): return hash(self.route) ^ hash(self.endpoint) ^ hash(self.before_request) @@ -223,6 +226,7 @@ class Router: default=False, websocket=False, before_request=False, + check_existing=False, ): """ Adds a route to the router. :param route: A string representation of the route @@ -236,6 +240,11 @@ class Router: self.before_requests.setdefault("http", []).append(endpoint) return + if check_existing: + assert not self.routes or route not in ( + item.route for item in self.routes + ), f"Route '{route}' already exists" + if default: self.default_endpoint = endpoint diff --git a/responder/staticfiles.py b/responder/staticfiles.py new file mode 100644 index 0000000..93f26d7 --- /dev/null +++ b/responder/staticfiles.py @@ -0,0 +1,18 @@ +from whitenoise import WhiteNoise + + +def _notfound_wsgi_app(environ, start_response): + start_response("404 NOT FOUND", [("Content-Type", "text/plain")]) + return [b"Not Found."] + + +class StaticFiles: + def __init__(self, directory=None, mkdir=True): + self.directory = directory + self.app = WhiteNoise(_notfound_wsgi_app, root=self.directory) + + def __call__(self, environ, start_response): + return self.app(environ, start_response) + + +from starlette.staticfiles import StaticFiles diff --git a/tests/test_responder.py b/tests/test_responder.py index 508e444..4ae39b8 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -7,6 +7,7 @@ import responder import requests import string import io +from responder.routes import Router, Route, WebSocketRoute from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse @@ -19,6 +20,46 @@ def test_api_basic_route(api): resp.text = "hello world!" +def test_route_repr(): + def home(req, resp): + """Home page + """ + resp.text = "Hello !" + + route = Route("/", home) + + assert route.__repr__() == f"" + + assert route.endpoint_name == home.__name__ + assert route.description == home.__doc__ + + +def test_websocket_route_repr(): + def chat_endpoint(ws): + """Chat + """ + pass + + route = WebSocketRoute("/", chat_endpoint) + + assert route.__repr__() == f"" + + assert route.endpoint_name == chat_endpoint.__name__ + assert route.description == chat_endpoint.__doc__ + + +def test_route_eq(): + def home(req, resp): + resp.text = "Hello !" + + assert Route("/", home) == Route("/", home) + + def chat(ws): + pass + + assert WebSocketRoute("/", home) == WebSocketRoute("/", home) + + """ def test_api_basic_route_overlap(api): @api.route("/") @@ -30,6 +71,8 @@ def test_api_basic_route_overlap(api): @api.route("/") def home2(req, resp): resp.text = "hello world!" +""" + def test_class_based_view_registration(api): @api.route("/") @@ -41,11 +84,10 @@ def test_class_based_view_registration(api): def test_class_based_view_parameters(api): @api.route("/{greeting}") class Greeting: - def on_request(self, req, resp, *, greeting): - resp.text = f"{greeting}, world!" + pass - assert api.session().get("http://;/Hello").ok -""" + resp = api.session().get("http://;/Hello") + assert resp.status_code == api.status_codes.HTTP_405 def test_requests_session(api): @@ -306,6 +348,40 @@ def test_yaml_downloads(api): assert yaml.safe_load(r.content) == dump +def test_schema_generation_explicit(): + import responder + from responder.ext.schema import Schema as OpenAPISchema + import marshmallow + + api = responder.API() + + schema = OpenAPISchema(app=api, title="Web Service", version="1.0", openapi="3.0.2") + + @schema.schema("Pet") + class PetSchema(marshmallow.Schema): + name = marshmallow.fields.Str() + + @api.route("/") + def route(req, resp): + """A cute furry animal endpoint. + --- + get: + description: Get a random pet + responses: + 200: + description: A pet to be returned + schema: + $ref: "#/components/schemas/Pet" + """ + resp.media = PetSchema().dump({"name": "little orange"}) + + r = api.requests.get("http://;/schema.yml") + dump = yaml.safe_load(r.content) + + assert dump + assert dump["openapi"] == "3.0.2" + + def test_schema_generation(): import responder from marshmallow import Schema, fields @@ -337,6 +413,60 @@ def test_schema_generation(): assert dump["openapi"] == "3.0.2" +def test_documentation_explicit(): + import responder + from responder.ext.schema import Schema as OpenAPISchema + + import marshmallow + + description = "This is a sample server for a pet store." + terms_of_service = "http://example.com/terms/" + contact = { + "name": "API Support", + "url": "http://www.example.com/support", + "email": "support@example.com", + } + license = { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + } + + api = responder.API(allowed_hosts=["testserver", ";"]) + + schema = OpenAPISchema( + app=api, + title="Web Service", + version="1.0", + openapi="3.0.2", + docs_route="/docs", + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + ) + + @schema.schema("Pet") + class PetSchema(marshmallow.Schema): + name = marshmallow.fields.Str() + + @api.route("/") + def route(req, resp): + """A cute furry animal endpoint. + --- + get: + description: Get a random pet + responses: + 200: + description: A pet to be returned + schema: + $ref: "#/components/schemas/Pet" + """ + resp.media = PetSchema().dump({"name": "little orange"}) + + r = api.requests.get("/docs") + assert "html" in r.text + + def test_documentation(): import responder from marshmallow import Schema, fields @@ -609,6 +739,10 @@ def test_before_response(api, session): def get(req, resp): resp.media = req.session + @api.route(before_request=True) + async def async_before_request(req, resp): + resp.headers["x-pizza"] = "1" + @api.route(before_request=True) def before_request(req, resp): resp.headers["x-pizza"] = "1" @@ -760,25 +894,6 @@ def test_staticfiles_none_dir(tmpdir): api.add_route("/spa", static=True) -def test_staticfiles_none_dir_route(tmpdir): - api = responder.API(static_dir=None, static_route=None) - session = api.session() - - static_dir = tmpdir.mkdir("static") - - asset = create_asset(static_dir) - - static_route = api.static_route - - # ok - r = session.get(f"{static_route}/{asset.basename}") - assert r.status_code == api.status_codes.HTTP_404 - - # dir listing - r = session.get(f"{static_route}") - assert r.status_code == api.status_codes.HTTP_404 - - def test_response_html_property(api): @api.route("/") def view(req, resp):