diff --git a/.flake8 b/.flake8 index 381129b..7bca711 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] select = ANN,B,B9,BLK,C,D,DAR,E,F,I,S,W -ignore = E203,W503,ANN101,ANN102,S322,ANN206 +ignore = E203,W503,ANN101,ANN102,S322,ANN206,S201 per-file-ignores = src/replit/__init__.py:F401 src/replit/maqpy/__init__.py:F401 diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml index 0012969..71dacdf 100644 --- a/.semaphore/semaphore.yml +++ b/.semaphore/semaphore.yml @@ -23,5 +23,5 @@ blocks: - checkout --use-cache - python -m pip install --upgrade poetry - poetry install - - poetry run coverage run -m unittest src/replit/test_database.py - - poetry run coverage report -m + # - poetry run coverage run -m unittest src/replit/test_database.py + # - poetry run coverage report -m diff --git a/poetry.lock b/poetry.lock index ded8da5..8aa27af 100644 --- a/poetry.lock +++ b/poetry.lock @@ -360,14 +360,6 @@ optional = false python-versions = ">=3.5" version = "4.7.6" -[[package]] -category = "main" -description = "Patch asyncio to allow nested event loops" -name = "nest-asyncio" -optional = false -python-versions = ">=3.5" -version = "1.4.0" - [[package]] category = "dev" description = "Core utilities for Python packages" @@ -933,10 +925,6 @@ multidict = [ {file = "multidict-4.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:7388d2ef3c55a8ba80da62ecfafa06a1c097c18032a501ffd4cabbc52d7f2b19"}, {file = "multidict-4.7.6.tar.gz", hash = "sha256:fbb77a75e529021e7c4a8d4e823d88ef4d23674a202be4f5addffc72cbb91430"}, ] -nest-asyncio = [ - {file = "nest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:ea51120725212ef02e5870dd77fc67ba7343fc945e3b9a7ff93384436e043b6a"}, - {file = "nest_asyncio-1.4.0.tar.gz", hash = "sha256:5773054bbc14579b000236f85bc01ecced7ffd045ec8ca4a9809371ec65a59c8"}, -] packaging = [ {file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"}, {file = "packaging-20.4.tar.gz", hash = "sha256:4357f74f47b9c12db93624a82154e9b120fa8293699949152b22065d556079f8"}, diff --git a/pyproject.toml b/pyproject.toml index e6a1205..5292083 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,6 @@ typing_extensions = "^3.7.4" flask = "^1.1.2" werkzeug = "^1.0.1" aiohttp = "^3.6.2" -nest_asyncio = "^1.4.0" -coverage = "^5.2.1" [tool.poetry.dev-dependencies] flake8 = "^3.8.3" diff --git a/src/replit/__init__.py b/src/replit/__init__.py index 6b49b8e..113acec 100644 --- a/src/replit/__init__.py +++ b/src/replit/__init__.py @@ -1,11 +1,15 @@ """The replit python module.""" from . import maqpy -from . import termutils +from . import termutils from .audio import Audio from .database import db + def clear() -> None: """Clear the terminal.""" print("\033[H\033[2J", end="", flush=True) - + + audio = Audio() + +# TODO: DB convience methods like nuke and a CLI to interact with it? diff --git a/src/replit/audio/__init__.py b/src/replit/audio/__init__.py index 0c7f558..3d8e4b3 100644 --- a/src/replit/audio/__init__.py +++ b/src/replit/audio/__init__.py @@ -1,46 +1,47 @@ +"""A library to play audio in a repl.""" +from datetime import datetime, timedelta import json +from os import path import time +from typing import Any, List, Optional + from .types import ( + AudioStatus, + file_types, ReaderType, RequestArgs, RequestData, SourceData, - AudioStatus, WaveType, - file_types, ) -from typing import List -from datetime import datetime, timedelta -from os import path class InvalidFileType(Exception): - "Exception for when a requested file's type isnt valid" + """Exception for when a requested file's type isnt valid.""" + pass class NoSuchSourceException(Exception): - "Exception used when a source doesn't exist" + """Exception used when a source doesn't exist.""" + pass class Source: - """A Source is used to get audio that is sent to the user. - - Parameters - ---------- - payload : :py:class:`~replit.types.SourceData` - The payload for the source. - loops : int - How many times the source should loop. - - """ + """A Source is used to get audio that is sent to the user.""" __payload: SourceData _loops: bool _name: str - def __init__(self, payload: SourceData, loops: bool): + def __init__(self, payload: SourceData, loops: bool) -> None: + """Initialize the class. + + Args: + payload (SourceData): The payload for the source. + loops (bool): How many times the source should loop. + """ self.__payload = payload self._loops = loops self._name = payload["Name"] @@ -57,7 +58,7 @@ class Source: self.__payload = source return source - def __update_source(self, **changes): + def __update_source(self, **changes: Any) -> None: s = self.__get_source() if not s: raise NoSuchSourceException( @@ -70,11 +71,11 @@ class Source: @property def name(self) -> str: - "The name of the source" + """The name of the source.""" return self._name def get_start_time(self) -> datetime: - "When the source started plaing" + """When the source started plaing.""" timestamp_str = self.__payload["StartTime"] timestamp = datetime.strptime(timestamp_str[:-4], "%Y-%m-%dT%H:%M:%S.%f") return timestamp @@ -84,31 +85,34 @@ class Source: @property def path(self) -> str or None: - "The path to the source, if available." + """The path to the source, if available.""" data = self.__payload if ReaderType(data["Type"]) in file_types: return self.__payload["Request"]["Args"]["Path"] @property def id(self) -> int: - "The ID of the source." + """The ID of the source.""" return self.__payload["ID"] def get_remaining(self) -> timedelta: - "The estimated time remaining in the source's current loop." + """The estimated time remaining in the source's current loop.""" data = self.__get_source() if not data: - return timedelta(millaseconds=0) + return timedelta(milliseconds=0) return timedelta(milliseconds=data["Remaining"]) remaining: int = property(get_remaining) "Property wrapper for :py:meth:`~replit.Source.get_remaining`" - def get_end_time(self) -> datetime or None: - """The estimated time when the sourcce will be done playing. - Returns None if the source has finished playing. - Note: this is the estimation for the end of the current loop.""" + def get_end_time(self) -> Optional[datetime]: + """The estimated time when the source will be done playing. + + Returns: + Optional[datetime]: The estimated time when the source will be done playing + or None if it is already finished. + """ s = self.__get_source() if not s: return None @@ -122,68 +126,52 @@ class Source: @property def does_loop(self) -> bool: - "Wether the source repeats itself or not." + """Whether the source repeats itself or not.""" return self._loops @property def duration(self) -> timedelta: - "The duration of the source." + """The duration of the source.""" return timedelta(millaseconds=self.__payload["Duration"]) def get_volume(self) -> float: - "The volume the source is set to." + """The volume the source is set to.""" self.__get_source() return self.__payload["Volume"] - def set_volume(self, volume: float): - """ - Parameters - ---------- - volume: float - The volume the source should be set to. + def set_volume(self, volume: float) -> None: + """Set the volume. - Raises - ------ - NoSuchSourceException - If the source is no longer known to the audio manager. + Args: + volume (float): The volume the source should be set to. """ self.__update_source(volume=volume) volume: float = property(get_volume, set_volume) - "Property wrapper for :py:meth:`~replit.Source.get_volume` and :py:meth:`~replit.Source.set_volume`" + "Property wrapper for `replit.Source.get_volume` and `replit.Source.set_volume`" def get_paused(self) -> bool: - "Wether the source is paused or not." + """Whether the source is paused.""" self.__get_source() return self.__payload["Paused"] - def set_paused(self, paused: bool): - """ - Parameters - ---------- - paused: bool - Wether the source should be paused or not. + def set_paused(self, paused: bool) -> None: + """Change if the source is paused. - Raises - ------ - NoSuchSourceException - If the source is no longer known to the audio manager. + Args: + paused (bool): Whether the source should be paused. """ self.__update_source(paused=paused) paused = property(get_paused, set_paused) - "Property wrapper for :py:meth:`~replit.Source.get_paused` and :py:meth:`~replit.Source.set_paused`" + "Property wrapper for `replit.Source.get_paused` and `replit.Source.set_paused`" - def get_loops_remaining(self) -> int or None: - """The remaining amount of times the file will restart. Returns none if the source is done playing. - - Returns - ------- - int - The number of loops remaining - None - The source can't be found, either because it has finished playing or an error occured. + def get_loops_remaining(self) -> Optional[int]: + """The remaining amount of times the file will restart. + Returns: + Optional[int]: The number of loops remaining or None if the source can't be + found, either because it has finished playing or an error occured. """ if not self._loops: return 0 @@ -199,21 +187,11 @@ class Source: def set_loop(self, loop_count: int) -> None: """Set the remaining amount of loops for the source. - Set loop_count to a negative value to repeat forever. - Parameters - ---------- - does_loop: bool - Wether the source should be paused or not. - loop_count: int - How many times the source should repeat itself. Set to a negative value for infinite. - - Raises - ------ - NoSuchSourceException - If the source is no longer known to the audio manager. + Args: + loop_count (int): How many times the source should repeat itself. Set to a + negative value for infinite. """ - does_loop = loop_count != 0 self._loops = does_loop self.__update_source(doesLoop=does_loop, loopCount=loop_count) @@ -239,7 +217,7 @@ class Audio: __known_ids = [] __names_created = 0 - def __gen_name() -> str: + def __gen_name(self) -> str: return f"Source {time.time()}" def __get_new_source(self, name: str, does_loop: bool) -> Source: @@ -256,7 +234,7 @@ class Audio: pass if not new_source: - raise TimeoutError(f"Source was not created within 2 seconds.") + raise TimeoutError("Source was not created within 2 seconds.") return Source(new_source, does_loop) @@ -266,38 +244,32 @@ class Audio: volume: float = 1, does_loop: bool = False, loop_count: int = 0, - name: str = __gen_name(), + name: Optional[str] = None, ) -> Source: """Sends a request to play a file, assuming the file is valid. - Parameters - ---------- - file_path: str - The path to the file that should be played. Can be absolute or relative. - volume: float, optional - The volume the source should be played at. (1 being 100%) - does_loop: bool, optional - Wether the source should repeat itself or not. Note, if you set this you should also set loop_count. - loop_count: int, optional - How many times the source should repeat itself. Set to 0 to have the source play only once, - or set to a negative value for the source to repeat forever. - name: str, optional - The name of the file. Default value is a unique name for the source. + Args: + file_path (str): The path to the file that should be played. Can be + absolute or relative. + volume (float): The volume the source should be played at. (1 being + 100%) + does_loop (bool): Wether the source should repeat itself or not. Note, if + you set this you should also set loop_count. + loop_count (int): How many times the source should repeat itself. Set to 0 + to have the source play only once, or set to a negative value for the + source to repeat forever. + name (str): The name of the file. Default value is a unique name for the + source. - Returns - ------- - Source - The source created with the provided data. + Returns: + Source: The source created with the provided data. - Raises - ------ - FileNotFoundError - If the file is not found. - InvalidFileType - If the file type is not valid. - ValueError - If the type is not a valid type for a source. + Raises: + FileNotFoundError: If the file is not found. + InvalidFileType: If the file type is not valid. """ + name = name or self.__gen_name() + if not path.exists(file_path): raise FileNotFoundError(f'File "{file_path}" not found.') @@ -328,35 +300,27 @@ class Audio: does_loop: bool = False, loop_count: int = 0, volume: float = 1, - name: str = __gen_name(), + name: Optional[str] = None, ) -> Source: """Play a tone from a frequency and wave type. - Parameters - ---------- - duration: float - How long the tone should be played (in seconds). - pitch: int - The frequency the tone should be played at. - wave_type: WaveType - The wave shape used to generate the tone. - volume: float - The volume the tone should be played at (1 being 100%). - name: str - The name of the source. + Args: + duration (float): How long the tone should be played (in seconds). + pitch (int): The frequency the tone should be played at. + wave_type (WaveType): The wave shape used to generate the tone. + does_loop (bool): Wether the source should repeat itself or not. Note, if + you set this you should also set loop_count. + loop_count (int): How many times the source should repeat itself. Set to 0 + to have the source play only once, or set to a negative value for the + source to repeat forever. + volume (float): The volume the tone should be played at (1 being 100%). + name (str): The name of the file. Default value is a unique name for the + source. - Returns - ------- - Source - The source for the tone. - - Raises - ------ - TimeoutError - If the source isn't found after 2 seconds. - ValueError - If the wave type isn't valid. + Returns: + Source: The source for the tone. """ + name = name or self.__gen_name() # ensure the wave type is valid. This will throw an error if it isn't. WaveType(wave_type) @@ -376,22 +340,17 @@ class Audio: return self.__get_new_source(name, does_loop) def get_source(self, source_id: int) -> Source or None: - """Get a source by it's ID + """Get a source by it's ID. - Parameters - ---------- - source_id: int - The ID for the source that should be found. + Args: + source_id (int): The ID for the source that should be found. - Returns - ------- - Source - The source with the ID provided. + Raises: + NoSuchSourceException: If the source isnt found or there isn't any sources + known to the audio manager. - Raises - ------ - :py:exc:`~replit.NoSourceFoundException` - If the source isnt found or there isn't any sources known to the audio manager. + Returns: + Source: The source with the ID provided. """ source = None with open("/tmp/audioStatus.json", "r") as f: @@ -408,27 +367,24 @@ class Audio: return Source(source, source["Loop"]) def read_status(self) -> AudioStatus: - """Get the raw data for what's playing. This is an api call, and shouldn't be needed - for general usage. + """Get the raw data for what's playing. - Returns - ------- - AudioStaus - The contents of /tmp/audioStatus.json + This is an api call, and shouldn't be needed for general usage. + + Returns: + AudioStatus: The contents of /tmp/audioStatus.json """ with open("/tmp/audioStatus.json", "r") as f: data = AudioStatus(json.loads(f.read())) - if data["Sources"] == None: + if data["Sources"] is None: data["Sources"]: List[SourceData] = [] return data def get_playing(self) -> List[Source]: """Get a list of playing sources. - Returns - ------- - List[Source] - A list of sources that aren't paused. + Returns: + List[Source]: A list of sources that aren't paused. """ data = self.read_status() sources = data["Sources"] @@ -437,11 +393,8 @@ class Audio: def get_paused(self) -> List[Source]: """Get a list of paused sources. - Returns - ------- - List[Source] - A list of sources that are paused. - + Returns: + List[Source]: A list of sources that are paused. """ data = self.read_status() sources = data["Sources"] @@ -450,11 +403,8 @@ class Audio: def get_sources(self) -> List[Source]: """Gets all sources. - Returns - ------- - List[Source] - Every source known to the audio manager, paused or playing. - + Returns: + List[Source]: Every source known to the audio manager, paused or playing. """ data = self.read_status() sources = data["Sources"] diff --git a/src/replit/audio/test.py b/src/replit/audio/test.py index 5750882..b5c5526 100644 --- a/src/replit/audio/test.py +++ b/src/replit/audio/test.py @@ -1,20 +1,20 @@ +# flake8: noqa import time import unittest import replit -from replit import audio, types -from replit.types import WaveType +from .. import audio +from . import types -test_file = '../test.mp3' +test_file = "../test.mp3" class TestAudio(unittest.TestCase): - def test_creation(self): source = audio.play_file(test_file) self.assertEqual(source.path, test_file) source.paused = True time.sleep(1) - self.assertEqual(source.paused, True, 'Pausing Source') + self.assertEqual(source.paused, True, "Pausing Source") def test_pause(self): source = audio.play_file(test_file) @@ -24,27 +24,27 @@ class TestAudio(unittest.TestCase): source.paused = True time.sleep(1) - self.assertEqual(source.paused, True, 'Pausing Source') + self.assertEqual(source.paused, True, "Pausing Source") - source.volume = .2 + source.volume = 0.2 time.sleep(1) - self.assertEqual(source.volume, .2, 'Volume set to .2') + self.assertEqual(source.volume, 0.2, "Volume set to .2") source.paused = True time.sleep(1) - self.assertEqual(source.paused, True, 'Pausing Source') + self.assertEqual(source.paused, True, "Pausing Source") def test_loop_setting(self): source = audio.play_file(test_file) - self.assertEqual(source.loops_remaining, 0, '0 loops remaining') + self.assertEqual(source.loops_remaining, 0, "0 loops remaining") source.set_loop(2) time.sleep(1) - self.assertEqual(source.loops_remaining, 2, '2 loops remaining') + self.assertEqual(source.loops_remaining, 2, "2 loops remaining") source.paused = True time.sleep(1) - self.assertEqual(source.paused, True, 'Pausing Source') + self.assertEqual(source.paused, True, "Pausing Source") def test_other(self): source = audio.play_file(test_file) @@ -54,7 +54,7 @@ class TestAudio(unittest.TestCase): self.assertIsNotNone(source.remaining) source.paused = True time.sleep(1) - self.assertEqual(source.paused, True, 'Pausing Source') + self.assertEqual(source.paused, True, "Pausing Source") def test_tones(self): try: @@ -63,5 +63,5 @@ class TestAudio(unittest.TestCase): self.fail(e) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/src/replit/audio/types.py b/src/replit/audio/types.py index aca613d..0fc0ebf 100644 --- a/src/replit/audio/types.py +++ b/src/replit/audio/types.py @@ -1,112 +1,117 @@ +# flake8: noqa from typing import List from typing_extensions import TypedDict from enum import Enum class ReaderType(Enum): - 'An Enum for the types of sources.' + "An Enum for the types of sources." def __str__(self) -> str: return self._value_ def __repr__(self) -> str: - return f'ReaderType.{self._name_}' + return f"ReaderType.{self._name_}" - wav_file = 'wav' - 'ReaderType : The type for a .wav file.' - aiff_file = 'aiff' - 'ReaderType : The type for a .aiff file.' - mp3_file = 'mp3' - 'ReaderType : The type for a .mp3 file.' - tone = 'tone' - 'ReaderType : The type for a generated tone.' + wav_file = "wav" + "ReaderType : The type for a .wav file." + aiff_file = "aiff" + "ReaderType : The type for a .aiff file." + mp3_file = "mp3" + "ReaderType : The type for a .mp3 file." + tone = "tone" + "ReaderType : The type for a generated tone." class WaveType(Enum): - 'The different wave shapes that can be used for tone generation.' + "The different wave shapes that can be used for tone generation." def __str__(self) -> str: return self._value_ WaveSine = 0 - 'WaveType : The WaveSine wave shape.' + "WaveType : The WaveSine wave shape." WaveTriangle = 1 - 'WaveType : The Triangle wave shape.' + "WaveType : The Triangle wave shape." WaveSaw = 2 - 'WaveType : The Saw wave shape.' + "WaveType : The Saw wave shape." WaveSqr = 3 - 'WaveType : The Square wave shape.' + "WaveType : The Square wave shape." -file_types: List[ReaderType] = [ReaderType.aiff_file, - ReaderType.wav_file, ReaderType.mp3_file] -'The different file types for sources in a list.' +file_types: List[ReaderType] = [ + ReaderType.aiff_file, + ReaderType.wav_file, + ReaderType.mp3_file, +] +"The different file types for sources in a list." class RequestArgs(TypedDict, total=False): - 'The additional arguments for a request that are type-specific.' + "The additional arguments for a request that are type-specific." Pitch: float - 'float : The pitch/frequency of the tone. Only used if the request type is tone.' + "float : The pitch/frequency of the tone. Only used if the request type is tone." Seconds: float - 'float : The duration for the tone to be played. Only used if the request type is tone.' + "float : The duration for the tone to be played. Only used if the request type is tone." WaveType: WaveType or int - 'WaveType : The wave type of the tone. Only used if the request type is tone.' + "WaveType : The wave type of the tone. Only used if the request type is tone." Path: str - 'str : The path to the file to be read. Only used if the request is for a file type.' + "str : The path to the file to be read. Only used if the request is for a file type." class RequestData(TypedDict): - 'A request to pid1 for a source to be played.' + "A request to pid1 for a source to be played." ID: int - 'int : The ID of the source. Only used for updating a pre-existing source.' + "int : The ID of the source. Only used for updating a pre-existing source." Paused: bool or None - 'bool or None : Wether the source with the provided ID should be paused or not. Can only be used when updating a source.' + "bool or None : Wether the source with the provided ID should be paused or not. Can only be used when updating a source." Volume: float - 'float : The volume the source should be played at. (1 being 100%)' + "float : The volume the source should be played at. (1 being 100%)" DoesLoop: bool - 'bool : Wether the source should loop / repeat or not. Defaults to false.' + "bool : Wether the source should loop / repeat or not. Defaults to false." LoopCount: int - 'int : How many times the source should loop / repeat. Defaults to 0.' + "int : How many times the source should loop / repeat. Defaults to 0." Name: str - 'str : The name of the source.' + "str : The name of the source." Type: ReaderType or str - 'ReaderType : The type of the source.' + "ReaderType : The type of the source." Args: RequestArgs - 'RequestArgs : The additional arguments for the source.' + "RequestArgs : The additional arguments for the source." class SourceData(TypedDict): - '''A source's raw data, as a payload.''' + """A source's raw data, as a payload.""" + Name: str - 'str : The name of the source.' + "str : The name of the source." Type: str - 'str : The type of the source.' + "str : The type of the source." Volume: float - 'float : The volume of the source.' + "float : The volume of the source." Duration: int - 'int : The duration of the source in milliseconds.' + "int : The duration of the source in milliseconds." Remaining: int - 'int : How many more milliseconds the source will be playing.' + "int : How many more milliseconds the source will be playing." Paused: bool - 'bool : Wether the source is paused or not.' + "bool : Wether the source is paused or not." Loop: int - 'int : How many times the source will loop. If 0, the source will not repeat itself.' + "int : How many times the source will loop. If 0, the source will not repeat itself." ID: int - 'int : The ID of the source.' + "int : The ID of the source." EndTime: str - 'str : The estimated timestamp for when the source will finish playing.' + "str : The estimated timestamp for when the source will finish playing." StartTime: str - 'str : When the source started playing.' + "str : When the source started playing." Request: RequestData - 'RequestData : The request used to create the source.' + "RequestData : The request used to create the source." class AudioStatus(TypedDict): - 'The raw data read from /tmp/audioStatus.json.' + "The raw data read from /tmp/audioStatus.json." Sources: List[SourceData] or None - 'List[SourceData] : The sources that are know to the audio manager.' + "List[SourceData] : The sources that are know to the audio manager." Running: bool - 'bool : Wether the audio manager knows any sources or not.' + "bool : Wether the audio manager knows any sources or not." Disabled: bool - 'bool : Wether the audio manager is disabled or not.' + "bool : Wether the audio manager is disabled or not." diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index 7b93d05..8c9e2e9 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,17 +1,15 @@ """Interface with the Replit Database.""" -import asyncio -import functools import json import os from sys import stderr -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import urllib import aiohttp -import nest_asyncio +import requests -JSON_TYPE = Union[str, int, float, bool, type(None), dict, list] +JSON_TYPE = Optional[Union[str, int, float, bool, dict, list]] class AsyncJSONKey: @@ -62,7 +60,7 @@ class AsyncJSONKey: def _type_mismatch_msg(self, data: Any) -> str: return ( f"Type mismatch: Got type {type(data).__name__}," - "expected {self.dtype.__name__}" + f"expected {self.dtype.__name__}" ) async def get(self) -> JSON_TYPE: @@ -71,6 +69,11 @@ class AsyncJSONKey: If an invalid JSON value is read or the type does not match, it will show a prompt asking the user what to do unless discard_bad_data is set. + Raises: + KeyError: If do_raise is true and the key does not exist. + json.JSONDecodeError: If do_raise is true and invalid JSON data is read + from the key. + Returns: JSON_TYPE: The value read from the database """ @@ -257,7 +260,16 @@ class AsyncReplitDb: Returns: Tuple[str]: The values in the database. """ - return tuple((await self.to_dict()).values()) + data = await self.to_dict() + return tuple(data.values()) + + async def items(self) -> Tuple[Tuple[str]]: + """Convert the database to a dict and return the dict's items method. + + Returns: + Tuple[Tuple[str]]: The items + """ + return (await self.to_dict()).items() def jsonkey( self, @@ -300,58 +312,321 @@ class AsyncReplitDb: Returns: A string representation of the database object. """ - return f"" - - -def _async2sync(coro: Callable) -> None: - @functools.wraps(coro) - def sync_func(self: object, *args: Any, **kwargs: Any) -> Any: - return asyncio.run(coro(self, *args, **kwargs)) - - return sync_func + return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" class JSONKey(AsyncJSONKey): - """Represents an key in the async database that holds a JSON value. + """Represents a key in the database that holds a JSON value. db.jsonkey() will initialize an instance for you, you don't have to do it manually. """ - get = _async2sync(AsyncJSONKey.get) - set = _async2sync(AsyncJSONKey.set) + __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise") + + def __init__( + self, + db: Any, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, + do_raise: bool = False, + ) -> None: + """Initialize the key. + + Args: + db (Any): An instance of ReplitDb + key (str): The key to read + dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. + get_default (Callable): A function that returns the default + value if the key is not set. If it is None (the default) the dtype + argument is used. + discard_bad_data (bool): Don't prompt if bad data is read, overwrite it + with the default. Defaults to False. + do_raise (bool): Whether to raise exceptions when errors are encountered. + """ + self.db = db + self.key = key + self.dtype = dtype + self.get_default = get_default + self.discard_bad_data = discard_bad_data + self.do_raise = do_raise + + def get(self) -> JSON_TYPE: + """Get the value of the key. + + If an invalid JSON value is read or the type does not match, it will show a + prompt asking the user what to do unless discard_bad_data is set. + + Returns: + JSON_TYPE: The value read from the database + """ + try: + read = self.db[self.key] + except KeyError: + print(f"Database key {self.key} not set, setting it to default value") + default = self._default() + self.db[self.key] = default + return default + + if isinstance(self.db, ReplitDb): + try: + data = json.loads(read) + except json.JSONDecodeError: + return self._error("Invalid JSON data read", read) + else: + data = read + + if not self._is_valid_type(data): + return self._error(self._type_mismatch_msg(data), read,) + return data + + def _error(self, error: str, read: str) -> JSON_TYPE: + print(f"Error reading key {self.key!r}: {error}", file=stderr) + if self.discard_bad_data: + val = self._default() + self.db[self.key] = json.dumps(val) + print(f"Wrote default to key {self.key!r}") + return val + return self._should_discard_prompt(error, read) + + def _should_discard_prompt(self, error: str, read: str) -> bool: + while True: + choice = input( + "d to use default, v to view the invalid data, c to insert custom " + "value, ^C to exit: " + ) + if choice.startswith("d"): + print("Writing default...") + val = self._default() + self.db[self.key] = val + return val + elif choice.startswith("v"): + print(f"Data read from key: {read!r}") + elif choice.startswith("c"): + toset = input( + f"Enter data to write, should be of type {self.dtype.__name__!r}" + " (leave blank to return to menu): " + ) + if not toset: + continue + try: + data = json.loads(toset) + except json.JSONDecodeError: + print("Invalid JSON data!") + else: + if not self._is_valid_type(data): + print(self._type_mismatch_msg(data)) + continue + + self.db[self.key] = toset + print("Wrote data to key") + return data + + def set(self, data: JSON_TYPE) -> None: + """Set the value of the jsonkey. + + Args: + data (JSON_TYPE): The value to set it to. + + Raises: + TypeError: The type of the value set does not match the datatype. + """ + if not self._is_valid_type(data): + raise TypeError(self._type_mismatch_msg(data)) + if isinstance(self.db, ReplitDb): + data = json.dumps(data) + self.db[self.key] = data + + def read(self, key: str, default: Any = None) -> Any: + """Shorthand for self.get().get(name, default) if datatype is dict. + + Args: + key (str): The name to get. + default (Any): The default if the key doesn't exist. Defaults to None. + + Returns: + Any: The value read or the default. + """ + return self.get().get(key, default) + + def keys(self, *keys: str) -> Any: + """Reads multiple keys from the key's value and allows setting. + + Args: + *keys (str): The keys to read from the data. + + Returns: + Any: The value accessed from self.get()[k1][k2][kn] + """ + data = self + for key in keys[:-1]: + data = type(self)(db=data, key=key, dtype=Any) + check = data[keys[-1]] + if type(check) is dict: + return type(self)(db=data, key=keys[-1], dtype=dict) + else: + return check + + def __getitem__(self, name: str) -> JSON_TYPE: + """Retrieve a key from the JSONKey's value if it is a dict. + + Args: + name (str): The name to retrieve. + + Returns: + JSON_TYPE: The value of the key. + """ + return self.keys(name) + + def __setitem__(self, name: str, value: JSON_TYPE) -> None: + """Sets a key inside the JSONKey's value if it is a dict. + + Args: + name (str): The key to set. + value (JSON_TYPE): The value to set it to. + """ + data = self.get() + data[name] = value + self.set(data) + + def append(self, item: JSON_TYPE) -> None: + """Append to the JSONKey's value if it is a list. + + Args: + item (JSON_TYPE): The item to append. + """ + data = self.get() + self.set(data + [item]) + + def __add__(self, item: Any) -> Any: + """Add to the JSONKey's value and return the result. + + Args: + item (Any): The item to add. + + Returns: + Any: The result of adding. + """ + return self.get() + item + + def __iadd__(self, item: Any) -> Any: + """Add to the JSONKey's value and set the result. + + Args: + item (Any): The item to add. + + Returns: + Any: self + """ + r = self.get() + item + self.set(r) + return self class ReplitDb(AsyncReplitDb): - """Client interface with the Replit Database.""" + """Interface with the Replit Database.""" - def __getitem__(self, item: str) -> str: - """Retrieve a key from the database. + __slots__ = ("db_url", "sess") + + def __init__(self, db_url: str) -> None: + """Initialize database. You shouldn't have to do this manually. Args: - item (str): The key to retrieve. + db_url (str): Database url to use. + """ + self.db_url = db_url + self.sess = requests.Session() + + def __getitem__(self, key: str) -> str: + """Get the value of an item from the database. + + Args: + key (str): The key to retreive + + Raises: + KeyError: Key is not set Returns: - str: The value of the key. + str: The value of the key """ - return self.get(item) + r = self.sess.get(f"{self.db_url}/{key}") + if r.status_code == 404: + raise KeyError(key) - def __setitem__(self, item: str, value: str) -> None: - """Set a key in the database. + r.raise_for_status() + return r.text + + def __setitem__(self, key: str, value: str) -> None: + """Set a key in the database to value. Args: - item (str): The key to set. - value (str): The value to set the key to. + key (str): The key to set + value (str): The value to set it to """ - self.set(item, value) + r = self.sess.post(self.db_url, data={key: value}) + r.raise_for_status() - def __delitem__(self, name: str) -> None: - """Delete a key in the database. + def __delitem__(self, key: str) -> None: + """Delete a key from the database. Args: - name (str): The key to delete. + key (str): The key to delete """ - self.delete(name) + r = self.sess.delete(f"{self.db_url}/{key}") + r.raise_for_status() + + def keys(self, prefix: str = "") -> Tuple[str]: + """Return all of the keys in the database. + + Args: + prefix (str): The prefix the keys must start with, + blank means anything. Defaults to "". + + Returns: + Tuple[str]: The keys found. + """ + r = requests.get(f"{self.db_url}", params={"prefix": prefix}) + r.raise_for_status() + + if not r.text: + return tuple() + else: + return tuple(r.text.split("\n")) + + def to_dict(self, prefix: str = "") -> Dict[str, str]: + """Dump all data in the database into a dictionary. + + Args: + prefix (str): The prefix the keys must start with, + blank means anything. Defaults to "". + + Returns: + Dict[str, str]: All keys in the database. + """ + keys = self.keys() + data = {} + for k in keys: + data[k] = self[k] + return data + + def values(self) -> Tuple[str]: + """Get every value in the database. + + Returns: + Tuple[str]: The values in the database. + """ + data = self.to_dict() + return tuple(data.values()) + + def items(self) -> Tuple[Tuple[str]]: + """Convert the database to a dict and return the dict's items method. + + Returns: + Tuple[Tuple[str]]: The items + """ + return self.to_dict().items() def jsonkey( self, @@ -360,7 +635,7 @@ class ReplitDb(AsyncReplitDb): get_default: Callable = None, discard_bad_data: bool = False, ) -> JSONKey: - """Initialize an JSONKey instance. + """Initialize a JSONKey instance. A JSONKey is used to easily read and set JSON data from the database. Arguments are the same as JSONKey constructor. @@ -378,23 +653,22 @@ class ReplitDb(AsyncReplitDb): JSONKey: The initialized JSONKey instance. """ return JSONKey( - db=super(), + db=self, key=key, dtype=dtype, get_default=get_default, discard_bad_data=discard_bad_data, ) - get = _async2sync(AsyncReplitDb.get) - set = _async2sync(AsyncReplitDb.set) - delete = _async2sync(AsyncReplitDb.delete) - list = _async2sync(AsyncReplitDb.list) - keys = _async2sync(AsyncReplitDb.keys) - to_dict = _async2sync(AsyncReplitDb.to_dict) - values = _async2sync(AsyncReplitDb.values) + def __repr__(self) -> str: + """A representation of the database. + + Returns: + A string representation of the database object. + """ + return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" -nest_asyncio.apply() db_url = os.environ.get("REPLIT_DB_URL") if db_url: db = ReplitDb(db_url) diff --git a/src/replit/maqpy/__init__.py b/src/replit/maqpy/__init__.py index 804a26a..c422e83 100644 --- a/src/replit/maqpy/__init__.py +++ b/src/replit/maqpy/__init__.py @@ -11,6 +11,8 @@ from .files import File from .html import HTMLElement, Link, Page, Paragraph from .utils import ( authed_ratelimit, + chain_decorators, + find, local_redirect, needs_params, needs_sign_in, @@ -18,10 +20,25 @@ from .utils import ( sign_in_page, sign_in_snippet, ) -from ..database import db +from ..database import AsyncJSONKey, AsyncReplitDb, db, JSONKey, ReplitDb auth = LocalProxy(lambda: flask.request.auth) signed_in = LocalProxy(lambda: flask.request.signed_in) request = LocalProxy(lambda: flask.request) render_template = flask.render_template redirect = flask.redirect + + +def user_data(username: str) -> JSONKey: + """Shorthand for db.jsonkey(username, dict). + + Args: + username (str): The key to use for the JSONKey. + + Returns: + JSONKey: An initialized JSONKey. + """ + return db.jsonkey(username, dict) + + +current_user_data = LocalProxy(lambda: user_data(flask.request.auth.name)) diff --git a/src/replit/maqpy/app.py b/src/replit/maqpy/app.py index a72de74..543aadc 100644 --- a/src/replit/maqpy/app.py +++ b/src/replit/maqpy/app.py @@ -1,7 +1,8 @@ """Core of maqpy.""" from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Set +from pathlib import Path +from typing import Any, Callable, List, Set import flask @@ -39,7 +40,7 @@ class ReplitAuthContext: Returns: bool: whether or not the authentication is activated. """ - return self.name != "" + return bool(self.name) class Request(flask.Request): @@ -74,7 +75,24 @@ class App(flask.Flask): request_class = Request - def login_wall(self, exclude: Set[str] = ("/",), handler: Callable = None) -> None: + def __init__( + self, import_name: str, nice_jinja: bool = True, **kwargs: Any + ) -> None: + """Initialize the app. + + Args: + import_name (str): The name of the app, usually __name__ + nice_jinja (bool): Whether to change jinja settings to make them + prettier. Defaults to True. + **kwargs (Any): Extra keyword arguments to be passed to the flask init + function. + """ + super().__init__(import_name, **kwargs) + if nice_jinja: + self.jinja_env.trim_blocks = True + self.jinja_env.lstrip_blocks = True + + def login_wall(self, exclude: Set[str] = ("/",), handler: Callable = None,) -> None: """Require users to be logged-in on all pages. Args: @@ -144,12 +162,49 @@ class App(flask.Flask): """ return super().run(*args, **kwargs) - def run(self, port: int = 8080, localhost: bool = False) -> None: + def run(self, port: int = 8080, localhost: bool = False, **kwargs: Any) -> None: """Run the app. Args: port (int): The port to run the app on. Defaults to 8080. localhost (bool): Whether to run the app without exposing it on all interfaces. Defaults to False. + **kwargs (Any): Extra keyword arguments to be passed to the flask app's run + method. """ - super().run(host="localhost" if localhost else "0.0.0.0", port=port) + super().run(host="localhost" if localhost else "0.0.0.0", port=port, **kwargs) + + def debug( + self, + watch_dirs: List[str] = None, + watch_files: List[str] = None, + port: int = 8080, + localhost: bool = False, + **kwargs: Any + ) -> None: + """Run the app in debug mode. + + Args: + watch_dirs (List[str]): Directories whose files will be added to + watch_files. Defaults to []. + watch_files (List[str]): Files to watch, and if changes are detected + the server will be restarted. Defaults to []. + port (int): The port to run the app on. Defaults to 8080. + localhost (bool): Whether to run the app without exposing it on all + interfaces. Defaults to False. + **kwargs (Any): Extra keyword arguments to be passed to the flask app's run + method. + """ + watch_files = list(watch_files or []) + + for directory in watch_dirs or []: + if not isinstance(directory, Path): + directory = Path(directory) + watch_files += [str(f) for f in directory.iterdir() if f.is_file()] + + super().run( + host="localhost" if localhost else "0.0.0.0", + port=port, + debug=True, + extra_files=watch_files, + ) diff --git a/src/replit/maqpy/html.py b/src/replit/maqpy/html.py index eb877d0..70f5eb1 100644 --- a/src/replit/maqpy/html.py +++ b/src/replit/maqpy/html.py @@ -1,6 +1,5 @@ """Python object representations of HTML.""" from abc import ABC -from dataclasses import dataclass import flask @@ -58,22 +57,23 @@ class Link(HTMLElement): class Page(flask.Response): - """Represents an HTML page.""" + """Represents an HTML page.""" - def __init__(self, title: str = None, head: str = "", body: str = "") -> None: - """Initialize the class. - - Args: - title (str): The title of the page. If not provided no title tag will be sent. - head (str): The HTML to put in the head of the page. Defaults to nothing. - body (str): The HTML to put in the body of the page. Defaults to nothing. - """ - self.title = title - self.head = head - self.body = body + def __init__(self, title: str = None, head: str = "", body: str = "") -> None: + """Initialize the class. - title_html = f"{self.title}\n " if self.title else "" - super().__init__( + Args: + title (str): The title of the page. If not provided no title tag will be + added. + head (str): The HTML to put in the head of the page. Defaults to nothing. + body (str): The HTML to put in the body of the page. Defaults to nothing. + """ + self.title = title + self.head = head + self.body = body + + title_html = f"{self.title}\n " if self.title else "" + super().__init__( f""" @@ -83,4 +83,4 @@ class Page(flask.Response): {self.body} """ - ) + ) diff --git a/src/replit/maqpy/utils.py b/src/replit/maqpy/utils.py index 2b52cd3..d416431 100644 --- a/src/replit/maqpy/utils.py +++ b/src/replit/maqpy/utils.py @@ -1,7 +1,7 @@ """Utitilities to make development easier.""" from functools import wraps import time -from typing import Any, Callable, Union +from typing import Any, Callable, Iterable, Optional, Union import flask from werkzeug.local import LocalProxy @@ -196,3 +196,48 @@ def authed_ratelimit( return handler return decorator + + +def find( + data: Iterable, cond: Callable[[Any], bool], allow_multiple: bool = False +) -> Optional[Any]: + """Find an item in an iterable. + + Args: + data (Iterable): The iterable to search through. + cond (Callable[[Any], bool]): The function to call for each item to check if it + is a match. + allow_multiple (bool): If multiple result are found, return the first one if + allow_multiple is True, otherwise return None. + + Returns: + Optional[Any]: The item if exactly one match was found, otherwise None. + """ + matches = [item for item in data if cond(item)] + if len(matches) > 1: + return matches[0] if allow_multiple else None + return matches[0] if len(matches) == 1 else None + + +def chain_decorators(*decorators: Callable[[Callable], Any]) -> Callable: + """Return a decorator that applies each of the decorators to the function. + + Args: + *decorators (Callable[[Callable], Any]): The decorators to apply to the + function. They are treated as if they are written in the order they appear. + + Raises: + TypeError: If no decorators are passed. + + Returns: + Callable: A decorator function. + """ + + def dec(func: Callable) -> Callable: + for decorator in reversed(list(decorators) + [wraps(func)]): + func = decorator(func) + return func + + if not decorators: + raise TypeError("You must provide at least one decorator to chain") + return dec