Database refactor (#15)

This commit is contained in:
Sidney Kochman
2020-08-19 10:55:11 -04:00
committed by GitHub
parent e1a33cbe31
commit f447c4d35e
6 changed files with 659 additions and 676 deletions
+1 -1
View File
@@ -31,4 +31,4 @@ blocks:
- checkout --use-cache
- python -m pip install --upgrade poetry
- poetry install
- poetry run mypy src/replit || true
- bash -c '! poetry run mypy src/replit | grep database.py'
+6 -665
View File
@@ -1,677 +1,18 @@
"""Interface with the Replit Database."""
import json
import os
from sys import stderr
from typing import Any, Callable, Dict, Optional, Tuple, Union
import urllib
from typing import Optional
import aiohttp
import requests
from .database import AsyncDatabase, Database
from .jsonkey import AsyncJSONKey, JSONKey
JSON_TYPE = Optional[Union[str, int, float, bool, dict, list]]
class AsyncJSONKey:
"""Represents an key in the async database that holds a JSON value.
db.jsonkey() will initialize an instance for you,
you don't have to do it manually.
"""
__slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise")
def __init__(
self,
db: Any,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> None:
"""Initialize the key.
Args:
db (Any): An instance of ReplitDb
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
"""
self.db = db
self.key = key
self.dtype = dtype
self.get_default = get_default
self.discard_bad_data = discard_bad_data
self.do_raise = do_raise
def _default(self) -> JSON_TYPE:
get_default_func = self.get_default or self.dtype
return get_default_func()
def _is_valid_type(self, data: JSON_TYPE) -> bool:
return self.dtype is Any or isinstance(data, self.dtype)
def _type_mismatch_msg(self, data: Any) -> str:
return (
f"Type mismatch: Got type {type(data).__name__},"
f"expected {self.dtype.__name__}"
)
async def get(self) -> JSON_TYPE:
"""Get the value of the key.
If an invalid JSON value is read or the type does not match, it will show a
prompt asking the user what to do unless discard_bad_data is set.
Raises:
KeyError: If do_raise is true and the key does not exist.
json.JSONDecodeError: If do_raise is true and invalid JSON data is read
from the key.
Returns:
JSON_TYPE: The value read from the database
"""
try:
read = await self.db.get(self.key)
except KeyError:
if self.do_raise:
raise
print(f"Database key {self.key} not set, setting it to default value")
default = self._default()
await self.db.set(self.key, default)
return default
try:
data = json.loads(read)
except json.JSONDecodeError:
if self.do_raise:
raise
return await self._error("Invalid JSON data read", read)
if not self._is_valid_type(data):
return await self._error(self._type_mismatch_msg(data), read,)
return data
async def _error(self, error: str, read: str) -> JSON_TYPE:
if self.do_raise:
raise ValueError(error)
print(f"Error reading key {self.key!r}: {error}", file=stderr)
if self.discard_bad_data:
val = self._default()
await self.db.set(self.key, json.dumps(val))
print(f"Wrote default to key {self.key!r}")
return val
return await self._should_discard_prompt(error, read)
async def _should_discard_prompt(self, error: str, read: str) -> bool:
while True:
choice = input(
"d to use default, v to view the invalid data, c to insert custom "
"value, ^C to exit: "
)
if choice.startswith("d"):
print("Writing default...")
val = self._default()
await self.db.set(self.key, val)
return val
elif choice.startswith("v"):
print(f"Data read from key: {read!r}")
elif choice.startswith("c"):
toset = input(
f"Enter data to write, should be of type {self.dtype.__name__!r}"
" (leave blank to return to menu): "
)
if not toset:
continue
try:
data = json.loads(toset)
except json.JSONDecodeError:
print("Invalid JSON data!")
else:
if not self._is_valid_type(data):
print(self._type_mismatch_msg(data))
continue
await self.db.set(self.key, toset)
print("Wrote data to key")
return data
async def set(self, data: JSON_TYPE) -> None:
"""Set the value of the jsonkey.
Args:
data (JSON_TYPE): The value to set it to.
Raises:
TypeError: The type of the value set does not match the datatype.
"""
if not self._is_valid_type(data):
raise TypeError(self._type_mismatch_msg(data))
await self.db.set(self.key, json.dumps(data))
class AsyncReplitDb:
"""Async client interface with the Replit Database."""
__slots__ = ("db_url", "sess")
def __init__(self, db_url: str) -> None:
"""Initialize database. You shouldn't have to do this manually.
Args:
db_url (str): Database url to use.
"""
self.db_url = db_url
async def get(self, key: str) -> str:
"""Get the value of an item from the database.
Args:
key (str): The key to retreive
Raises:
KeyError: Key is not set
Returns:
str: The value of the key
"""
async with aiohttp.ClientSession() as session:
async with session.get(self.db_url + "/" + key) as response:
if response.status == 404:
raise KeyError(key)
response.raise_for_status()
return await response.text()
async def set(self, key: str, value: str) -> None:
"""Set a key in the database to value.
Args:
key (str): The key to set
value (str): The value to set it to
"""
async with aiohttp.ClientSession() as session:
async with session.post(self.db_url, data={key: value}) as response:
response.raise_for_status()
async def delete(self, key: str) -> None:
"""Delete a key from the database.
Args:
key (str): The key to delete
"""
async with aiohttp.ClientSession() as session:
async with session.delete(self.db_url + "/" + key) as response:
response.raise_for_status()
async def list(self, prefix: str) -> Tuple[str]:
"""List keys in the database which start with prefix.
Args:
prefix (str): The prefix keys must start with, blank not not check.
Returns:
Tuple[str]: The keys found.
"""
params = {"prefix": prefix, "encode": "true"}
async with aiohttp.ClientSession() as session:
async with session.get(self.db_url, params=params) as response:
response.raise_for_status()
text = await response.text()
if not text:
return tuple()
else:
return tuple(urllib.parse.unquote(k) for k in text.split("\n"))
async def to_dict(self, prefix: str = "") -> Dict[str, str]:
"""Dump all data in the database into a dictionary.
Args:
prefix (str): The prefix the keys must start with,
blank means anything. Defaults to "".
Returns:
Dict[str, str]: All keys in the database.
"""
ret = {}
keys = await self.list(prefix=prefix)
for i in keys:
ret[i] = await self.get(i)
return ret
async def keys(self) -> Tuple[str]:
"""Get all keys in the database.
Returns:
Tuple[str]: The keys in the database.
"""
return await self.list("")
async def values(self) -> Tuple[str]:
"""Get every value in the database.
Returns:
Tuple[str]: The values in the database.
"""
data = await self.to_dict()
return tuple(data.values())
async def items(self) -> Tuple[Tuple[str]]:
"""Convert the database to a dict and return the dict's items method.
Returns:
Tuple[Tuple[str]]: The items
"""
return (await self.to_dict()).items()
def jsonkey(
self,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> AsyncJSONKey:
"""Initialize an AsyncJSONKey instance.
A AsyncJSONKey is used to easily read and set JSON data from the database.
Arguments are the same as AsyncJSONKey constructor.
Args:
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
Returns:
AsyncJSONKey: The initialized AsyncJSONKey instance.
"""
return AsyncJSONKey(
db=self,
key=key,
dtype=dtype,
get_default=get_default,
discard_bad_data=discard_bad_data,
do_raise=do_raise,
)
def __repr__(self) -> str:
"""A representation of the database.
Returns:
A string representation of the database object.
"""
return f"<{self.__class__.__name__}(db_url={self.db_url!r})>"
class JSONKey(AsyncJSONKey):
"""Represents a key in the database that holds a JSON value.
db.jsonkey() will initialize an instance for you,
you don't have to do it manually.
"""
__slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise")
def __init__(
self,
db: Any,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> None:
"""Initialize the key.
Args:
db (Any): An instance of ReplitDb
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
"""
self.db = db
self.key = key
self.dtype = dtype
self.get_default = get_default
self.discard_bad_data = discard_bad_data
self.do_raise = do_raise
def get(self) -> JSON_TYPE:
"""Get the value of the key.
If an invalid JSON value is read or the type does not match, it will show a
prompt asking the user what to do unless discard_bad_data is set.
Returns:
JSON_TYPE: The value read from the database
"""
try:
read = self.db[self.key]
except KeyError:
print(f"Database key {self.key} not set, setting it to default value")
default = self._default()
self.db[self.key] = default
return default
if isinstance(self.db, ReplitDb):
try:
data = json.loads(read)
except json.JSONDecodeError:
return self._error("Invalid JSON data read", read)
else:
data = read
if not self._is_valid_type(data):
return self._error(self._type_mismatch_msg(data), read,)
return data
def _error(self, error: str, read: str) -> JSON_TYPE:
print(f"Error reading key {self.key!r}: {error}", file=stderr)
if self.discard_bad_data:
val = self._default()
self.db[self.key] = json.dumps(val)
print(f"Wrote default to key {self.key!r}")
return val
return self._should_discard_prompt(error, read)
def _should_discard_prompt(self, error: str, read: str) -> bool:
while True:
choice = input(
"d to use default, v to view the invalid data, c to insert custom "
"value, ^C to exit: "
)
if choice.startswith("d"):
print("Writing default...")
val = self._default()
self.db[self.key] = val
return val
elif choice.startswith("v"):
print(f"Data read from key: {read!r}")
elif choice.startswith("c"):
toset = input(
f"Enter data to write, should be of type {self.dtype.__name__!r}"
" (leave blank to return to menu): "
)
if not toset:
continue
try:
data = json.loads(toset)
except json.JSONDecodeError:
print("Invalid JSON data!")
else:
if not self._is_valid_type(data):
print(self._type_mismatch_msg(data))
continue
self.db[self.key] = toset
print("Wrote data to key")
return data
def set(self, data: JSON_TYPE) -> None:
"""Set the value of the jsonkey.
Args:
data (JSON_TYPE): The value to set it to.
Raises:
TypeError: The type of the value set does not match the datatype.
"""
if not self._is_valid_type(data):
raise TypeError(self._type_mismatch_msg(data))
if isinstance(self.db, ReplitDb):
data = json.dumps(data)
self.db[self.key] = data
def read(self, key: str, default: Any = None) -> Any:
"""Shorthand for self.get().get(name, default) if datatype is dict.
Args:
key (str): The name to get.
default (Any): The default if the key doesn't exist. Defaults to None.
Returns:
Any: The value read or the default.
"""
return self.get().get(key, default)
def keys(self, *keys: str) -> Any:
"""Reads multiple keys from the key's value and allows setting.
Args:
*keys (str): The keys to read from the data.
Returns:
Any: The value accessed from self.get()[k1][k2][kn]
"""
data = self
for key in keys[:-1]:
data = type(self)(db=data, key=key, dtype=Any)
check = data[keys[-1]]
if type(check) is dict:
return type(self)(db=data, key=keys[-1], dtype=dict)
else:
return check
def __getitem__(self, name: str) -> JSON_TYPE:
"""Retrieve a key from the JSONKey's value if it is a dict.
Args:
name (str): The name to retrieve.
Returns:
JSON_TYPE: The value of the key.
"""
return self.keys(name)
def __setitem__(self, name: str, value: JSON_TYPE) -> None:
"""Sets a key inside the JSONKey's value if it is a dict.
Args:
name (str): The key to set.
value (JSON_TYPE): The value to set it to.
"""
data = self.get()
data[name] = value
self.set(data)
def append(self, item: JSON_TYPE) -> None:
"""Append to the JSONKey's value if it is a list.
Args:
item (JSON_TYPE): The item to append.
"""
data = self.get()
self.set(data + [item])
def __add__(self, item: Any) -> Any:
"""Add to the JSONKey's value and return the result.
Args:
item (Any): The item to add.
Returns:
Any: The result of adding.
"""
return self.get() + item
def __iadd__(self, item: Any) -> Any:
"""Add to the JSONKey's value and set the result.
Args:
item (Any): The item to add.
Returns:
Any: self
"""
r = self.get() + item
self.set(r)
return self
class ReplitDb(AsyncReplitDb):
"""Interface with the Replit Database."""
__slots__ = ("db_url", "sess")
def __init__(self, db_url: str) -> None:
"""Initialize database. You shouldn't have to do this manually.
Args:
db_url (str): Database url to use.
"""
self.db_url = db_url
self.sess = requests.Session()
def __getitem__(self, key: str) -> str:
"""Get the value of an item from the database.
Args:
key (str): The key to retreive
Raises:
KeyError: Key is not set
Returns:
str: The value of the key
"""
r = self.sess.get(f"{self.db_url}/{key}")
if r.status_code == 404:
raise KeyError(key)
r.raise_for_status()
return r.text
def __setitem__(self, key: str, value: str) -> None:
"""Set a key in the database to value.
Args:
key (str): The key to set
value (str): The value to set it to
"""
r = self.sess.post(self.db_url, data={key: value})
r.raise_for_status()
def __delitem__(self, key: str) -> None:
"""Delete a key from the database.
Args:
key (str): The key to delete
"""
r = self.sess.delete(f"{self.db_url}/{key}")
r.raise_for_status()
def keys(self, prefix: str = "") -> Tuple[str]:
"""Return all of the keys in the database.
Args:
prefix (str): The prefix the keys must start with,
blank means anything. Defaults to "".
Returns:
Tuple[str]: The keys found.
"""
r = requests.get(f"{self.db_url}", params={"prefix": prefix})
r.raise_for_status()
if not r.text:
return tuple()
else:
return tuple(r.text.split("\n"))
def to_dict(self, prefix: str = "") -> Dict[str, str]:
"""Dump all data in the database into a dictionary.
Args:
prefix (str): The prefix the keys must start with,
blank means anything. Defaults to "".
Returns:
Dict[str, str]: All keys in the database.
"""
keys = self.keys()
data = {}
for k in keys:
data[k] = self[k]
return data
def values(self) -> Tuple[str]:
"""Get every value in the database.
Returns:
Tuple[str]: The values in the database.
"""
data = self.to_dict()
return tuple(data.values())
def items(self) -> Tuple[Tuple[str]]:
"""Convert the database to a dict and return the dict's items method.
Returns:
Tuple[Tuple[str]]: The items
"""
return self.to_dict().items()
def jsonkey(
self,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
) -> JSONKey:
"""Initialize a JSONKey instance.
A JSONKey is used to easily read and set JSON data from the database.
Arguments are the same as JSONKey constructor.
Args:
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
Returns:
JSONKey: The initialized JSONKey instance.
"""
return JSONKey(
db=self,
key=key,
dtype=dtype,
get_default=get_default,
discard_bad_data=discard_bad_data,
)
def __repr__(self) -> str:
"""A representation of the database.
Returns:
A string representation of the database object.
"""
return f"<{self.__class__.__name__}(db_url={self.db_url!r})>"
__all__ = ["AsyncDatabase", "Database", "AsyncJSONKey", "JSONKey"]
db: Optional[Database]
db_url = os.environ.get("REPLIT_DB_URL")
if db_url:
db = ReplitDb(db_url)
db = Database(db_url)
else:
print(
"Warning: REPLIT_DB_URL does not exist, are we running on repl.it? "
+232
View File
@@ -0,0 +1,232 @@
"""Async and dict-like interfaces for interacting with Repl.it Database."""
from collections import abc
import json
from typing import Any, Dict, Iterator, Tuple
import urllib
import aiohttp
import requests
class AsyncDatabase:
"""Async interface for Repl.it Database."""
__slots__ = ("db_url", "sess")
def __init__(self, db_url: str) -> None:
"""Initialize database. You shouldn't have to do this manually.
Args:
db_url (str): Database url to use.
"""
self.db_url = db_url
async def get(self, key: str) -> str:
"""Get the value of an item from the database.
Args:
key (str): The key to retreive
Raises:
KeyError: Key is not set
Returns:
str: The value of the key
"""
async with aiohttp.ClientSession() as session:
async with session.get(
self.db_url + "/" + urllib.parse.quote(key)
) as response:
if response.status == 404:
raise KeyError(key)
response.raise_for_status()
return await response.text()
async def set(self, key: str, value: str) -> None:
"""Set a key in the database to value.
Args:
key (str): The key to set
value (str): The value to set it to
"""
async with aiohttp.ClientSession() as session:
async with session.post(self.db_url, data={key: value}) as response:
response.raise_for_status()
async def delete(self, key: str) -> None:
"""Delete a key from the database.
Args:
key (str): The key to delete
"""
async with aiohttp.ClientSession() as session:
async with session.delete(
self.db_url + "/" + urllib.parse.quote(key)
) as response:
response.raise_for_status()
async def list(self, prefix: str) -> Tuple[str, ...]:
"""List keys in the database which start with prefix.
Args:
prefix (str): The prefix keys must start with, blank not not check.
Returns:
Tuple[str]: The keys found.
"""
params = {"prefix": prefix, "encode": "true"}
async with aiohttp.ClientSession() as session:
async with session.get(self.db_url, params=params) as response:
response.raise_for_status()
text = await response.text()
if not text:
return tuple()
else:
return tuple(urllib.parse.unquote(k) for k in text.split("\n"))
async def to_dict(self, prefix: str = "") -> Dict[str, str]:
"""Dump all data in the database into a dictionary.
Args:
prefix (str): The prefix the keys must start with,
blank means anything. Defaults to "".
Returns:
Dict[str, str]: All keys in the database.
"""
ret = {}
keys = await self.list(prefix=prefix)
for i in keys:
ret[i] = await self.get(i)
return ret
async def keys(self) -> Tuple[str, ...]:
"""Get all keys in the database.
Returns:
Tuple[str]: The keys in the database.
"""
return await self.list("")
async def values(self) -> Tuple[str, ...]:
"""Get every value in the database.
Returns:
Tuple[str]: The values in the database.
"""
data = await self.to_dict()
return tuple(data.values())
async def items(self) -> Tuple[Tuple[str, str], ...]:
"""Convert the database to a dict and return the dict's items method.
Returns:
Tuple[Tuple[str]]: The items
"""
return tuple((await self.to_dict()).items())
def __repr__(self) -> str:
"""A representation of the database.
Returns:
A string representation of the database object.
"""
return f"<{self.__class__.__name__}(db_url={self.db_url!r})>"
class Database(abc.MutableMapping):
"""Dictionary-like interface for Repl.it Database.
This interface will coerce all values everything to and from JSON. If you
don't want this, use AsyncDatabase instead.
"""
__slots__ = ("db_url", "sess")
def __init__(self, db_url: str) -> None:
"""Initialize database. You shouldn't have to do this manually.
Args:
db_url (str): Database url to use.
"""
self.db_url = db_url
self.sess = requests.Session()
def __getitem__(self, key: str) -> Any:
"""Get the value of an item from the database.
Args:
key (str): The key to retreive
Raises:
KeyError: Key is not set
Returns:
Any: The value of the key
"""
r = self.sess.get(f"{self.db_url}/{key}")
if r.status_code == 404:
raise KeyError(key)
r.raise_for_status()
return json.loads(r.text)
def __setitem__(self, key: str, value: Any) -> None:
"""Set a key in the database to value.
Args:
key (str): The key to set
value (Any): The value to set it to. Must be JSON-serializable.
"""
j = json.dumps(value, separators=(",", ":"))
r = self.sess.post(self.db_url, data={key: j})
r.raise_for_status()
def __delitem__(self, key: str) -> None:
"""Delete a key from the database.
Args:
key (str): The key to delete
Raises:
KeyError: Key is not set
"""
r = self.sess.delete(f"{self.db_url}/{key}")
if r.status_code == 404:
raise KeyError(key)
r.raise_for_status()
def __iter__(self) -> Iterator[str]:
"""Return an iterator for the database."""
return iter(self.prefix(""))
def __len__(self) -> int:
"""The number of keys in the database."""
return len(self.prefix(""))
def prefix(self, prefix: str) -> Tuple[str, ...]:
"""Return all of the keys in the database that begin with the prefix.
Args:
prefix (str): The prefix the keys must start with,
blank means anything.
Returns:
Tuple[str]: The keys found.
"""
r = requests.get(f"{self.db_url}", params={"prefix": prefix, "encode": "true"})
r.raise_for_status()
if not r.text:
return tuple()
else:
return tuple(urllib.parse.unquote(k) for k in r.text.split("\n"))
def __repr__(self) -> str:
"""A representation of the database.
Returns:
A string representation of the database object.
"""
return f"<{self.__class__.__name__}(db_url={self.db_url!r})>"
+362
View File
@@ -0,0 +1,362 @@
# flake8: noqa
from typing import Any, Callable, Dict, Optional, Tuple, Union
import json
JSON_TYPE = Optional[Union[str, int, float, bool, dict, list]]
class AsyncJSONKey:
"""Represents an key in the async database that holds a JSON value.
db.jsonkey() will initialize an instance for you,
you don't have to do it manually.
"""
__slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise")
def __init__(
self,
db: Any,
key: str,
dtype: JSON_TYPE,
get_default: Callable[[], JSON_TYPE] = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> None:
"""Initialize the key.
Args:
db (Any): An instance of ReplitDb
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
"""
self.db = db
self.key = key
self.dtype = dtype
self.get_default = get_default
self.discard_bad_data = discard_bad_data
self.do_raise = do_raise
def _default(self) -> JSON_TYPE:
if self.get_default:
return self.get_default()
return self.dtype
def _is_valid_type(self, data: JSON_TYPE) -> bool:
return self.dtype is Any or isinstance(data, self.dtype)
def _type_mismatch_msg(self, data: Any) -> str:
return (
f"Type mismatch: Got type {type(data).__name__},"
f"expected {self.dtype.__name__}"
)
async def get(self) -> JSON_TYPE:
"""Get the value of the key.
If an invalid JSON value is read or the type does not match, it will show a
prompt asking the user what to do unless discard_bad_data is set.
Raises:
KeyError: If do_raise is true and the key does not exist.
json.JSONDecodeError: If do_raise is true and invalid JSON data is read
from the key.
Returns:
JSON_TYPE: The value read from the database
"""
try:
read = await self.db.get(self.key)
except KeyError:
if self.do_raise:
raise
print(f"Database key {self.key} not set, setting it to default value")
default = self._default()
await self.db.set(self.key, default)
return default
try:
data = json.loads(read)
except json.JSONDecodeError:
if self.do_raise:
raise
return await self._error("Invalid JSON data read", read)
if not self._is_valid_type(data):
return await self._error(self._type_mismatch_msg(data), read,)
return data
async def _error(self, error: str, read: str) -> JSON_TYPE:
if self.do_raise:
raise ValueError(error)
print(f"Error reading key {self.key!r}: {error}", file=stderr)
if self.discard_bad_data:
val = self._default()
await self.db.set(self.key, json.dumps(val))
print(f"Wrote default to key {self.key!r}")
return val
return await self._should_discard_prompt(error, read)
async def _should_discard_prompt(self, error: str, read: str) -> bool:
while True:
choice = input(
"d to use default, v to view the invalid data, c to insert custom "
"value, ^C to exit: "
)
if choice.startswith("d"):
print("Writing default...")
val = self._default()
await self.db.set(self.key, val)
return val
elif choice.startswith("v"):
print(f"Data read from key: {read!r}")
elif choice.startswith("c"):
toset = input(
f"Enter data to write, should be of type {self.dtype.__name__!r}"
" (leave blank to return to menu): "
)
if not toset:
continue
try:
data = json.loads(toset)
except json.JSONDecodeError:
print("Invalid JSON data!")
else:
if not self._is_valid_type(data):
print(self._type_mismatch_msg(data))
continue
await self.db.set(self.key, toset)
print("Wrote data to key")
return data
async def set(self, data: JSON_TYPE) -> None:
"""Set the value of the jsonkey.
Args:
data (JSON_TYPE): The value to set it to.
Raises:
TypeError: The type of the value set does not match the datatype.
"""
if not self._is_valid_type(data):
raise TypeError(self._type_mismatch_msg(data))
await self.db.set(self.key, json.dumps(data))
class JSONKey(AsyncJSONKey):
"""Represents a key in the database that holds a JSON value.
db.jsonkey() will initialize an instance for you,
you don't have to do it manually.
"""
__slots__ = ("db", "key", "dtype", "get_default", "discard_bad_data", "do_raise")
def __init__(
self,
db: Any,
key: str,
dtype: JSON_TYPE,
get_default: Callable = None,
discard_bad_data: bool = False,
do_raise: bool = False,
) -> None:
"""Initialize the key.
Args:
db (Any): An instance of ReplitDb
key (str): The key to read
dtype (JSON_TYPE): The datatype the key should be, can be typing.Any.
get_default (Callable): A function that returns the default
value if the key is not set. If it is None (the default) the dtype
argument is used.
discard_bad_data (bool): Don't prompt if bad data is read, overwrite it
with the default. Defaults to False.
do_raise (bool): Whether to raise exceptions when errors are encountered.
"""
self.db = db
self.key = key
self.dtype = dtype
self.get_default = get_default
self.discard_bad_data = discard_bad_data
self.do_raise = do_raise
def get(self) -> JSON_TYPE:
"""Get the value of the key.
If an invalid JSON value is read or the type does not match, it will show a
prompt asking the user what to do unless discard_bad_data is set.
Returns:
JSON_TYPE: The value read from the database
"""
try:
read = self.db[self.key]
except KeyError:
print(f"Database key {self.key} not set, setting it to default value")
default = self._default()
self.db[self.key] = default
return default
if isinstance(self.db, ReplitDb):
try:
data = json.loads(read)
except json.JSONDecodeError:
return self._error("Invalid JSON data read", read)
else:
data = read
if not self._is_valid_type(data):
return self._error(self._type_mismatch_msg(data), read,)
return data
def _error(self, error: str, read: str) -> JSON_TYPE:
print(f"Error reading key {self.key!r}: {error}", file=stderr)
if self.discard_bad_data:
val = self._default()
self.db[self.key] = json.dumps(val)
print(f"Wrote default to key {self.key!r}")
return val
return self._should_discard_prompt(error, read)
def _should_discard_prompt(self, error: str, read: str) -> bool:
while True:
choice = input(
"d to use default, v to view the invalid data, c to insert custom "
"value, ^C to exit: "
)
if choice.startswith("d"):
print("Writing default...")
val = self._default()
self.db[self.key] = val
return val
elif choice.startswith("v"):
print(f"Data read from key: {read!r}")
elif choice.startswith("c"):
toset = input(
f"Enter data to write, should be of type {self.dtype.__name__!r}"
" (leave blank to return to menu): "
)
if not toset:
continue
try:
data = json.loads(toset)
except json.JSONDecodeError:
print("Invalid JSON data!")
else:
if not self._is_valid_type(data):
print(self._type_mismatch_msg(data))
continue
self.db[self.key] = toset
print("Wrote data to key")
return data
def set(self, data: JSON_TYPE) -> None:
"""Set the value of the jsonkey.
Args:
data (JSON_TYPE): The value to set it to.
Raises:
TypeError: The type of the value set does not match the datatype.
"""
if not self._is_valid_type(data):
raise TypeError(self._type_mismatch_msg(data))
if isinstance(self.db, ReplitDb):
data = json.dumps(data)
self.db[self.key] = data
def read(self, key: str, default: Any = None) -> Any:
"""Shorthand for self.get().get(name, default) if datatype is dict.
Args:
key (str): The name to get.
default (Any): The default if the key doesn't exist. Defaults to None.
Returns:
Any: The value read or the default.
"""
data = self.get()
if not isinstance(data, dict):
raise TypeError
return data.get(key, default)
def keys(self, *keys: str) -> Any:
"""Reads multiple keys from the key's value and allows setting.
Args:
*keys (str): The keys to read from the data.
Returns:
Any: The value accessed from self.get()[k1][k2][kn]
"""
data = self
for key in keys[:-1]:
data = type(self)(db=data, key=key, dtype=Any)
check = data[keys[-1]]
if type(check) is dict:
return type(self)(db=data, key=keys[-1], dtype=dict)
else:
return check
def __getitem__(self, name: str) -> JSON_TYPE:
"""Retrieve a key from the JSONKey's value if it is a dict.
Args:
name (str): The name to retrieve.
Returns:
JSON_TYPE: The value of the key.
"""
return self.keys(name)
def __setitem__(self, name: str, value: JSON_TYPE) -> None:
"""Sets a key inside the JSONKey's value if it is a dict.
Args:
name (str): The key to set.
value (JSON_TYPE): The value to set it to.
"""
data = self.get()
if not isinstance(data, dict):
raise TypeError
data[name] = value
self.set(data)
def append(self, item: JSON_TYPE) -> None:
"""Append to the JSONKey's value if it is a list.
Args:
item (JSON_TYPE): The item to append.
"""
data = self.get()
if not isinstance(data, list):
raise TypeError
self.set(data + [item])
def __add__(self, item: Any) -> Any:
"""Add to the JSONKey's value and return the result.
Args:
item (Any): The item to add.
Returns:
Any: The result of adding.
"""
return self.get() + item
def __iadd__(self, item: Any) -> Any:
"""Add to the JSONKey's value and set the result.
Args:
item (Any): The item to add.
Returns:
Any: self
"""
r = self.get() + item
self.set(r)
return self
+1 -1
View File
@@ -20,7 +20,7 @@ from .utils import (
sign_in_page,
sign_in_snippet,
)
from ..database import AsyncJSONKey, AsyncReplitDb, db, JSONKey, ReplitDb
from ..database import AsyncDatabase, AsyncJSONKey, Database, db, JSONKey
auth = LocalProxy(lambda: flask.request.auth)
signed_in = LocalProxy(lambda: flask.request.signed_in)
+57 -9
View File
@@ -3,12 +3,12 @@
import os
import unittest
from replit.database import AsyncReplitDb, ReplitDb
from replit.database import AsyncDatabase, AsyncJSONKey, Database
import requests
class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase):
"""Tests for replit.database.AsyncReplitDb."""
"""Tests for replit.database.AsyncDatabase."""
async def asyncSetUp(self) -> None:
"""Grab a JWT for all the tests to share."""
@@ -17,7 +17,7 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase):
"https://database-test-jwt.kochman.repl.co", auth=("test", password)
)
url = req.text
self.db = AsyncReplitDb(url)
self.db = AsyncDatabase(url)
# nuke whatever is already here
for k in await self.db.keys():
@@ -39,6 +39,18 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(KeyError):
await self.db.get("test-key")
async def test_get_set_delete_newline(self) -> None:
"""Test that we can get, set, and delete a key with newline."""
key = "test-key-with\nnewline"
await self.db.set(key, "value")
val = await self.db.get(key)
self.assertEqual(val, "value")
await self.db.delete(key)
with self.assertRaises(KeyError):
await self.db.get(key)
async def test_list_keys(self) -> None:
"""Test that we can list keys."""
key = "test-list-keys-with\nnewline"
@@ -80,7 +92,7 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase):
"""Test replit.database.AsyncJSONKey."""
key = "test-jsonkey"
jk = self.db.jsonkey(key, dtype=str, do_raise=True)
jk = AsyncJSONKey(db=self.db, key=key, dtype=str, do_raise=True)
with self.assertRaises(KeyError):
await jk.get()
await jk.set("value")
@@ -91,13 +103,13 @@ class TestAsyncDatabase(unittest.IsolatedAsyncioTestCase):
"""Test replit.database.AsyncJSONKey with a default callable."""
key = "test-jsonkey"
jk = self.db.jsonkey(key, dtype=str, get_default=lambda: "value")
jk = AsyncJSONKey(db=self.db, key=key, dtype=str, get_default=lambda: "value")
val = await jk.get()
self.assertEqual(val, "value")
class TestDatabase(unittest.IsolatedAsyncioTestCase):
"""Tests for replit.database.ReplitDb."""
class TestDatabase(unittest.TestCase):
"""Tests for replit.database.Database."""
def setUp(self) -> None:
"""Grab a JWT for all the tests to share."""
@@ -106,7 +118,7 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase):
"https://database-test-jwt.kochman.repl.co", auth=("test", password)
)
url = req.text
self.db = ReplitDb(url)
self.db = Database(url)
# nuke whatever is already here
for k in self.db.keys():
@@ -115,7 +127,7 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase):
def tearDown(self) -> None:
"""Nuke whatever the test added."""
for k in self.db.keys():
self.db.delete(k)
del self.db[k]
def test_get_set_delete(self) -> None:
"""Test get, set, and delete."""
@@ -129,3 +141,39 @@ class TestDatabase(unittest.IsolatedAsyncioTestCase):
del self.db["key"]
with self.assertRaises(KeyError):
val = self.db["key"]
def test_list_keys(self) -> None:
"""Test that we can list keys."""
key = "test-list-keys-with\nnewline"
self.db[key] = "value"
val = self.db[key]
self.assertEqual(val, "value")
keys = self.db.prefix(key)
self.assertEqual(keys, (key,))
keys = self.db.keys()
self.assertTupleEqual(tuple(keys), (key,))
# just to make sure...
self.assertTupleEqual(tuple(self.db.keys()), self.db.prefix(""))
del self.db[key]
with self.assertRaises(KeyError):
val = self.db[key]
def test_delete_nonexistent_key(self) -> None:
"""Test that deleting a non-existent key returns 404."""
key = "this-doesn't-exist"
with self.assertRaises(KeyError):
self.db[key]
def test_get_set_fancy_object(self) -> None:
"""Test that we can get/set/delete something that's more than a string."""
key = "big-ol-list"
val = ["this", {"is": "a", "complex": "object"}, 1337]
self.db[key] = val
act = self.db[key]
self.assertEqual(act, val)