diff --git a/poetry.lock b/poetry.lock index 86d781b..b82aa2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -101,7 +101,7 @@ typed-ast = ">=1.4.0" d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] -category = "main" +category = "dev" description = "Python package for providing Mozilla's CA Bundle." name = "certifi" optional = false @@ -347,6 +347,14 @@ optional = false python-versions = ">=3.5" version = "4.7.6" +[[package]] +category = "main" +description = "Patch asyncio to allow nested event loops" +name = "nest-asyncio" +optional = false +python-versions = ">=3.5" +version = "1.4.0" + [[package]] category = "dev" description = "Core utilities for Python packages" @@ -443,7 +451,7 @@ python-versions = "*" version = "2020.7.14" [[package]] -category = "main" +category = "dev" description = "Python HTTP for Humans." name = "requests" optional = false @@ -651,7 +659,7 @@ python-versions = "*" version = "3.7.4.2" [[package]] -category = "main" +category = "dev" description = "HTTP library with thread-safe connection pooling, file post, and more." name = "urllib3" optional = false @@ -688,7 +696,7 @@ idna = ">=2.0" multidict = ">=4.0" [metadata] -content-hash = "93d94d333f6a693b503220db880ca32a9115d2d29a9c3cf8dac4b597c7152ed6" +content-hash = "c92c18a07248bbd0df18fba9327c2387af07b06a9a489123e3276810fb4e89fc" python-versions = "^3.8" [metadata.files] @@ -874,6 +882,10 @@ multidict = [ {file = "multidict-4.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:7388d2ef3c55a8ba80da62ecfafa06a1c097c18032a501ffd4cabbc52d7f2b19"}, {file = "multidict-4.7.6.tar.gz", hash = "sha256:fbb77a75e529021e7c4a8d4e823d88ef4d23674a202be4f5addffc72cbb91430"}, ] +nest-asyncio = [ + {file = "nest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:ea51120725212ef02e5870dd77fc67ba7343fc945e3b9a7ff93384436e043b6a"}, + {file = "nest_asyncio-1.4.0.tar.gz", hash = "sha256:5773054bbc14579b000236f85bc01ecced7ffd045ec8ca4a9809371ec65a59c8"}, +] packaging = [ {file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"}, {file = "packaging-20.4.tar.gz", hash = "sha256:4357f74f47b9c12db93624a82154e9b120fa8293699949152b22065d556079f8"}, diff --git a/pyproject.toml b/pyproject.toml index 3289504..8d7a1d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,11 @@ documentation = "https://replit-python-docs.scoder12.repl.co" [tool.poetry.dependencies] python = "^3.8" -requests = "^2.24.0" typing_extensions = "^3.7.4" flask = "^1.1.2" werkzeug = "^1.0.1" aiohttp = "^3.6.2" +nest_asyncio = "^1.4.0" [tool.poetry.dev-dependencies] flake8 = "^3.8.3" diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index b8b222b..69174f4 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,23 +1,20 @@ """Interface with the Replit Database.""" +import asyncio +import functools import json -import aiohttp import os from sys import stderr from typing import Any, Callable, Dict, Tuple, Union -from . import _async +import aiohttp +import nest_asyncio -import asyncio - -asyncio.run = _async.run - -import requests JSON_TYPE = Union[str, int, float, bool, type(None), dict, list] -class JSONKey: - """Represents a key in the database that holds a JSON value. +class AsyncJSONKey: + """Represents an key in the async database that holds a JSON value. db.jsonkey() will initialize an instance for you, you don't have to do it manually. @@ -26,12 +23,12 @@ class JSONKey: __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data") def __init__( - self, - db: Any, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, + self, + db: Any, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, ) -> None: """Initialize the key. @@ -59,10 +56,12 @@ class JSONKey: return self.dtype is Any or isinstance(data, self.dtype) def _type_mismatch_msg(self, data: Any) -> str: - return (f"Type mismatch: Got type {type(data).__name__}," - "expected {self.dtype.__name__}") + return ( + f"Type mismatch: Got type {type(data).__name__}," + "expected {self.dtype.__name__}" + ) - def get(self) -> JSON_TYPE: + async def get(self) -> JSON_TYPE: """Get the value of the key. If an invalid JSON value is read or the type does not match, it will show a @@ -72,52 +71,49 @@ class JSONKey: JSON_TYPE: The value read from the database """ try: - read = self.db[self.key] + read = await self.db.get(self.key) except KeyError: - print( - f"Database key {self.key} not set, setting it to default value" - ) + print(f"Database key {self.key} not set, setting it to default value") default = self._default() - self.db[self.key] = default + await self.db.set(self.key, default) return default try: data = json.loads(read) except json.JSONDecodeError: - return self._error("Invalid JSON data read", read) + return await self._error("Invalid JSON data read", read) if not self._is_valid_type(data): - return self._error( - self._type_mismatch_msg(data), - read, - ) + return await self._error(self._type_mismatch_msg(data), read,) return data - def _error(self, error: str, read: str) -> JSON_TYPE: + async def _error(self, error: str, read: str) -> JSON_TYPE: print(f"Error reading key {self.key!r}: {error}", file=stderr) if self.discard_bad_data: val = self._default() - self.db[self.key] = json.dumps(val) + await self.db.set(self.key, json.dumps(val)) print(f"Wrote default to key {self.key!r}") return val - return self._should_discard_prompt(error, read) + return await self._should_discard_prompt(error, read) - def _should_discard_prompt(self, error: str, read: str) -> bool: + async def _should_discard_prompt(self, error: str, read: str) -> bool: while True: choice = input( "d to use default, v to view the invalid data, c to insert custom " - "value, ^C to exit: ") + "value, ^C to exit: " + ) if choice.startswith("d"): print("Writing default...") val = self._default() - self.db[self.key] = val + await self.db.set(self.key, val) return val elif choice.startswith("v"): print(f"Data read from key: {read!r}") elif choice.startswith("c"): toset = input( f"Enter data to write, should be of type {self.dtype.__name__!r}" - " (leave blank to return to menu): ") + " (leave blank to return to menu): " + ) if not toset: continue try: @@ -129,11 +125,11 @@ class JSONKey: print(self._type_mismatch_msg(data)) continue - self.db[self.key] = toset + await self.db.set(self.key, toset) print("Wrote data to key") return data - def set(self, data: JSON_TYPE) -> None: + async def set(self, data: JSON_TYPE) -> None: """Set the value of the jsonkey. Args: @@ -144,11 +140,12 @@ class JSONKey: """ if not self._is_valid_type(data): raise TypeError(self._type_mismatch_msg(data)) - self.db[self.key] = json.dumps(data) + + await self.db.set(self.key, json.dumps(data)) -class ReplitDb(dict): - """Interface with the Replit Database.""" +class AsyncReplitDb: + """Async client interface with the Replit Database.""" __slots__ = ("db_url", "sess") @@ -159,9 +156,8 @@ class ReplitDb(dict): db_url (str): Database url to use. """ self.db_url = db_url - self.sess = _AsyncBackend(db_url) - def __getitem__(self, key: str) -> str: + async def get(self, key: str) -> str: """Get the value of an item from the database. Args: @@ -173,39 +169,53 @@ class ReplitDb(dict): Returns: str: The value of the key """ - r = asyncio.run(self.sess.view(key)) - return r + async with aiohttp.ClientSession() as session: + async with session.get(self.db_url + "/" + key) as response: + if response.status == 404: + raise KeyError(key) + response.raise_for_status() + return await response.text() - def __setitem__(self, key: str, value: str) -> None: + async def set(self, key: str, value: str) -> None: """Set a key in the database to value. Args: key (str): The key to set value (str): The value to set it to """ - asyncio.run(self.sess.set(key, value)) + async with aiohttp.ClientSession() as session: + async with session.post(self.db_url, data={key: value}) as response: + response.raise_for_status() - def __delitem__(self, key: str) -> None: + async def delete(self, key: str) -> None: """Delete a key from the database. Args: key (str): The key to delete """ - asyncio.run(self.sess.delete(key)) + async with aiohttp.ClientSession() as session: + async with session.delete(self.db_url + "/" + key) as response: + response.raise_for_status() - def keys(self, prefix: str = "") -> Tuple[str]: - """Return all of the keys in the database. + async def list(self, prefix: str) -> Tuple[str]: + """List keys in the database which start with prefix. Args: - prefix (str): The prefix the keys must start with, - blank means anything. Defaults to "". + prefix (str): The prefix keys must start with, blank not not check. Returns: Tuple[str]: The keys found. """ - return asyncio.run(self.sess.list(prefix)) + async with aiohttp.ClientSession() as session: + async with session.get(self.db_url + "?prefix=" + prefix) as response: + response.raise_for_status() + text = await response.text() + if not text: + return tuple() + else: + return tuple(text.split("\n")) - def to_dict(self, prefix: str = "") -> Dict[str, str]: + async def to_dict(self, prefix: str = "") -> Dict[str, str]: """Dump all data in the database into a dictionary. Args: @@ -215,32 +225,127 @@ class ReplitDb(dict): Returns: Dict[str, str]: All keys in the database. """ - keys = self.keys() - data = {} - for k in keys: - data[k] = self[k] - return data + ret = {} + keys = await self.list(prefix=prefix) + for i in keys: + ret[i] = await self.view(i) + return ret - def values(self) -> Tuple[str]: + async def keys(self) -> Tuple[str]: + """Get all keys in the database. + + Returns: + Tuple[str]: The keys in the database. + """ + return tuple(await self.list("")) + + async def values(self) -> Tuple[str]: """Get every value in the database. Returns: Tuple[str]: The values in the database. """ - data = self.to_dict() - return tuple(data.values()) - - def items(self) -> Tuple[Tuple[str]]: - return self.to_dict().items() + return tuple((await self.to_dict()).values()) def jsonkey( - self, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, + self, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, + ) -> AsyncJSONKey: + """Initialize an AsyncJSONKey instance. + + A AsyncJSONKey is used to easily read and set JSON data from the database. + Arguments are the same as AsyncJSONKey constructor. + + Args: + key (str): The key to read + dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. + get_default (Callable): A function that returns the default + value if the key is not set. If it is None (the default) the dtype + argument is used. + discard_bad_data (bool): Don't prompt if bad data is read, overwrite it + with the default. Defaults to False. + + Returns: + AsyncJSONKey: The initialized AsyncJSONKey instance. + """ + return AsyncJSONKey( + db=self, + key=key, + dtype=dtype, + get_default=get_default, + discard_bad_data=discard_bad_data, + ) + + def __repr__(self) -> str: + """A representation of the database. + + Returns: + A string representation of the database object. + """ + return f"" + + +def _async2sync(coro: Callable) -> None: + @functools.wraps(coro) + def sync_func(self: object, *args: Any, **kwargs: Any) -> Any: + return asyncio.run(coro(self, *args, **kwargs)) + + return sync_func + + +class JSONKey(AsyncJSONKey): + """Represents an key in the async database that holds a JSON value. + + db.jsonkey() will initialize an instance for you, + you don't have to do it manually. + """ + + get = _async2sync(AsyncJSONKey.get) + set = _async2sync(AsyncJSONKey.set) + + +class ReplitDb(AsyncReplitDb): + """Client interface with the Replit Database.""" + + def __getitem__(self, item: str) -> str: + """Retrieve a key from the database. + + Args: + item (str): The key to retrieve. + + Returns: + str: The value of the key. + """ + return self.get(item) + + def __setitem__(self, item: str, value: str) -> None: + """Set a key in the database. + + Args: + item (str): The key to set. + value (str): The value to set the key to. + """ + self.set(item, value) + + def __delitem__(self, name: str) -> None: + """Delete a key in the database. + + Args: + name (str): The key to delete. + """ + self.delete(name) + + def jsonkey( + self, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, ) -> JSONKey: - """Initialize a JSONKey instance. + """Initialize an JSONKey instance. A JSONKey is used to easily read and set JSON data from the database. Arguments are the same as JSONKey constructor. @@ -258,62 +363,23 @@ class ReplitDb(dict): JSONKey: The initialized JSONKey instance. """ return JSONKey( - db=self, + db=super(), key=key, dtype=dtype, get_default=get_default, discard_bad_data=discard_bad_data, ) - def __repr__(self) -> str: - """A representation of the database. - - Returns: - A string representation of the database object. - """ - return f"" - - -class AsyncClient(): - def __init__(self, db_url): - self.db_url = db_url - - -class _AsyncBackend(): - def __init__(self, db_url): - self.db_url = db_url - - async def set(self, key, val): - async with aiohttp.ClientSession() as session: - async with session.post(self.db_url, data={key: val}) as response: - response.raise_for_status() - return await response.text() - - async def view(self, key): - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "/" + key) as response: - if response.status == 404: - raise KeyError(key) - response.raise_for_status() - return await response.text() - - async def delete(self, key): - async with aiohttp.ClientSession() as session: - async with session.delete(self.db_url + "/" + key) as response: - response.raise_for_status() - return await response.text() - - async def list(self, prefix): - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "?prefix=" + - prefix) as response: - response.raise_for_status() - if not await response.text(): - return tuple() - else: - return tuple((await response.text()).split("\n")) + get = _async2sync(AsyncReplitDb.get) + set = _async2sync(AsyncReplitDb.set) + delete = _async2sync(AsyncReplitDb.delete) + list = _async2sync(AsyncReplitDb.list) + keys = _async2sync(AsyncReplitDb.keys) + to_dict = _async2sync(AsyncReplitDb.to_dict) + values = _async2sync(AsyncReplitDb.values) +nest_asyncio.apply() db_url = os.environ.get("REPLIT_DB_URL") if db_url: db = ReplitDb(db_url) diff --git a/src/replit/database/_async.py b/src/replit/database/_async.py deleted file mode 100644 index eafc32c..0000000 --- a/src/replit/database/_async.py +++ /dev/null @@ -1,44 +0,0 @@ -import threading -import asyncio -import sys -asyncio._run = asyncio.run - - -class AsyncThread(threading.Thread): - def __init__(self, res, exc, func): - self.result = res - self.exc = exc - self.func = func - threading.Thread.__init__(self) - - def run(self): - def inner(func): - try: - res = asyncio.run(func) - except Exception as e: - self.exc[0] = sys.exc_info() - res = '' - return res - - self.result[0] = inner(self.func) - - -def run(func): - def error(): - ret = [None] - exc = [None] - thread = AsyncThread(ret, exc, func) - thread.start() - thread.join() - exc = exc[0] - if exc != None: - raise exc[1].with_traceback(exc[2]) - sys.exit(1) - return ret[0] - - try: - return asyncio._run(func) - except RuntimeError: - return error() - except RuntimeWarning: - return error()