diff --git a/examples/maqpy/all_pages_auth.py b/examples/maqpy/all_pages_auth.py new file mode 100644 index 0000000..e69de29 diff --git a/src/replit/maqpy/app.py b/src/replit/maqpy/app.py index f74aca0..a833682 100644 --- a/src/replit/maqpy/app.py +++ b/src/replit/maqpy/app.py @@ -1,6 +1,7 @@ """Core of maqpi.""" from dataclasses import dataclass -from typing import Any +from functools import wraps +from typing import Any, Callable, Set import flask @@ -14,14 +15,14 @@ class ReplitAuthContext: roles: str @classmethod - def from_headers(cls, headers: dict): + def from_headers(cls, headers: dict) -> Any: """Initialize an instance using the Replit magic headers. Args: headers (dict): A dictionary of headers received Returns: - [type]: An initialized class instance + Any: An initialized class instance """ return cls( user_id=headers.get("X-Replit-User-Id"), @@ -31,7 +32,11 @@ class ReplitAuthContext: @property def signed_in(self) -> bool: - """Return whether or not the authentication is activated.""" + """Check whether the user is signed in with repl auth. + + Returns: + bool: whether or not the authentication is activated. + """ return self.name != "" @@ -39,7 +44,12 @@ class Request(flask.Request): """Represents a client request.""" def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize request and run update_auth.""" + """Initializes request and runs update_auth. + + Args: + args (Any): The arguments to be passed to the superclass. + kwargs (Any): The keyword arguments to be passed to the superclass. + """ super().__init__(*args, **kwargs) self.update_auth() @@ -49,7 +59,11 @@ class Request(flask.Request): @property def signed_in(self) -> bool: - """Return whether or not the authentication is activated.""" + """Check whether the user is signed in with repl auth. + + Returns: + bool: Whether or not the user is signed in + """ return self.auth.signed_in @@ -58,12 +72,71 @@ class App(flask.Flask): request_class = Request - def all_pages_sign_in(self) -> None: - """Require sign-in on all pages.""" - raise NotImplementedError() + def all_pages_sign_in(self, exclude: Set[str] = ("/",)) -> None: + """Require sign-in on all pages. + + Args: + exclude (Tuple[str]): The routes that should not require sign in. + Defaults to just /. + """ + self._apsi_exclude = set(exclude) or set() + + def _request_handler(self, rule: str, view_func: Callable) -> Callable: + """Return a handler for a given request. + + This enables the all_pages_sign_in feature. + + Args: + rule (str): The url that the route will be matched to + view_func (Callable): The original view function that will be called. + + Returns: + Callable: A handler that runs the middleware and calls the original function + """ + + @wraps(view_func) + def handler(*args: Any, **kwargs: Any) -> Any: + if ( + hasattr(self, "_apsi_exclude") + and self._apsi_exclude is not None + and rule not in self._apsi_exclude + and not flask.request.signed_in + ): + return "" + return view_func(*args, **kwargs) + + return handler + + def add_url_rule( + self, + rule: str, + endpoint: str = None, + view_func: Callable = None, + provide_automatic_options: bool = None, + **options: Any + ) -> None: + """Replaces view function with custom handler.""" + return super().add_url_rule( + rule, + endpoint=endpoint, + view_func=self._request_handler(view_func) + if view_func is not None + else view_func, + provide_automatic_options=provide_automatic_options, + **options + ) def _run(self, *args: Any, **kwargs: Any) -> Any: - """Interface with the underlying flask instance's run function.""" + """Interface with the underlying flask instance's run function. + + Args: + args (Any): The arguments to be passed to the superclass' run method. + kwargs (Any): The keyword arguments to be passed to the superclass' run + method. + + Returns: + Any: The result of running the superclasses' run method. + """ return super().run(*args, **kwargs) def run(self, port: int = 8080, localhost: bool = False) -> None: