Fix _async2sync to pass a correct self value

Initialize a _super variable in __init__ which is an instance of the async base class and is used as the self argument to async methods so that the class can await internal methods
This commit is contained in:
Scoder12
2020-08-10 15:09:40 -07:00
parent 764804fb6b
commit 1c4c7c069f
+42 -2
View File
@@ -300,13 +300,14 @@ class AsyncReplitDb:
Returns:
A string representation of the database object.
"""
return f"<ReplitDb(db_url={self.db_url!r})>"
return f"<{self.__class__.__name__}(db_url={self.db_url!r})>"
def _async2sync(coro: Callable) -> None:
@functools.wraps(coro)
def sync_func(self: object, *args: Any, **kwargs: Any) -> Any:
return asyncio.run(coro(self, *args, **kwargs))
res = asyncio.run(coro(self._super, *args, **kwargs))
return res
return sync_func
@@ -318,6 +319,37 @@ class JSONKey(AsyncJSONKey):
you don't have to do it manually.
"""
def __init__(
self,
db: Any,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> None:
"""Initialize the key.
Args:
db (Any): An instance of ReplitDb
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
"""
self._super = AsyncJSONKey(
db=db,
key=key,
dtype=dtype,
get_default=get_default,
discard_bad_data=discard_bad_data,
do_raise=do_raise,
)
get = _async2sync(AsyncJSONKey.get)
set = _async2sync(AsyncJSONKey.set)
@@ -325,6 +357,14 @@ class JSONKey(AsyncJSONKey):
class ReplitDb(AsyncReplitDb):
"""Client interface with the Replit Database."""
def __init__(self, db_url: str) -> None:
"""Initialize the class.
Args:
db_url (str): The database URL to connect to.
"""
self._super = AsyncReplitDb(db_url)
def __getitem__(self, item: str) -> str:
"""Retrieve a key from the database.