From 9cf5c102ca2c04e7a61155e75ec4595dc8ad66db Mon Sep 17 00:00:00 2001 From: Scoder12 <34356756+Scoder12@users.noreply.github.com> Date: Tue, 4 Aug 2020 16:30:26 -0700 Subject: [PATCH] Overhaul async and sync code Add new _async2sync function to copy async code into a sync class --- src/replit/database/__init__.py | 144 +++++++++++++++++++++++++++++--- src/replit/database/_async.py | 47 ----------- 2 files changed, 134 insertions(+), 57 deletions(-) delete mode 100644 src/replit/database/_async.py diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index cb1884d..15248ef 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,6 +1,8 @@ """Interface with the Replit Database.""" import asyncio +import functools import json +import inspect import os from sys import stderr from typing import Any, Callable, Dict, Tuple, Union @@ -11,8 +13,8 @@ import aiohttp JSON_TYPE = Union[str, int, float, bool, type(None), dict, list] -class JSONKey: - """Represents a key in the database that holds a JSON value. +class AsyncJSONKey: + """Represents an key in the async database that holds a JSON value. db.jsonkey() will initialize an instance for you, you don't have to do it manually. @@ -142,7 +144,7 @@ class JSONKey: await self.db.set(self.key, json.dumps(data)) -class AsyncClient: +class AsyncReplitDb: """Async client interface with the Replit Database.""" __slots__ = ("db_url", "sess") @@ -230,6 +232,11 @@ class AsyncClient: return ret async def keys(self) -> Tuple[str]: + """Get all keys in the database. + + Returns: + Tuple[str]: The keys in the database. + """ return tuple(await self.list("")) async def values(self) -> Tuple[str]: @@ -240,6 +247,128 @@ class AsyncClient: """ return tuple((await self.to_dict()).values()) + def jsonkey( + self, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, + ) -> AsyncJSONKey: + """Initialize an AsyncJSONKey instance. + + A AsyncJSONKey is used to easily read and set JSON data from the database. + Arguments are the same as AsyncJSONKey constructor. + + Args: + key (str): The key to read + dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. + get_default (Callable): A function that returns the default + value if the key is not set. If it is None (the default) the dtype + argument is used. + discard_bad_data (bool): Don't prompt if bad data is read, overwrite it + with the default. Defaults to False. + + Returns: + AsyncJSONKey: The initialized AsyncJSONKey instance. + """ + return AsyncJSONKey( + db=self, + key=key, + dtype=dtype, + get_default=get_default, + discard_bad_data=discard_bad_data, + ) + + def __repr__(self) -> str: + """A representation of the database. + + Returns: + A string representation of the database object. + """ + return f"" + + +def _async2sync(src_cls: object, res_cls: object) -> None: + class ReprWrapped(object): + def __init__(self, func: Callable) -> None: + self.func = func + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.func(*args, **kwargs) + + def __repr__(self) -> str: + return ( + f"" + ) + + for attr in dir(src_cls): + if attr in ["__class__", "__dict__"]: + continue + val = getattr(src_cls, attr) + if inspect.iscoroutinefunction(val): + # Convert the async function to sync with asyncio.run + @functools.wraps(val) + def sync_func(*args: Any, **kwargs: Any) -> Any: + print(f"Calling {attr}: {val}") + return asyncio.run(val(*args, **kwargs)) + + setattr(res_cls, attr, ReprWrapped(sync_func)) + elif inspect.isfunction(val): + # Wrap the source function + @functools.wraps(val) + def new_func(*args: Any, **kwargs: Any) -> Any: + return val(*args, **kwargs) + + setattr(res_cls, attr, ReprWrapped(new_func)) + else: + setattr(res_cls, attr, val) + + +class JSONKey: + """Represents an key in the async database that holds a JSON value. + + db.jsonkey() will initialize an instance for you, + you don't have to do it manually. + """ + + pass + + +_async2sync(AsyncJSONKey, JSONKey) + + +class ReplitDb: + """Client interface with the Replit Database.""" + + def __getitem__(self, item: str) -> str: + """Retrieve a key from the database. + + Args: + item (str): The key to retrieve. + + Returns: + str: The value of the key. + """ + return self.get(item) + + def __setitem__(self, item: str, value: str) -> None: + """Set a key in the database. + + Args: + item (str): The key to set. + value (str): The value to set the key to. + """ + self.set(item, value) + + def __delitem__(self, name: str) -> None: + """Delete a key in the database. + + Args: + name (str): The key to delete. + """ + self.delete(name) + def jsonkey( self, key: str, @@ -247,7 +376,7 @@ class AsyncClient: get_default: Callable = None, discard_bad_data: bool = False, ) -> JSONKey: - """Initialize a JSONKey instance. + """Initialize an JSONKey instance. A JSONKey is used to easily read and set JSON data from the database. Arguments are the same as JSONKey constructor. @@ -272,13 +401,8 @@ class AsyncClient: discard_bad_data=discard_bad_data, ) - def __repr__(self) -> str: - """A representation of the database. - Returns: - A string representation of the database object. - """ - return f"" +_async2sync(AsyncReplitDb, ReplitDb) db_url = os.environ.get("REPLIT_DB_URL") diff --git a/src/replit/database/_async.py b/src/replit/database/_async.py deleted file mode 100644 index 08af0ab..0000000 --- a/src/replit/database/_async.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Allows asyncio.run to work in async environment.""" -import asyncio -import sys -import threading -from typing import Any, Callable, List - -asyncio._run = asyncio.run - - -class AsyncThread(threading.Thread): - def __init__(self, res: List[None], exc: List[None], func: Callable) -> None: - self.result = res - self.exc = exc - self.func = func - threading.Thread.__init__(self) - - def run(self) -> Any: - def inner(func: Callable) -> Any: - try: - res = asyncio.run(func) - except Exception: - self.exc[0] = sys.exc_info() - res = "" - return res - - self.result[0] = inner(self.func) - - -def run(func: Callable) -> Any: - def error() -> Any: - ret = [None] - exc = [None] - thread = AsyncThread(ret, exc, func) - thread.start() - thread.join() - exc = exc[0] - if exc is not None: - raise exc[1].with_traceback(exc[2]) - sys.exit(1) - return ret[0] - - try: - return asyncio._run(func) - except RuntimeError: - return error() - except RuntimeWarning: - return error()