"""Utitilities to make development easier.""" from functools import wraps import time from typing import Any, Callable, Iterable, Optional, Union import flask from werkzeug.local import LocalProxy from .html import Page sign_in_snippet = ( '' ) def whoami(): """Returns the username of the authenticated Replit user, else None.""" return flask.request.headers.get('X-Replit-User-Name') def sign_in(title: str = "Please Sign In") -> Page: """Return a sign-in page. Args: title (str): The title of the sign in page. Defaults to "Please Sign In". Returns: Page: The sign-in page. """ return Page(title=title, body=sign_in_snippet) sign_in_page = sign_in() def needs_sign_in(func: Callable = None, login_res: str = sign_in_page) -> Callable: """A decorator that enforces that the user is signed in before accessing the page. Args: func (Callable): The function passed in if used as a decorator. Defaults to None. login_res (str): The HTML to show when the user needs to sign in. Defaults to sign_in_snippet. Returns: Callable: The new handler. """ def decorator(func: Callable) -> Callable: @wraps(func) def handler(*args: Any, **kwargs: Any) -> flask.Response: if flask.request.is_authenticated: return func(*args, **kwargs) else: return login_res return handler if func is not None: # called with no options @needs_signin return decorator(func) else: # called with options, eg @needs_signin(login_html='...') return decorator def needs_params( *param_names: str, src: Union[str, dict] = "form", onerror: Callable[[str], flask.Response] = None, ) -> Callable: """Require paramaters before a handler can be activated. Args: param_names (str): The paramaters that must be in the request. src (Union[str, dict]): The source to get the paramaters from. Can be "form" to use flask.request.form (POST requests), "query" for flask.request.query (GET requests), or a custom dictionary. onerror (Callable): A function to handle when a paramater is missing. It will be passed the parameter that is missing. If no function is specified a handler that returns a descriptive error and 400 Bad Request status code will be used. Raises: TypeError: No paramaters were provided or an invalid one was provided. Returns: Callable: The new handler. """ if len(param_names) < 1: raise TypeError("You must specify at least one required paramater name") # If function is used as a decorator with no arguments, the first argument will be # a function, so type check all of the param names to catch mistakes if not all(isinstance(p, str) for p in param_names): raise TypeError("All paramater names should be strings.") def default_onerror(missing_param: str) -> flask.Response: return flask.Response( f"Parameter {missing_param!r} is required but is missing", 400, mimetype="text/plain", ) onerror = default_onerror if onerror is None else onerror def decorator(func: Callable) -> Callable: @wraps(func) def handler(*args: Any, **ignoredkwargs: Any) -> flask.Response: if src == "form": params = flask.request.form elif src == "query": params = flask.request.args else: params = src param_kwargs = {} for p in param_names: if p not in params: return onerror(p) param_kwargs[p] = params[p] return func(*args, **param_kwargs) return handler return decorator def local_redirect(location: str, code: int = 302) -> flask.Response: """Perform a redirection to a local path without downgrading to HTTP. Args: location (str): The path to redirect to. code (int): The code to use for the redirect. Defaults to 302. Returns: flask.Response: The redirect response. """ # Use a LocalProxy so that it can be called before the request context is available return LocalProxy( lambda: flask.redirect( "https://" + flask.request.headers["host"] + location, code ) ) def authed_ratelimit( max_requests: int, period: float, login_res: str = sign_in_page, get_ratelimited_res: Callable[[float], str] = ( lambda left: f"Too many requests, wait {left} sec" ), ) -> Callable[[Callable], flask.Response]: """Require sign in and limit the amount of requests each signed in user can perform. This decorator also calls needs_signin for you and passes the login_res kwarg directly to it. Args: max_requests (int): The maximum amount of requests allowed in the period. period (float): The length of the period. login_res (str): The response to be shown if the user is not signed in, passed to needs_sign_in. get_ratelimited_res (Callable[[float], str]): A callable which is passed the amount of time remaining before the user can request again and returns the response that should be sent to the user. Returns: Callable[[Callable], flask.Response]: A function which decorates the handler. """ def decorator(func: Callable) -> flask.Response: last_reset = time.time() num_requests = {} # Checks for signin first, before checking ratelimit @needs_sign_in(login_res=login_res) @wraps(func) def handler(*args: Any, **kwargs: Any) -> flask.Response: nonlocal last_reset nonlocal num_requests name = flask.request.auth.name now = time.time() if now - last_reset >= period: last_reset = now num_requests = {} times_requested = num_requests.get(name, 0) if times_requested >= max_requests: res = get_ratelimited_res(period - (now - last_reset)) # Make a reponse object so that status can be set if not isinstance(res, flask.Response): res = flask.make_response(res) res.status = "429" return res num_requests[name] = times_requested + 1 return func(*args, **kwargs) return handler return decorator def find( data: Iterable, cond: Callable[[Any], bool], allow_multiple: bool = False ) -> Optional[Any]: """Find an item in an iterable. Args: data (Iterable): The iterable to search through. cond (Callable[[Any], bool]): The function to call for each item to check if it is a match. allow_multiple (bool): If multiple result are found, return the first one if allow_multiple is True, otherwise return None. Returns: Optional[Any]: The item if exactly one match was found, otherwise None. """ matches = [item for item in data if cond(item)] if len(matches) > 1: return matches[0] if allow_multiple else None return matches[0] if len(matches) == 1 else None def chain_decorators(*decorators: Callable[[Callable], Any]) -> Callable: """Return a decorator that applies each of the decorators to the function. Args: *decorators (Callable[[Callable], Any]): The decorators to apply to the function. They are treated as if they are written in the order they appear. Raises: TypeError: If no decorators are passed. Returns: Callable: A decorator function. """ def dec(func: Callable) -> Callable: for decorator in reversed(list(decorators) + [wraps(func)]): func = decorator(func) return func if not decorators: raise TypeError("You must provide at least one decorator to chain") return dec