From b0bce695cd7a284d2df2f83abc502d628912f8a4 Mon Sep 17 00:00:00 2001 From: Scoder12 <34356756+Scoder12@users.noreply.github.com> Date: Tue, 4 Aug 2020 13:25:26 -0700 Subject: [PATCH] Initial code for switch to async --- poetry.lock | 6 +- pyproject.toml | 1 - src/replit/database/__init__.py | 162 +++++++++++--------------------- 3 files changed, 57 insertions(+), 112 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8022d7e..5c6b905 100644 --- a/poetry.lock +++ b/poetry.lock @@ -141,7 +141,7 @@ typed-ast = ">=1.4.0" d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] -category = "main" +category = "dev" description = "Python package for providing Mozilla's CA Bundle." name = "certifi" optional = false @@ -493,7 +493,7 @@ python-versions = "*" version = "2020.7.14" [[package]] -category = "main" +category = "dev" description = "Python HTTP for Humans." name = "requests" optional = false @@ -710,7 +710,7 @@ python-versions = "*" version = "3.7.4.2" [[package]] -category = "main" +category = "dev" description = "HTTP library with thread-safe connection pooling, file post, and more." name = "urllib3" optional = false diff --git a/pyproject.toml b/pyproject.toml index 6aeb8da..202e535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ documentation = "https://replit-python-docs.scoder12.repl.co" [tool.poetry.dependencies] python = "^3.8" -requests = "^2.24.0" typing_extensions = "^3.7.4" flask = "^1.1.2" werkzeug = "^1.0.1" diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index 1d9e82c..cb1884d 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -3,14 +3,10 @@ import asyncio import json import os from sys import stderr -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union import aiohttp -from . import _async - -asyncio.run = _async.run - JSON_TYPE = Union[str, int, float, bool, type(None), dict, list] @@ -63,7 +59,7 @@ class JSONKey: "expected {self.dtype.__name__}" ) - def get(self) -> JSON_TYPE: + async def get(self) -> JSON_TYPE: """Get the value of the key. If an invalid JSON value is read or the type does not match, it will show a @@ -73,32 +69,32 @@ class JSONKey: JSON_TYPE: The value read from the database """ try: - read = self.db[self.key] + read = await self.db.get(self.key) except KeyError: print(f"Database key {self.key} not set, setting it to default value") default = self._default() - self.db[self.key] = default + await self.db.set(self.key, default) return default try: data = json.loads(read) except json.JSONDecodeError: - return self._error("Invalid JSON data read", read) + return await self._error("Invalid JSON data read", read) if not self._is_valid_type(data): - return self._error(self._type_mismatch_msg(data), read,) + return await self._error(self._type_mismatch_msg(data), read,) return data - def _error(self, error: str, read: str) -> JSON_TYPE: + async def _error(self, error: str, read: str) -> JSON_TYPE: print(f"Error reading key {self.key!r}: {error}", file=stderr) if self.discard_bad_data: val = self._default() - self.db[self.key] = json.dumps(val) + await self.db.set(self.key, json.dumps(val)) print(f"Wrote default to key {self.key!r}") return val - return self._should_discard_prompt(error, read) + return await self._should_discard_prompt(error, read) - def _should_discard_prompt(self, error: str, read: str) -> bool: + async def _should_discard_prompt(self, error: str, read: str) -> bool: while True: choice = input( "d to use default, v to view the invalid data, c to insert custom " @@ -107,7 +103,7 @@ class JSONKey: if choice.startswith("d"): print("Writing default...") val = self._default() - self.db[self.key] = val + await self.db.set(self.key, val) return val elif choice.startswith("v"): print(f"Data read from key: {read!r}") @@ -127,11 +123,11 @@ class JSONKey: print(self._type_mismatch_msg(data)) continue - self.db[self.key] = toset + await self.db.set(self.key, toset) print("Wrote data to key") return data - def set(self, data: JSON_TYPE) -> None: + async def set(self, data: JSON_TYPE) -> None: """Set the value of the jsonkey. Args: @@ -142,11 +138,12 @@ class JSONKey: """ if not self._is_valid_type(data): raise TypeError(self._type_mismatch_msg(data)) - self.db[self.key] = json.dumps(data) + + await self.db.set(self.key, json.dumps(data)) -class ReplitDb(dict): - """Interface with the Replit Database.""" +class AsyncClient: + """Async client interface with the Replit Database.""" __slots__ = ("db_url", "sess") @@ -157,9 +154,8 @@ class ReplitDb(dict): db_url (str): Database url to use. """ self.db_url = db_url - self.sess = _AsyncBackend(db_url) - def __getitem__(self, key: str) -> str: + async def get(self, key: str) -> str: """Get the value of an item from the database. Args: @@ -171,39 +167,53 @@ class ReplitDb(dict): Returns: str: The value of the key """ - r = asyncio.run(self.sess.view(key)) - return r + async with aiohttp.ClientSession() as session: + async with session.get(self.db_url + "/" + key) as response: + if response.status == 404: + raise KeyError(key) + response.raise_for_status() + return await response.text() - def __setitem__(self, key: str, value: str) -> None: + async def set(self, key: str, value: str) -> None: """Set a key in the database to value. Args: key (str): The key to set value (str): The value to set it to """ - asyncio.run(self.sess.set(key, value)) + async with aiohttp.ClientSession() as session: + async with session.post(self.db_url, data={key: value}) as response: + response.raise_for_status() - def __delitem__(self, key: str) -> None: + async def delete(self, key: str) -> None: """Delete a key from the database. Args: key (str): The key to delete """ - asyncio.run(self.sess.delete(key)) + async with aiohttp.ClientSession() as session: + async with session.delete(self.db_url + "/" + key) as response: + response.raise_for_status() - def keys(self, prefix: str = "") -> Tuple[str]: - """Return all of the keys in the database. + async def list(self, prefix: str) -> Tuple[str]: + """List keys in the database which start with prefix. Args: - prefix (str): The prefix the keys must start with, - blank means anything. Defaults to "". + prefix (str): The prefix keys must start with, blank not not check. Returns: Tuple[str]: The keys found. """ - return asyncio.run(self.sess.list(prefix)) + async with aiohttp.ClientSession() as session: + async with session.get(self.db_url + "?prefix=" + prefix) as response: + response.raise_for_status() + text = await response.text() + if not text: + return tuple() + else: + return tuple(text.split("\n")) - def to_dict(self, prefix: str = "") -> Dict[str, str]: + async def to_dict(self, prefix: str = "") -> Dict[str, str]: """Dump all data in the database into a dictionary. Args: @@ -213,23 +223,22 @@ class ReplitDb(dict): Returns: Dict[str, str]: All keys in the database. """ - keys = self.keys() - data = {} - for k in keys: - data[k] = self[k] - return data + ret = {} + keys = await self.list(prefix=prefix) + for i in keys: + ret[i] = await self.view(i) + return ret - def values(self) -> Tuple[str]: + async def keys(self) -> Tuple[str]: + return tuple(await self.list("")) + + async def values(self) -> Tuple[str]: """Get every value in the database. Returns: Tuple[str]: The values in the database. """ - data = self.to_dict() - return tuple(data.values()) - - def items(self) -> Tuple[Tuple[str]]: - return self.to_dict().items() + return tuple((await self.to_dict()).values()) def jsonkey( self, @@ -272,69 +281,6 @@ class ReplitDb(dict): return f"" -class AsyncClient: - def __init__(self, db_url: str) -> None: - self.db_url = db_url - self.backend = _AsyncBackend(db_url) - - async def view(self, key: str) -> str: - return await self.backend.view(key) - - async def list(self, prefix: str) -> List[str]: - return await self.backend.list(prefix) - - async def delete(self, key: str) -> None: - return await self.backend.delete(key) - - async def set(self, key: str, val: str) -> None: - return await self.backend.set(key, val) - - async def to_dict(self) -> Dict[str, str]: - ret = {} - keys = await self.keys() - for i in keys: - ret[i] = await self.view(i) - return ret - - async def keys(self) -> Tuple[str]: - return tuple(await self.list("")) - - async def values(self) -> Tuple[str]: - return tuple((await self.to_dict()).values()) - - -class _AsyncBackend: - def __init__(self, db_url: str) -> None: - self.db_url = db_url - - async def set(self, key: str, val: str) -> None: - async with aiohttp.ClientSession() as session: - async with session.post(self.db_url, data={key: val}) as response: - response.raise_for_status() - - async def view(self, key: str) -> str: - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "/" + key) as response: - if response.status == 404: - raise KeyError(key) - response.raise_for_status() - return await response.text() - - async def delete(self, key: str) -> None: - async with aiohttp.ClientSession() as session: - async with session.delete(self.db_url + "/" + key) as response: - response.raise_for_status() - - async def list(self, prefix: str) -> Tuple[str]: - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "?prefix=" + prefix) as response: - response.raise_for_status() - if not await response.text(): - return tuple() - else: - return tuple((await response.text()).split("\n")) - - db_url = os.environ.get("REPLIT_DB_URL") if db_url: db = ReplitDb(db_url)