Update SQLAlchemy version and modify test fixture

This commit is contained in:
2024-03-29 19:37:25 -04:00
parent 08845f792a
commit 9e0d8bc244
3 changed files with 67 additions and 51 deletions
+54 -40
View File
@@ -24,7 +24,8 @@ def isexception(obj):
class Record(object): class Record(object):
"""A row, from a query, from a database.""" """A row, from a query, from a database."""
__slots__ = ('_keys', '_values')
__slots__ = ("_keys", "_values")
def __init__(self, keys, values): def __init__(self, keys, values):
self._keys = keys self._keys = keys
@@ -42,7 +43,7 @@ class Record(object):
return self._values return self._values
def __repr__(self): def __repr__(self):
return '<Record {}>'.format(self.export('json')[1:-1]) return "<Record {}>".format(self.export("json")[1:-1])
def __getitem__(self, key): def __getitem__(self, key):
# Support for index-based lookup. # Support for index-based lookup.
@@ -51,7 +52,9 @@ class Record(object):
# Support for string-based lookup. # Support for string-based lookup.
usekeys = self.keys() usekeys = self.keys()
if hasattr(usekeys, "_keys"): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list) if hasattr(
usekeys, "_keys"
): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list)
usekeys = usekeys._keys usekeys = usekeys._keys
if key in usekeys: if key in usekeys:
i = usekeys.index(key) i = usekeys.index(key)
@@ -103,13 +106,14 @@ class Record(object):
class RecordCollection(object): class RecordCollection(object):
"""A set of excellent Records from a query.""" """A set of excellent Records from a query."""
def __init__(self, rows): def __init__(self, rows):
self._rows = rows self._rows = rows
self._all_rows = [] self._all_rows = []
self.pending = True self.pending = True
def __repr__(self): def __repr__(self):
return '<RecordCollection size={} pending={}>'.format(len(self), self.pending) return "<RecordCollection size={} pending={}>".format(len(self), self.pending)
def __iter__(self): def __iter__(self):
"""Iterate over all rows, consuming the underlying generator """Iterate over all rows, consuming the underlying generator
@@ -139,7 +143,7 @@ class RecordCollection(object):
return nextrow return nextrow
except StopIteration: except StopIteration:
self.pending = False self.pending = False
raise StopIteration('RecordCollection contains no more rows.') raise StopIteration("RecordCollection contains no more rows.")
def __getitem__(self, key): def __getitem__(self, key):
is_int = isinstance(key, int) is_int = isinstance(key, int)
@@ -203,7 +207,7 @@ class RecordCollection(object):
return rows return rows
def as_dict(self, ordered=False): def as_dict(self, ordered=False):
return self.all(as_dict=not(ordered), as_ordereddict=ordered) return self.all(as_dict=not (ordered), as_ordereddict=ordered)
def first(self, default=None, as_dict=False, as_ordereddict=False): def first(self, default=None, as_dict=False, as_ordereddict=False):
"""Returns a single record for the RecordCollection, or `default`. If """Returns a single record for the RecordCollection, or `default`. If
@@ -235,11 +239,15 @@ class RecordCollection(object):
try: try:
self[1] self[1]
except IndexError: except IndexError:
return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict) return self.first(
default=default, as_dict=as_dict, as_ordereddict=as_ordereddict
)
else: else:
raise ValueError('RecordCollection contained more than one row. ' raise ValueError(
'Expects only one row when using ' "RecordCollection contained more than one row. "
'RecordCollection.one') "Expects only one row when using "
"RecordCollection.one"
)
def scalar(self, default=None): def scalar(self, default=None):
"""Returns the first column of the first row, or `default`.""" """Returns the first column of the first row, or `default`."""
@@ -254,21 +262,21 @@ class Database(object):
def __init__(self, db_url=None, **kwargs): def __init__(self, db_url=None, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL. # If no db_url was provided, fallback to $DATABASE_URL.
self.db_url = db_url or os.environ.get('DATABASE_URL') self.db_url = db_url or os.environ.get("DATABASE_URL")
if not self.db_url: if not self.db_url:
raise ValueError('You must provide a db_url.') raise ValueError("You must provide a db_url.")
# Create an engine. # Create an engine.
self._engine = create_engine(self.db_url, **kwargs) self._engine = create_engine(self.db_url, **kwargs)
self.open = True self.open = True
def get_engine(self): def get_engine(self):
# Return the engine if open # Return the engine if open
if not self.open: if not self.open:
raise exc.ResourceClosedError('Database closed.') raise exc.ResourceClosedError("Database closed.")
return self._engine return self._engine
def close(self): def close(self):
"""Closes the Database.""" """Closes the Database."""
self._engine.dispose() self._engine.dispose()
@@ -281,7 +289,7 @@ class Database(object):
self.close() self.close()
def __repr__(self): def __repr__(self):
return '<Database open={}>'.format(self.open) return "<Database open={}>".format(self.open)
def get_table_names(self, internal=False, **kwargs): def get_table_names(self, internal=False, **kwargs):
"""Returns a list of table names for the connected database.""" """Returns a list of table names for the connected database."""
@@ -294,9 +302,9 @@ class Database(object):
pool. pool.
""" """
if not self.open: if not self.open:
raise exc.ResourceClosedError('Database closed.') raise exc.ResourceClosedError("Database closed.")
return Connection(self._engine.connect(close_with_result=close_with_result), close_with_result) return Connection(self._engine.connect(), close_with_result=close_with_result)
def query(self, query, fetchall=False, **params): def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters can, """Executes the given SQL query against the Database. Parameters can,
@@ -361,7 +369,7 @@ class Connection(object):
self.close() self.close()
def __repr__(self): def __repr__(self):
return '<Connection open={}>'.format(self.open) return "<Connection open={}>".format(self.open)
def query(self, query, fetchall=False, **params): def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the connected Database. """Executes the given SQL query against the connected Database.
@@ -370,11 +378,13 @@ class Connection(object):
""" """
# Execute the given query. # Execute the given query.
cursor = self._conn.execute(text(query).bindparams(**params)) # TODO: PARAMS GO HERE cursor = self._conn.execute(
text(query).bindparams(**params)
) # TODO: PARAMS GO HERE
# Row-by-row Record generator. # Row-by-row Record generator.
row_gen = iter(Record([], [])) row_gen = iter(Record([], []))
if cursor.returns_rows: if cursor.returns_rows:
row_gen = (Record(cursor.keys(), row) for row in cursor) row_gen = (Record(cursor.keys(), row) for row in cursor)
@@ -415,7 +425,7 @@ class Connection(object):
from. from.
""" """
# If path doesn't exists # If path doesn't exists
if not os.path.exists(path): if not os.path.exists(path):
raise IOError("File '{}'' not found!".format(path)) raise IOError("File '{}'' not found!".format(path))
@@ -435,20 +445,22 @@ class Connection(object):
return self._conn.begin() return self._conn.begin()
def _reduce_datetimes(row): def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings.""" """Receives a row, converts datetimes to strings."""
row = list(row) row = list(row)
for i, element in enumerate(row): for i, element in enumerate(row):
if hasattr(element, 'isoformat'): if hasattr(element, "isoformat"):
row[i] = element.isoformat() row[i] = element.isoformat()
return tuple(row) return tuple(row)
def cli(): def cli():
supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split() supported_formats = "csv tsv json yaml html xls xlsx dbf latex ods".split()
formats_lst=", ".join(supported_formats) formats_lst = ", ".join(supported_formats)
cli_docs ="""Records: SQL for Humans™ cli_docs = """Records: SQL for Humans™
A Kenneth Reitz project. A Kenneth Reitz project.
Usage: Usage:
@@ -478,34 +490,36 @@ Notes:
can be provided instead. Use this feature discernfully; it's dangerous. can be provided instead. Use this feature discernfully; it's dangerous.
- Records is intended for report-style exports of database queries, and - Records is intended for report-style exports of database queries, and
has not yet been optimized for extremely large data dumps. has not yet been optimized for extremely large data dumps.
""" % dict(formats_lst=formats_lst) """ % dict(
formats_lst=formats_lst
)
# Parse the command-line arguments. # Parse the command-line arguments.
arguments = docopt(cli_docs) arguments = docopt(cli_docs)
query = arguments['<query>'] query = arguments["<query>"]
params = arguments['<params>'] params = arguments["<params>"]
format = arguments.get('<format>') format = arguments.get("<format>")
if format and "=" in format: if format and "=" in format:
del arguments['<format>'] del arguments["<format>"]
arguments['<params>'].append(format) arguments["<params>"].append(format)
format = None format = None
if format and format not in supported_formats: if format and format not in supported_formats:
print('%s format not supported.' % format) print("%s format not supported." % format)
print('Supported formats are %s.' % formats_lst) print("Supported formats are %s." % formats_lst)
exit(62) exit(62)
# Can't send an empty list if params aren't expected. # Can't send an empty list if params aren't expected.
try: try:
params = dict([i.split('=') for i in params]) params = dict([i.split("=") for i in params])
except ValueError: except ValueError:
print('Parameters must be given in key=value format.') print("Parameters must be given in key=value format.")
exit(64) exit(64)
# Be ready to fail on missing packages # Be ready to fail on missing packages
try: try:
# Create the Database. # Create the Database.
db = Database(arguments['--url']) db = Database(arguments["--url"])
# Execute the query, if it is a found file. # Execute the query, if it is a found file.
if os.path.isfile(query): if os.path.isfile(query):
@@ -517,7 +531,7 @@ Notes:
# Otherwise, say the file wasn't found. # Otherwise, say the file wasn't found.
else: else:
print('The given query could not be found.') print("The given query could not be found.")
exit(66) exit(66)
# Print results in desired format. # Print results in desired format.
@@ -544,5 +558,5 @@ def print_bytes(content):
# Run the CLI when executed directly. # Run the CLI when executed directly.
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()
+1 -1
View File
@@ -47,7 +47,7 @@ class PublishCommand(Command):
requires = [ requires = [
"SQLAlchemy", "SQLAlchemy>=2.0",
"tablib>=0.11.4", "tablib>=0.11.4",
"openpyxl>2.6.0", # https://github.com/kennethreitz-archive/records/pull/184#issuecomment-606207851 "openpyxl>2.6.0", # https://github.com/kennethreitz-archive/records/pull/184#issuecomment-606207851
"docopt", "docopt",
+12 -10
View File
@@ -1,19 +1,21 @@
"""Shared pytest fixtures. """Shared pytest fixtures.
""" """
import pytest import pytest
import records import records
@pytest.fixture(params=[ @pytest.fixture(
# request: (sql_url_id, sql_url_template) params=[
# request: (sql_url_id, sql_url_template)
('sqlite_memory', 'sqlite:///:memory:'), ("sqlite_memory", "sqlite:///:memory:"),
('sqlite_file', 'sqlite:///{dbfile}'), # ('sqlite_file', 'sqlite:///{dbfile}'),
# ('psql', 'postgresql://records:records@localhost/records') # ('psql', 'postgresql://records:records@localhost/records')
], ],
ids=lambda r: r[0]) ids=lambda r: r[0],
)
def db(request, tmpdir): def db(request, tmpdir):
"""Instance of `records.Database(dburl)` """Instance of `records.Database(dburl)`
@@ -42,6 +44,6 @@ def foo_table(db):
Typically applied by `@pytest.mark.usefixtures('foo_table')` Typically applied by `@pytest.mark.usefixtures('foo_table')`
""" """
db.query('CREATE TABLE foo (a integer)') db.query("CREATE TABLE foo (a integer)")
yield yield
db.query('DROP TABLE foo') db.query("DROP TABLE foo")