diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index b8b222b..1d9e82c 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,17 +1,16 @@ """Interface with the Replit Database.""" +import asyncio import json -import aiohttp import os from sys import stderr -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union + +import aiohttp from . import _async -import asyncio - asyncio.run = _async.run -import requests JSON_TYPE = Union[str, int, float, bool, type(None), dict, list] @@ -26,12 +25,12 @@ class JSONKey: __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data") def __init__( - self, - db: Any, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, + self, + db: Any, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, ) -> None: """Initialize the key. @@ -59,8 +58,10 @@ class JSONKey: return self.dtype is Any or isinstance(data, self.dtype) def _type_mismatch_msg(self, data: Any) -> str: - return (f"Type mismatch: Got type {type(data).__name__}," - "expected {self.dtype.__name__}") + return ( + f"Type mismatch: Got type {type(data).__name__}," + "expected {self.dtype.__name__}" + ) def get(self) -> JSON_TYPE: """Get the value of the key. @@ -74,9 +75,7 @@ class JSONKey: try: read = self.db[self.key] except KeyError: - print( - f"Database key {self.key} not set, setting it to default value" - ) + print(f"Database key {self.key} not set, setting it to default value") default = self._default() self.db[self.key] = default return default @@ -87,10 +86,7 @@ class JSONKey: return self._error("Invalid JSON data read", read) if not self._is_valid_type(data): - return self._error( - self._type_mismatch_msg(data), - read, - ) + return self._error(self._type_mismatch_msg(data), read,) return data def _error(self, error: str, read: str) -> JSON_TYPE: @@ -106,7 +102,8 @@ class JSONKey: while True: choice = input( "d to use default, v to view the invalid data, c to insert custom " - "value, ^C to exit: ") + "value, ^C to exit: " + ) if choice.startswith("d"): print("Writing default...") val = self._default() @@ -117,7 +114,8 @@ class JSONKey: elif choice.startswith("c"): toset = input( f"Enter data to write, should be of type {self.dtype.__name__!r}" - " (leave blank to return to menu): ") + " (leave blank to return to menu): " + ) if not toset: continue try: @@ -234,11 +232,11 @@ class ReplitDb(dict): return self.to_dict().items() def jsonkey( - self, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, + self, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, ) -> JSONKey: """Initialize a JSONKey instance. @@ -274,22 +272,47 @@ class ReplitDb(dict): return f"" -class AsyncClient(): - def __init__(self, db_url): +class AsyncClient: + def __init__(self, db_url: str) -> None: + self.db_url = db_url + self.backend = _AsyncBackend(db_url) + + async def view(self, key: str) -> str: + return await self.backend.view(key) + + async def list(self, prefix: str) -> List[str]: + return await self.backend.list(prefix) + + async def delete(self, key: str) -> None: + return await self.backend.delete(key) + + async def set(self, key: str, val: str) -> None: + return await self.backend.set(key, val) + + async def to_dict(self) -> Dict[str, str]: + ret = {} + keys = await self.keys() + for i in keys: + ret[i] = await self.view(i) + return ret + + async def keys(self) -> Tuple[str]: + return tuple(await self.list("")) + + async def values(self) -> Tuple[str]: + return tuple((await self.to_dict()).values()) + + +class _AsyncBackend: + def __init__(self, db_url: str) -> None: self.db_url = db_url - -class _AsyncBackend(): - def __init__(self, db_url): - self.db_url = db_url - - async def set(self, key, val): + async def set(self, key: str, val: str) -> None: async with aiohttp.ClientSession() as session: async with session.post(self.db_url, data={key: val}) as response: response.raise_for_status() - return await response.text() - async def view(self, key): + async def view(self, key: str) -> str: async with aiohttp.ClientSession() as session: async with session.get(self.db_url + "/" + key) as response: if response.status == 404: @@ -297,16 +320,14 @@ class _AsyncBackend(): response.raise_for_status() return await response.text() - async def delete(self, key): + async def delete(self, key: str) -> None: async with aiohttp.ClientSession() as session: async with session.delete(self.db_url + "/" + key) as response: response.raise_for_status() - return await response.text() - async def list(self, prefix): + async def list(self, prefix: str) -> Tuple[str]: async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "?prefix=" + - prefix) as response: + async with session.get(self.db_url + "?prefix=" + prefix) as response: response.raise_for_status() if not await response.text(): return tuple() diff --git a/src/replit/database/_async.py b/src/replit/database/_async.py index eafc32c..08af0ab 100644 --- a/src/replit/database/_async.py +++ b/src/replit/database/_async.py @@ -1,37 +1,40 @@ -import threading +"""Allows asyncio.run to work in async environment.""" import asyncio import sys +import threading +from typing import Any, Callable, List + asyncio._run = asyncio.run class AsyncThread(threading.Thread): - def __init__(self, res, exc, func): + def __init__(self, res: List[None], exc: List[None], func: Callable) -> None: self.result = res self.exc = exc self.func = func threading.Thread.__init__(self) - def run(self): - def inner(func): + def run(self) -> Any: + def inner(func: Callable) -> Any: try: res = asyncio.run(func) - except Exception as e: + except Exception: self.exc[0] = sys.exc_info() - res = '' + res = "" return res self.result[0] = inner(self.func) -def run(func): - def error(): +def run(func: Callable) -> Any: + def error() -> Any: ret = [None] exc = [None] thread = AsyncThread(ret, exc, func) thread.start() thread.join() exc = exc[0] - if exc != None: + if exc is not None: raise exc[1].with_traceback(exc[2]) sys.exit(1) return ret[0]