From f447c4d35e7059d23e96e450213a2e48bac175cc Mon Sep 17 00:00:00 2001 From: Sidney Kochman Date: Wed, 19 Aug 2020 10:55:11 -0400 Subject: [PATCH] Database refactor (#15) --- .semaphore/semaphore.yml | 2 +- src/replit/database/__init__.py | 671 +------------------------------- src/replit/database/database.py | 232 +++++++++++ src/replit/database/jsonkey.py | 362 +++++++++++++++++ src/replit/maqpy/__init__.py | 2 +- tests/test_database.py | 66 +++- 6 files changed, 659 insertions(+), 676 deletions(-) create mode 100644 src/replit/database/database.py create mode 100644 src/replit/database/jsonkey.py diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml index ceb471f..f86db5c 100644 --- a/.semaphore/semaphore.yml +++ b/.semaphore/semaphore.yml @@ -31,4 +31,4 @@ blocks: - checkout --use-cache - python -m pip install --upgrade poetry - poetry install - - poetry run mypy src/replit || true + - bash -c '! poetry run mypy src/replit | grep database.py' diff --git a/src/replit/database/__init__.py b/src/replit/database/__init__.py index e75d607..306a37e 100644 --- a/src/replit/database/__init__.py +++ b/src/replit/database/__init__.py @@ -1,677 +1,18 @@ """Interface with the Replit Database.""" -import json import os from sys import stderr -from typing import Any, Callable, Dict, Optional, Tuple, Union -import urllib +from typing import Optional -import aiohttp -import requests +from .database import AsyncDatabase, Database +from .jsonkey import AsyncJSONKey, JSONKey - -JSON_TYPE = Optional[Union[str, int, float, bool, dict, list]] - - -class AsyncJSONKey: - """Represents an key in the async database that holds a JSON value. - - db.jsonkey() will initialize an instance for you, - you don't have to do it manually. - """ - - __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise") - - def __init__( - self, - db: Any, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, - do_raise: bool = False, - ) -> None: - """Initialize the key. - - Args: - db (Any): An instance of ReplitDb - key (str): The key to read - dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. - get_default (Callable): A function that returns the default - value if the key is not set. If it is None (the default) the dtype - argument is used. - discard_bad_data (bool): Don't prompt if bad data is read, overwrite it - with the default. Defaults to False. - do_raise (bool): Whether to raise exceptions when errors are encountered. - """ - self.db = db - self.key = key - self.dtype = dtype - self.get_default = get_default - self.discard_bad_data = discard_bad_data - self.do_raise = do_raise - - def _default(self) -> JSON_TYPE: - get_default_func = self.get_default or self.dtype - return get_default_func() - - def _is_valid_type(self, data: JSON_TYPE) -> bool: - return self.dtype is Any or isinstance(data, self.dtype) - - def _type_mismatch_msg(self, data: Any) -> str: - return ( - f"Type mismatch: Got type {type(data).__name__}," - f"expected {self.dtype.__name__}" - ) - - async def get(self) -> JSON_TYPE: - """Get the value of the key. - - If an invalid JSON value is read or the type does not match, it will show a - prompt asking the user what to do unless discard_bad_data is set. - - Raises: - KeyError: If do_raise is true and the key does not exist. - json.JSONDecodeError: If do_raise is true and invalid JSON data is read - from the key. - - Returns: - JSON_TYPE: The value read from the database - """ - try: - read = await self.db.get(self.key) - except KeyError: - if self.do_raise: - raise - print(f"Database key {self.key} not set, setting it to default value") - default = self._default() - await self.db.set(self.key, default) - return default - - try: - data = json.loads(read) - except json.JSONDecodeError: - if self.do_raise: - raise - return await self._error("Invalid JSON data read", read) - - if not self._is_valid_type(data): - return await self._error(self._type_mismatch_msg(data), read,) - return data - - async def _error(self, error: str, read: str) -> JSON_TYPE: - if self.do_raise: - raise ValueError(error) - - print(f"Error reading key {self.key!r}: {error}", file=stderr) - if self.discard_bad_data: - val = self._default() - await self.db.set(self.key, json.dumps(val)) - print(f"Wrote default to key {self.key!r}") - return val - return await self._should_discard_prompt(error, read) - - async def _should_discard_prompt(self, error: str, read: str) -> bool: - while True: - choice = input( - "d to use default, v to view the invalid data, c to insert custom " - "value, ^C to exit: " - ) - if choice.startswith("d"): - print("Writing default...") - val = self._default() - await self.db.set(self.key, val) - return val - elif choice.startswith("v"): - print(f"Data read from key: {read!r}") - elif choice.startswith("c"): - toset = input( - f"Enter data to write, should be of type {self.dtype.__name__!r}" - " (leave blank to return to menu): " - ) - if not toset: - continue - try: - data = json.loads(toset) - except json.JSONDecodeError: - print("Invalid JSON data!") - else: - if not self._is_valid_type(data): - print(self._type_mismatch_msg(data)) - continue - - await self.db.set(self.key, toset) - print("Wrote data to key") - return data - - async def set(self, data: JSON_TYPE) -> None: - """Set the value of the jsonkey. - - Args: - data (JSON_TYPE): The value to set it to. - - Raises: - TypeError: The type of the value set does not match the datatype. - """ - if not self._is_valid_type(data): - raise TypeError(self._type_mismatch_msg(data)) - - await self.db.set(self.key, json.dumps(data)) - - -class AsyncReplitDb: - """Async client interface with the Replit Database.""" - - __slots__ = ("db_url", "sess") - - def __init__(self, db_url: str) -> None: - """Initialize database. You shouldn't have to do this manually. - - Args: - db_url (str): Database url to use. - """ - self.db_url = db_url - - async def get(self, key: str) -> str: - """Get the value of an item from the database. - - Args: - key (str): The key to retreive - - Raises: - KeyError: Key is not set - - Returns: - str: The value of the key - """ - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url + "/" + key) as response: - if response.status == 404: - raise KeyError(key) - response.raise_for_status() - return await response.text() - - async def set(self, key: str, value: str) -> None: - """Set a key in the database to value. - - Args: - key (str): The key to set - value (str): The value to set it to - """ - async with aiohttp.ClientSession() as session: - async with session.post(self.db_url, data={key: value}) as response: - response.raise_for_status() - - async def delete(self, key: str) -> None: - """Delete a key from the database. - - Args: - key (str): The key to delete - """ - async with aiohttp.ClientSession() as session: - async with session.delete(self.db_url + "/" + key) as response: - response.raise_for_status() - - async def list(self, prefix: str) -> Tuple[str]: - """List keys in the database which start with prefix. - - Args: - prefix (str): The prefix keys must start with, blank not not check. - - Returns: - Tuple[str]: The keys found. - """ - params = {"prefix": prefix, "encode": "true"} - async with aiohttp.ClientSession() as session: - async with session.get(self.db_url, params=params) as response: - response.raise_for_status() - text = await response.text() - if not text: - return tuple() - else: - return tuple(urllib.parse.unquote(k) for k in text.split("\n")) - - async def to_dict(self, prefix: str = "") -> Dict[str, str]: - """Dump all data in the database into a dictionary. - - Args: - prefix (str): The prefix the keys must start with, - blank means anything. Defaults to "". - - Returns: - Dict[str, str]: All keys in the database. - """ - ret = {} - keys = await self.list(prefix=prefix) - for i in keys: - ret[i] = await self.get(i) - return ret - - async def keys(self) -> Tuple[str]: - """Get all keys in the database. - - Returns: - Tuple[str]: The keys in the database. - """ - return await self.list("") - - async def values(self) -> Tuple[str]: - """Get every value in the database. - - Returns: - Tuple[str]: The values in the database. - """ - data = await self.to_dict() - return tuple(data.values()) - - async def items(self) -> Tuple[Tuple[str]]: - """Convert the database to a dict and return the dict's items method. - - Returns: - Tuple[Tuple[str]]: The items - """ - return (await self.to_dict()).items() - - def jsonkey( - self, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, - do_raise: bool = False, - ) -> AsyncJSONKey: - """Initialize an AsyncJSONKey instance. - - A AsyncJSONKey is used to easily read and set JSON data from the database. - Arguments are the same as AsyncJSONKey constructor. - - Args: - key (str): The key to read - dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. - get_default (Callable): A function that returns the default - value if the key is not set. If it is None (the default) the dtype - argument is used. - discard_bad_data (bool): Don't prompt if bad data is read, overwrite it - with the default. Defaults to False. - do_raise (bool): Whether to raise exceptions when errors are encountered. - - Returns: - AsyncJSONKey: The initialized AsyncJSONKey instance. - """ - return AsyncJSONKey( - db=self, - key=key, - dtype=dtype, - get_default=get_default, - discard_bad_data=discard_bad_data, - do_raise=do_raise, - ) - - def __repr__(self) -> str: - """A representation of the database. - - Returns: - A string representation of the database object. - """ - return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" - - -class JSONKey(AsyncJSONKey): - """Represents a key in the database that holds a JSON value. - - db.jsonkey() will initialize an instance for you, - you don't have to do it manually. - """ - - __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise") - - def __init__( - self, - db: Any, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, - do_raise: bool = False, - ) -> None: - """Initialize the key. - - Args: - db (Any): An instance of ReplitDb - key (str): The key to read - dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. - get_default (Callable): A function that returns the default - value if the key is not set. If it is None (the default) the dtype - argument is used. - discard_bad_data (bool): Don't prompt if bad data is read, overwrite it - with the default. Defaults to False. - do_raise (bool): Whether to raise exceptions when errors are encountered. - """ - self.db = db - self.key = key - self.dtype = dtype - self.get_default = get_default - self.discard_bad_data = discard_bad_data - self.do_raise = do_raise - - def get(self) -> JSON_TYPE: - """Get the value of the key. - - If an invalid JSON value is read or the type does not match, it will show a - prompt asking the user what to do unless discard_bad_data is set. - - Returns: - JSON_TYPE: The value read from the database - """ - try: - read = self.db[self.key] - except KeyError: - print(f"Database key {self.key} not set, setting it to default value") - default = self._default() - self.db[self.key] = default - return default - - if isinstance(self.db, ReplitDb): - try: - data = json.loads(read) - except json.JSONDecodeError: - return self._error("Invalid JSON data read", read) - else: - data = read - - if not self._is_valid_type(data): - return self._error(self._type_mismatch_msg(data), read,) - return data - - def _error(self, error: str, read: str) -> JSON_TYPE: - print(f"Error reading key {self.key!r}: {error}", file=stderr) - if self.discard_bad_data: - val = self._default() - self.db[self.key] = json.dumps(val) - print(f"Wrote default to key {self.key!r}") - return val - return self._should_discard_prompt(error, read) - - def _should_discard_prompt(self, error: str, read: str) -> bool: - while True: - choice = input( - "d to use default, v to view the invalid data, c to insert custom " - "value, ^C to exit: " - ) - if choice.startswith("d"): - print("Writing default...") - val = self._default() - self.db[self.key] = val - return val - elif choice.startswith("v"): - print(f"Data read from key: {read!r}") - elif choice.startswith("c"): - toset = input( - f"Enter data to write, should be of type {self.dtype.__name__!r}" - " (leave blank to return to menu): " - ) - if not toset: - continue - try: - data = json.loads(toset) - except json.JSONDecodeError: - print("Invalid JSON data!") - else: - if not self._is_valid_type(data): - print(self._type_mismatch_msg(data)) - continue - - self.db[self.key] = toset - print("Wrote data to key") - return data - - def set(self, data: JSON_TYPE) -> None: - """Set the value of the jsonkey. - - Args: - data (JSON_TYPE): The value to set it to. - - Raises: - TypeError: The type of the value set does not match the datatype. - """ - if not self._is_valid_type(data): - raise TypeError(self._type_mismatch_msg(data)) - if isinstance(self.db, ReplitDb): - data = json.dumps(data) - self.db[self.key] = data - - def read(self, key: str, default: Any = None) -> Any: - """Shorthand for self.get().get(name, default) if datatype is dict. - - Args: - key (str): The name to get. - default (Any): The default if the key doesn't exist. Defaults to None. - - Returns: - Any: The value read or the default. - """ - return self.get().get(key, default) - - def keys(self, *keys: str) -> Any: - """Reads multiple keys from the key's value and allows setting. - - Args: - *keys (str): The keys to read from the data. - - Returns: - Any: The value accessed from self.get()[k1][k2][kn] - """ - data = self - for key in keys[:-1]: - data = type(self)(db=data, key=key, dtype=Any) - check = data[keys[-1]] - if type(check) is dict: - return type(self)(db=data, key=keys[-1], dtype=dict) - else: - return check - - def __getitem__(self, name: str) -> JSON_TYPE: - """Retrieve a key from the JSONKey's value if it is a dict. - - Args: - name (str): The name to retrieve. - - Returns: - JSON_TYPE: The value of the key. - """ - return self.keys(name) - - def __setitem__(self, name: str, value: JSON_TYPE) -> None: - """Sets a key inside the JSONKey's value if it is a dict. - - Args: - name (str): The key to set. - value (JSON_TYPE): The value to set it to. - """ - data = self.get() - data[name] = value - self.set(data) - - def append(self, item: JSON_TYPE) -> None: - """Append to the JSONKey's value if it is a list. - - Args: - item (JSON_TYPE): The item to append. - """ - data = self.get() - self.set(data + [item]) - - def __add__(self, item: Any) -> Any: - """Add to the JSONKey's value and return the result. - - Args: - item (Any): The item to add. - - Returns: - Any: The result of adding. - """ - return self.get() + item - - def __iadd__(self, item: Any) -> Any: - """Add to the JSONKey's value and set the result. - - Args: - item (Any): The item to add. - - Returns: - Any: self - """ - r = self.get() + item - self.set(r) - return self - - -class ReplitDb(AsyncReplitDb): - """Interface with the Replit Database.""" - - __slots__ = ("db_url", "sess") - - def __init__(self, db_url: str) -> None: - """Initialize database. You shouldn't have to do this manually. - - Args: - db_url (str): Database url to use. - """ - self.db_url = db_url - self.sess = requests.Session() - - def __getitem__(self, key: str) -> str: - """Get the value of an item from the database. - - Args: - key (str): The key to retreive - - Raises: - KeyError: Key is not set - - Returns: - str: The value of the key - """ - r = self.sess.get(f"{self.db_url}/{key}") - if r.status_code == 404: - raise KeyError(key) - - r.raise_for_status() - return r.text - - def __setitem__(self, key: str, value: str) -> None: - """Set a key in the database to value. - - Args: - key (str): The key to set - value (str): The value to set it to - """ - r = self.sess.post(self.db_url, data={key: value}) - r.raise_for_status() - - def __delitem__(self, key: str) -> None: - """Delete a key from the database. - - Args: - key (str): The key to delete - """ - r = self.sess.delete(f"{self.db_url}/{key}") - r.raise_for_status() - - def keys(self, prefix: str = "") -> Tuple[str]: - """Return all of the keys in the database. - - Args: - prefix (str): The prefix the keys must start with, - blank means anything. Defaults to "". - - Returns: - Tuple[str]: The keys found. - """ - r = requests.get(f"{self.db_url}", params={"prefix": prefix}) - r.raise_for_status() - - if not r.text: - return tuple() - else: - return tuple(r.text.split("\n")) - - def to_dict(self, prefix: str = "") -> Dict[str, str]: - """Dump all data in the database into a dictionary. - - Args: - prefix (str): The prefix the keys must start with, - blank means anything. Defaults to "". - - Returns: - Dict[str, str]: All keys in the database. - """ - keys = self.keys() - data = {} - for k in keys: - data[k] = self[k] - return data - - def values(self) -> Tuple[str]: - """Get every value in the database. - - Returns: - Tuple[str]: The values in the database. - """ - data = self.to_dict() - return tuple(data.values()) - - def items(self) -> Tuple[Tuple[str]]: - """Convert the database to a dict and return the dict's items method. - - Returns: - Tuple[Tuple[str]]: The items - """ - return self.to_dict().items() - - def jsonkey( - self, - key: str, - dtype: JSON_TYPE, - get_default: Callable = None, - discard_bad_data: bool = False, - ) -> JSONKey: - """Initialize a JSONKey instance. - - A JSONKey is used to easily read and set JSON data from the database. - Arguments are the same as JSONKey constructor. - - Args: - key (str): The key to read - dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. - get_default (Callable): A function that returns the default - value if the key is not set. If it is None (the default) the dtype - argument is used. - discard_bad_data (bool): Don't prompt if bad data is read, overwrite it - with the default. Defaults to False. - - Returns: - JSONKey: The initialized JSONKey instance. - """ - return JSONKey( - db=self, - key=key, - dtype=dtype, - get_default=get_default, - discard_bad_data=discard_bad_data, - ) - - def __repr__(self) -> str: - """A representation of the database. - - Returns: - A string representation of the database object. - """ - return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" +__all__ = ["AsyncDatabase", "Database", "AsyncJSONKey", "JSONKey"] +db: Optional[Database] db_url = os.environ.get("REPLIT_DB_URL") if db_url: - db = ReplitDb(db_url) + db = Database(db_url) else: print( "Warning: REPLIT_DB_URL does not exist, are we running on repl.it? " diff --git a/src/replit/database/database.py b/src/replit/database/database.py new file mode 100644 index 0000000..00d5d7c --- /dev/null +++ b/src/replit/database/database.py @@ -0,0 +1,232 @@ +"""Async and dict-like interfaces for interacting with Repl.it Database.""" +from collections import abc +import json +from typing import Any, Dict, Iterator, Tuple +import urllib + +import aiohttp +import requests + + +class AsyncDatabase: + """Async interface for Repl.it Database.""" + + __slots__ = ("db_url", "sess") + + def __init__(self, db_url: str) -> None: + """Initialize database. You shouldn't have to do this manually. + + Args: + db_url (str): Database url to use. + """ + self.db_url = db_url + + async def get(self, key: str) -> str: + """Get the value of an item from the database. + + Args: + key (str): The key to retreive + + Raises: + KeyError: Key is not set + + Returns: + str: The value of the key + """ + async with aiohttp.ClientSession() as session: + async with session.get( + self.db_url + "/" + urllib.parse.quote(key) + ) as response: + if response.status == 404: + raise KeyError(key) + response.raise_for_status() + return await response.text() + + async def set(self, key: str, value: str) -> None: + """Set a key in the database to value. + + Args: + key (str): The key to set + value (str): The value to set it to + """ + async with aiohttp.ClientSession() as session: + async with session.post(self.db_url, data={key: value}) as response: + response.raise_for_status() + + async def delete(self, key: str) -> None: + """Delete a key from the database. + + Args: + key (str): The key to delete + """ + async with aiohttp.ClientSession() as session: + async with session.delete( + self.db_url + "/" + urllib.parse.quote(key) + ) as response: + response.raise_for_status() + + async def list(self, prefix: str) -> Tuple[str, ...]: + """List keys in the database which start with prefix. + + Args: + prefix (str): The prefix keys must start with, blank not not check. + + Returns: + Tuple[str]: The keys found. + """ + params = {"prefix": prefix, "encode": "true"} + async with aiohttp.ClientSession() as session: + async with session.get(self.db_url, params=params) as response: + response.raise_for_status() + text = await response.text() + if not text: + return tuple() + else: + return tuple(urllib.parse.unquote(k) for k in text.split("\n")) + + async def to_dict(self, prefix: str = "") -> Dict[str, str]: + """Dump all data in the database into a dictionary. + + Args: + prefix (str): The prefix the keys must start with, + blank means anything. Defaults to "". + + Returns: + Dict[str, str]: All keys in the database. + """ + ret = {} + keys = await self.list(prefix=prefix) + for i in keys: + ret[i] = await self.get(i) + return ret + + async def keys(self) -> Tuple[str, ...]: + """Get all keys in the database. + + Returns: + Tuple[str]: The keys in the database. + """ + return await self.list("") + + async def values(self) -> Tuple[str, ...]: + """Get every value in the database. + + Returns: + Tuple[str]: The values in the database. + """ + data = await self.to_dict() + return tuple(data.values()) + + async def items(self) -> Tuple[Tuple[str, str], ...]: + """Convert the database to a dict and return the dict's items method. + + Returns: + Tuple[Tuple[str]]: The items + """ + return tuple((await self.to_dict()).items()) + + def __repr__(self) -> str: + """A representation of the database. + + Returns: + A string representation of the database object. + """ + return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" + + +class Database(abc.MutableMapping): + """Dictionary-like interface for Repl.it Database. + + This interface will coerce all values everything to and from JSON. If you + don't want this, use AsyncDatabase instead. + """ + + __slots__ = ("db_url", "sess") + + def __init__(self, db_url: str) -> None: + """Initialize database. You shouldn't have to do this manually. + + Args: + db_url (str): Database url to use. + """ + self.db_url = db_url + self.sess = requests.Session() + + def __getitem__(self, key: str) -> Any: + """Get the value of an item from the database. + + Args: + key (str): The key to retreive + + Raises: + KeyError: Key is not set + + Returns: + Any: The value of the key + """ + r = self.sess.get(f"{self.db_url}/{key}") + if r.status_code == 404: + raise KeyError(key) + + r.raise_for_status() + return json.loads(r.text) + + def __setitem__(self, key: str, value: Any) -> None: + """Set a key in the database to value. + + Args: + key (str): The key to set + value (Any): The value to set it to. Must be JSON-serializable. + """ + j = json.dumps(value, separators=(",", ":")) + r = self.sess.post(self.db_url, data={key: j}) + r.raise_for_status() + + def __delitem__(self, key: str) -> None: + """Delete a key from the database. + + Args: + key (str): The key to delete + + Raises: + KeyError: Key is not set + """ + r = self.sess.delete(f"{self.db_url}/{key}") + if r.status_code == 404: + raise KeyError(key) + + r.raise_for_status() + + def __iter__(self) -> Iterator[str]: + """Return an iterator for the database.""" + return iter(self.prefix("")) + + def __len__(self) -> int: + """The number of keys in the database.""" + return len(self.prefix("")) + + def prefix(self, prefix: str) -> Tuple[str, ...]: + """Return all of the keys in the database that begin with the prefix. + + Args: + prefix (str): The prefix the keys must start with, + blank means anything. + + Returns: + Tuple[str]: The keys found. + """ + r = requests.get(f"{self.db_url}", params={"prefix": prefix, "encode": "true"}) + r.raise_for_status() + + if not r.text: + return tuple() + else: + return tuple(urllib.parse.unquote(k) for k in r.text.split("\n")) + + def __repr__(self) -> str: + """A representation of the database. + + Returns: + A string representation of the database object. + """ + return f"<{self.__class__.__name__}(db_url={self.db_url!r})>" diff --git a/src/replit/database/jsonkey.py b/src/replit/database/jsonkey.py new file mode 100644 index 0000000..9921889 --- /dev/null +++ b/src/replit/database/jsonkey.py @@ -0,0 +1,362 @@ +# flake8: noqa +from typing import Any, Callable, Dict, Optional, Tuple, Union +import json + +JSON_TYPE = Optional[Union[str, int, float, bool, dict, list]] + + +class AsyncJSONKey: + """Represents an key in the async database that holds a JSON value. + db.jsonkey() will initialize an instance for you, + you don't have to do it manually. + """ + + __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise") + + def __init__( + self, + db: Any, + key: str, + dtype: JSON_TYPE, + get_default: Callable[[], JSON_TYPE] = None, + discard_bad_data: bool = False, + do_raise: bool = False, + ) -> None: + """Initialize the key. + Args: + db (Any): An instance of ReplitDb + key (str): The key to read + dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. + get_default (Callable): A function that returns the default + value if the key is not set. If it is None (the default) the dtype + argument is used. + discard_bad_data (bool): Don't prompt if bad data is read, overwrite it + with the default. Defaults to False. + do_raise (bool): Whether to raise exceptions when errors are encountered. + """ + self.db = db + self.key = key + self.dtype = dtype + self.get_default = get_default + self.discard_bad_data = discard_bad_data + self.do_raise = do_raise + + def _default(self) -> JSON_TYPE: + if self.get_default: + return self.get_default() + return self.dtype + + def _is_valid_type(self, data: JSON_TYPE) -> bool: + return self.dtype is Any or isinstance(data, self.dtype) + + def _type_mismatch_msg(self, data: Any) -> str: + return ( + f"Type mismatch: Got type {type(data).__name__}," + f"expected {self.dtype.__name__}" + ) + + async def get(self) -> JSON_TYPE: + """Get the value of the key. + If an invalid JSON value is read or the type does not match, it will show a + prompt asking the user what to do unless discard_bad_data is set. + Raises: + KeyError: If do_raise is true and the key does not exist. + json.JSONDecodeError: If do_raise is true and invalid JSON data is read + from the key. + Returns: + JSON_TYPE: The value read from the database + """ + try: + read = await self.db.get(self.key) + except KeyError: + if self.do_raise: + raise + print(f"Database key {self.key} not set, setting it to default value") + default = self._default() + await self.db.set(self.key, default) + return default + + try: + data = json.loads(read) + except json.JSONDecodeError: + if self.do_raise: + raise + return await self._error("Invalid JSON data read", read) + + if not self._is_valid_type(data): + return await self._error(self._type_mismatch_msg(data), read,) + return data + + async def _error(self, error: str, read: str) -> JSON_TYPE: + if self.do_raise: + raise ValueError(error) + + print(f"Error reading key {self.key!r}: {error}", file=stderr) + if self.discard_bad_data: + val = self._default() + await self.db.set(self.key, json.dumps(val)) + print(f"Wrote default to key {self.key!r}") + return val + return await self._should_discard_prompt(error, read) + + async def _should_discard_prompt(self, error: str, read: str) -> bool: + while True: + choice = input( + "d to use default, v to view the invalid data, c to insert custom " + "value, ^C to exit: " + ) + if choice.startswith("d"): + print("Writing default...") + val = self._default() + await self.db.set(self.key, val) + return val + elif choice.startswith("v"): + print(f"Data read from key: {read!r}") + elif choice.startswith("c"): + toset = input( + f"Enter data to write, should be of type {self.dtype.__name__!r}" + " (leave blank to return to menu): " + ) + if not toset: + continue + try: + data = json.loads(toset) + except json.JSONDecodeError: + print("Invalid JSON data!") + else: + if not self._is_valid_type(data): + print(self._type_mismatch_msg(data)) + continue + + await self.db.set(self.key, toset) + print("Wrote data to key") + return data + + async def set(self, data: JSON_TYPE) -> None: + """Set the value of the jsonkey. + Args: + data (JSON_TYPE): The value to set it to. + Raises: + TypeError: The type of the value set does not match the datatype. + """ + if not self._is_valid_type(data): + raise TypeError(self._type_mismatch_msg(data)) + + await self.db.set(self.key, json.dumps(data)) + + +class JSONKey(AsyncJSONKey): + """Represents a key in the database that holds a JSON value. + + db.jsonkey() will initialize an instance for you, + you don't have to do it manually. + """ + + __slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise") + + def __init__( + self, + db: Any, + key: str, + dtype: JSON_TYPE, + get_default: Callable = None, + discard_bad_data: bool = False, + do_raise: bool = False, + ) -> None: + """Initialize the key. + + Args: + db (Any): An instance of ReplitDb + key (str): The key to read + dtype (JSON_TYPE): The datatype the key should be, can be typing.Any. + get_default (Callable): A function that returns the default + value if the key is not set. If it is None (the default) the dtype + argument is used. + discard_bad_data (bool): Don't prompt if bad data is read, overwrite it + with the default. Defaults to False. + do_raise (bool): Whether to raise exceptions when errors are encountered. + """ + self.db = db + self.key = key + self.dtype = dtype + self.get_default = get_default + self.discard_bad_data = discard_bad_data + self.do_raise = do_raise + + def get(self) -> JSON_TYPE: + """Get the value of the key. + + If an invalid JSON value is read or the type does not match, it will show a + prompt asking the user what to do unless discard_bad_data is set. + + Returns: + JSON_TYPE: The value read from the database + """ + try: + read = self.db[self.key] + except KeyError: + print(f"Database key {self.key} not set, setting it to default value") + default = self._default() + self.db[self.key] = default + return default + + if isinstance(self.db, ReplitDb): + try: + data = json.loads(read) + except json.JSONDecodeError: + return self._error("Invalid JSON data read", read) + else: + data = read + + if not self._is_valid_type(data): + return self._error(self._type_mismatch_msg(data), read,) + return data + + def _error(self, error: str, read: str) -> JSON_TYPE: + print(f"Error reading key {self.key!r}: {error}", file=stderr) + if self.discard_bad_data: + val = self._default() + self.db[self.key] = json.dumps(val) + print(f"Wrote default to key {self.key!r}") + return val + return self._should_discard_prompt(error, read) + + def _should_discard_prompt(self, error: str, read: str) -> bool: + while True: + choice = input( + "d to use default, v to view the invalid data, c to insert custom " + "value, ^C to exit: " + ) + if choice.startswith("d"): + print("Writing default...") + val = self._default() + self.db[self.key] = val + return val + elif choice.startswith("v"): + print(f"Data read from key: {read!r}") + elif choice.startswith("c"): + toset = input( + f"Enter data to write, should be of type {self.dtype.__name__!r}" + " (leave blank to return to menu): " + ) + if not toset: + continue + try: + data = json.loads(toset) + except json.JSONDecodeError: + print("Invalid JSON data!") + else: + if not self._is_valid_type(data): + print(self._type_mismatch_msg(data)) + continue + + self.db[self.key] = toset + print("Wrote data to key") + return data + + def set(self, data: JSON_TYPE) -> None: + """Set the value of the jsonkey. + + Args: + data (JSON_TYPE): The value to set it to. + + Raises: + TypeError: The type of the value set does not match the datatype. + """ + if not self._is_valid_type(data): + raise TypeError(self._type_mismatch_msg(data)) + if isinstance(self.db, ReplitDb): + data = json.dumps(data) + self.db[self.key] = data + + def read(self, key: str, default: Any = None) -> Any: + """Shorthand for self.get().get(name, default) if datatype is dict. + + Args: + key (str): The name to get. + default (Any): The default if the key doesn't exist. Defaults to None. + + Returns: + Any: The value read or the default. + """ + data = self.get() + if not isinstance(data, dict): + raise TypeError + return data.get(key, default) + + def keys(self, *keys: str) -> Any: + """Reads multiple keys from the key's value and allows setting. + + Args: + *keys (str): The keys to read from the data. + + Returns: + Any: The value accessed from self.get()[k1][k2][kn] + """ + data = self + for key in keys[:-1]: + data = type(self)(db=data, key=key, dtype=Any) + check = data[keys[-1]] + if type(check) is dict: + return type(self)(db=data, key=keys[-1], dtype=dict) + else: + return check + + def __getitem__(self, name: str) -> JSON_TYPE: + """Retrieve a key from the JSONKey's value if it is a dict. + + Args: + name (str): The name to retrieve. + + Returns: + JSON_TYPE: The value of the key. + """ + return self.keys(name) + + def __setitem__(self, name: str, value: JSON_TYPE) -> None: + """Sets a key inside the JSONKey's value if it is a dict. + + Args: + name (str): The key to set. + value (JSON_TYPE): The value to set it to. + """ + data = self.get() + if not isinstance(data, dict): + raise TypeError + data[name] = value + self.set(data) + + def append(self, item: JSON_TYPE) -> None: + """Append to the JSONKey's value if it is a list. + + Args: + item (JSON_TYPE): The item to append. + """ + data = self.get() + if not isinstance(data, list): + raise TypeError + self.set(data + [item]) + + def __add__(self, item: Any) -> Any: + """Add to the JSONKey's value and return the result. + + Args: + item (Any): The item to add. + + Returns: + Any: The result of adding. + """ + return self.get() + item + + def __iadd__(self, item: Any) -> Any: + """Add to the JSONKey's value and set the result. + + Args: + item (Any): The item to add. + + Returns: + Any: self + """ + r = self.get() + item + self.set(r) + return self diff --git a/src/replit/maqpy/__init__.py b/src/replit/maqpy/__init__.py index c422e83..36ab163 100644 --- a/src/replit/maqpy/__init__.py +++ b/src/replit/maqpy/__init__.py @@ -20,7 +20,7 @@ from .utils import ( sign_in_page, sign_in_snippet, ) -from ..database import AsyncJSONKey, AsyncReplitDb, db, JSONKey, ReplitDb +from ..database import AsyncDatabase, AsyncJSONKey, Database, db, JSONKey auth = LocalProxy(lambda: flask.request.auth) signed_in = LocalProxy(lambda: flask.request.signed_in) diff --git a/tests/test_database.py b/tests/test_database.py index a2ae862..8974a91 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,12 +3,12 @@ import os import unittest -from replit.database import AsyncReplitDb, ReplitDb +from replit.database import AsyncDatabase, AsyncJSONKey, Database import requests class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase): - """Tests for replit.database.AsyncReplitDb.""" + """Tests for replit.database.AsyncDatabase.""" async def asyncSetUp(self) -> None: """Grab a JWT for all the tests to share.""" @@ -17,7 +17,7 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase): "https://database-test-jwt.kochman.repl.co", auth=("test", password) ) url = req.text - self.db = AsyncReplitDb(url) + self.db = AsyncDatabase(url) # nuke whatever is already here for k in await self.db.keys(): @@ -39,6 +39,18 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase): with self.assertRaises(KeyError): await self.db.get("test-key") + async def test_get_set_delete_newline(self) -> None: + """Test that we can get, set, and delete a key with newline.""" + key = "test-key-with\nnewline" + await self.db.set(key, "value") + + val = await self.db.get(key) + self.assertEqual(val, "value") + + await self.db.delete(key) + with self.assertRaises(KeyError): + await self.db.get(key) + async def test_list_keys(self) -> None: """Test that we can list keys.""" key = "test-list-keys-with\nnewline" @@ -80,7 +92,7 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase): """Test replit.database.AsyncJSONKey.""" key = "test-jsonkey" - jk = self.db.jsonkey(key, dtype=str, do_raise=True) + jk = AsyncJSONKey(db=self.db, key=key, dtype=str, do_raise=True) with self.assertRaises(KeyError): await jk.get() await jk.set("value") @@ -91,13 +103,13 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase): """Test replit.database.AsyncJSONKey with a default callable.""" key = "test-jsonkey" - jk = self.db.jsonkey(key, dtype=str, get_default=lambda: "value") + jk = AsyncJSONKey(db=self.db, key=key, dtype=str, get_default=lambda: "value") val = await jk.get() self.assertEqual(val, "value") -class TestDatabase(unittest.IsolatedAsyncioTestCase): - """Tests for replit.database.ReplitDb.""" +class TestDatabase(unittest.TestCase): + """Tests for replit.database.Database.""" def setUp(self) -> None: """Grab a JWT for all the tests to share.""" @@ -106,7 +118,7 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase): "https://database-test-jwt.kochman.repl.co", auth=("test", password) ) url = req.text - self.db = ReplitDb(url) + self.db = Database(url) # nuke whatever is already here for k in self.db.keys(): @@ -115,7 +127,7 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase): def tearDown(self) -> None: """Nuke whatever the test added.""" for k in self.db.keys(): - self.db.delete(k) + del self.db[k] def test_get_set_delete(self) -> None: """Test get, set, and delete.""" @@ -129,3 +141,39 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase): del self.db["key"] with self.assertRaises(KeyError): val = self.db["key"] + + def test_list_keys(self) -> None: + """Test that we can list keys.""" + key = "test-list-keys-with\nnewline" + self.db[key] = "value" + + val = self.db[key] + self.assertEqual(val, "value") + + keys = self.db.prefix(key) + self.assertEqual(keys, (key,)) + + keys = self.db.keys() + self.assertTupleEqual(tuple(keys), (key,)) + + # just to make sure... + self.assertTupleEqual(tuple(self.db.keys()), self.db.prefix("")) + + del self.db[key] + with self.assertRaises(KeyError): + val = self.db[key] + + def test_delete_nonexistent_key(self) -> None: + """Test that deleting a non-existent key returns 404.""" + key = "this-doesn't-exist" + with self.assertRaises(KeyError): + self.db[key] + + def test_get_set_fancy_object(self) -> None: + """Test that we can get/set/delete something that's more than a string.""" + key = "big-ol-list" + val = ["this", {"is": "a", "complex": "object"}, 1337] + + self.db[key] = val + act = self.db[key] + self.assertEqual(act, val)