mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 06:46:17 +00:00
Update SQLAlchemy version and modify test fixture
This commit is contained in:
+54
-40
@@ -24,7 +24,8 @@ def isexception(obj):
|
|||||||
|
|
||||||
class Record(object):
|
class Record(object):
|
||||||
"""A row, from a query, from a database."""
|
"""A row, from a query, from a database."""
|
||||||
__slots__ = ('_keys', '_values')
|
|
||||||
|
__slots__ = ("_keys", "_values")
|
||||||
|
|
||||||
def __init__(self, keys, values):
|
def __init__(self, keys, values):
|
||||||
self._keys = keys
|
self._keys = keys
|
||||||
@@ -42,7 +43,7 @@ class Record(object):
|
|||||||
return self._values
|
return self._values
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Record {}>'.format(self.export('json')[1:-1])
|
return "<Record {}>".format(self.export("json")[1:-1])
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
# Support for index-based lookup.
|
# Support for index-based lookup.
|
||||||
@@ -51,7 +52,9 @@ class Record(object):
|
|||||||
|
|
||||||
# Support for string-based lookup.
|
# Support for string-based lookup.
|
||||||
usekeys = self.keys()
|
usekeys = self.keys()
|
||||||
if hasattr(usekeys, "_keys"): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list)
|
if hasattr(
|
||||||
|
usekeys, "_keys"
|
||||||
|
): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list)
|
||||||
usekeys = usekeys._keys
|
usekeys = usekeys._keys
|
||||||
if key in usekeys:
|
if key in usekeys:
|
||||||
i = usekeys.index(key)
|
i = usekeys.index(key)
|
||||||
@@ -103,13 +106,14 @@ class Record(object):
|
|||||||
|
|
||||||
class RecordCollection(object):
|
class RecordCollection(object):
|
||||||
"""A set of excellent Records from a query."""
|
"""A set of excellent Records from a query."""
|
||||||
|
|
||||||
def __init__(self, rows):
|
def __init__(self, rows):
|
||||||
self._rows = rows
|
self._rows = rows
|
||||||
self._all_rows = []
|
self._all_rows = []
|
||||||
self.pending = True
|
self.pending = True
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<RecordCollection size={} pending={}>'.format(len(self), self.pending)
|
return "<RecordCollection size={} pending={}>".format(len(self), self.pending)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Iterate over all rows, consuming the underlying generator
|
"""Iterate over all rows, consuming the underlying generator
|
||||||
@@ -139,7 +143,7 @@ class RecordCollection(object):
|
|||||||
return nextrow
|
return nextrow
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.pending = False
|
self.pending = False
|
||||||
raise StopIteration('RecordCollection contains no more rows.')
|
raise StopIteration("RecordCollection contains no more rows.")
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
is_int = isinstance(key, int)
|
is_int = isinstance(key, int)
|
||||||
@@ -203,7 +207,7 @@ class RecordCollection(object):
|
|||||||
return rows
|
return rows
|
||||||
|
|
||||||
def as_dict(self, ordered=False):
|
def as_dict(self, ordered=False):
|
||||||
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
|
return self.all(as_dict=not (ordered), as_ordereddict=ordered)
|
||||||
|
|
||||||
def first(self, default=None, as_dict=False, as_ordereddict=False):
|
def first(self, default=None, as_dict=False, as_ordereddict=False):
|
||||||
"""Returns a single record for the RecordCollection, or `default`. If
|
"""Returns a single record for the RecordCollection, or `default`. If
|
||||||
@@ -235,11 +239,15 @@ class RecordCollection(object):
|
|||||||
try:
|
try:
|
||||||
self[1]
|
self[1]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)
|
return self.first(
|
||||||
|
default=default, as_dict=as_dict, as_ordereddict=as_ordereddict
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError('RecordCollection contained more than one row. '
|
raise ValueError(
|
||||||
'Expects only one row when using '
|
"RecordCollection contained more than one row. "
|
||||||
'RecordCollection.one')
|
"Expects only one row when using "
|
||||||
|
"RecordCollection.one"
|
||||||
|
)
|
||||||
|
|
||||||
def scalar(self, default=None):
|
def scalar(self, default=None):
|
||||||
"""Returns the first column of the first row, or `default`."""
|
"""Returns the first column of the first row, or `default`."""
|
||||||
@@ -254,21 +262,21 @@ class Database(object):
|
|||||||
|
|
||||||
def __init__(self, db_url=None, **kwargs):
|
def __init__(self, db_url=None, **kwargs):
|
||||||
# If no db_url was provided, fallback to $DATABASE_URL.
|
# If no db_url was provided, fallback to $DATABASE_URL.
|
||||||
self.db_url = db_url or os.environ.get('DATABASE_URL')
|
self.db_url = db_url or os.environ.get("DATABASE_URL")
|
||||||
|
|
||||||
if not self.db_url:
|
if not self.db_url:
|
||||||
raise ValueError('You must provide a db_url.')
|
raise ValueError("You must provide a db_url.")
|
||||||
|
|
||||||
# Create an engine.
|
# Create an engine.
|
||||||
self._engine = create_engine(self.db_url, **kwargs)
|
self._engine = create_engine(self.db_url, **kwargs)
|
||||||
self.open = True
|
self.open = True
|
||||||
|
|
||||||
def get_engine(self):
|
def get_engine(self):
|
||||||
# Return the engine if open
|
# Return the engine if open
|
||||||
if not self.open:
|
if not self.open:
|
||||||
raise exc.ResourceClosedError('Database closed.')
|
raise exc.ResourceClosedError("Database closed.")
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Closes the Database."""
|
"""Closes the Database."""
|
||||||
self._engine.dispose()
|
self._engine.dispose()
|
||||||
@@ -281,7 +289,7 @@ class Database(object):
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Database open={}>'.format(self.open)
|
return "<Database open={}>".format(self.open)
|
||||||
|
|
||||||
def get_table_names(self, internal=False, **kwargs):
|
def get_table_names(self, internal=False, **kwargs):
|
||||||
"""Returns a list of table names for the connected database."""
|
"""Returns a list of table names for the connected database."""
|
||||||
@@ -294,9 +302,9 @@ class Database(object):
|
|||||||
pool.
|
pool.
|
||||||
"""
|
"""
|
||||||
if not self.open:
|
if not self.open:
|
||||||
raise exc.ResourceClosedError('Database closed.')
|
raise exc.ResourceClosedError("Database closed.")
|
||||||
|
|
||||||
return Connection(self._engine.connect(close_with_result=close_with_result), close_with_result)
|
return Connection(self._engine.connect(), close_with_result=close_with_result)
|
||||||
|
|
||||||
def query(self, query, fetchall=False, **params):
|
def query(self, query, fetchall=False, **params):
|
||||||
"""Executes the given SQL query against the Database. Parameters can,
|
"""Executes the given SQL query against the Database. Parameters can,
|
||||||
@@ -361,7 +369,7 @@ class Connection(object):
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Connection open={}>'.format(self.open)
|
return "<Connection open={}>".format(self.open)
|
||||||
|
|
||||||
def query(self, query, fetchall=False, **params):
|
def query(self, query, fetchall=False, **params):
|
||||||
"""Executes the given SQL query against the connected Database.
|
"""Executes the given SQL query against the connected Database.
|
||||||
@@ -370,11 +378,13 @@ class Connection(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Execute the given query.
|
# Execute the given query.
|
||||||
cursor = self._conn.execute(text(query).bindparams(**params)) # TODO: PARAMS GO HERE
|
cursor = self._conn.execute(
|
||||||
|
text(query).bindparams(**params)
|
||||||
|
) # TODO: PARAMS GO HERE
|
||||||
|
|
||||||
# Row-by-row Record generator.
|
# Row-by-row Record generator.
|
||||||
row_gen = iter(Record([], []))
|
row_gen = iter(Record([], []))
|
||||||
|
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
row_gen = (Record(cursor.keys(), row) for row in cursor)
|
row_gen = (Record(cursor.keys(), row) for row in cursor)
|
||||||
|
|
||||||
@@ -415,7 +425,7 @@ class Connection(object):
|
|||||||
from.
|
from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If path doesn't exists
|
# If path doesn't exists
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise IOError("File '{}'' not found!".format(path))
|
raise IOError("File '{}'' not found!".format(path))
|
||||||
|
|
||||||
@@ -435,20 +445,22 @@ class Connection(object):
|
|||||||
|
|
||||||
return self._conn.begin()
|
return self._conn.begin()
|
||||||
|
|
||||||
|
|
||||||
def _reduce_datetimes(row):
|
def _reduce_datetimes(row):
|
||||||
"""Receives a row, converts datetimes to strings."""
|
"""Receives a row, converts datetimes to strings."""
|
||||||
|
|
||||||
row = list(row)
|
row = list(row)
|
||||||
|
|
||||||
for i, element in enumerate(row):
|
for i, element in enumerate(row):
|
||||||
if hasattr(element, 'isoformat'):
|
if hasattr(element, "isoformat"):
|
||||||
row[i] = element.isoformat()
|
row[i] = element.isoformat()
|
||||||
return tuple(row)
|
return tuple(row)
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split()
|
supported_formats = "csv tsv json yaml html xls xlsx dbf latex ods".split()
|
||||||
formats_lst=", ".join(supported_formats)
|
formats_lst = ", ".join(supported_formats)
|
||||||
cli_docs ="""Records: SQL for Humans™
|
cli_docs = """Records: SQL for Humans™
|
||||||
A Kenneth Reitz project.
|
A Kenneth Reitz project.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@@ -478,34 +490,36 @@ Notes:
|
|||||||
can be provided instead. Use this feature discernfully; it's dangerous.
|
can be provided instead. Use this feature discernfully; it's dangerous.
|
||||||
- Records is intended for report-style exports of database queries, and
|
- Records is intended for report-style exports of database queries, and
|
||||||
has not yet been optimized for extremely large data dumps.
|
has not yet been optimized for extremely large data dumps.
|
||||||
""" % dict(formats_lst=formats_lst)
|
""" % dict(
|
||||||
|
formats_lst=formats_lst
|
||||||
|
)
|
||||||
|
|
||||||
# Parse the command-line arguments.
|
# Parse the command-line arguments.
|
||||||
arguments = docopt(cli_docs)
|
arguments = docopt(cli_docs)
|
||||||
|
|
||||||
query = arguments['<query>']
|
query = arguments["<query>"]
|
||||||
params = arguments['<params>']
|
params = arguments["<params>"]
|
||||||
format = arguments.get('<format>')
|
format = arguments.get("<format>")
|
||||||
if format and "=" in format:
|
if format and "=" in format:
|
||||||
del arguments['<format>']
|
del arguments["<format>"]
|
||||||
arguments['<params>'].append(format)
|
arguments["<params>"].append(format)
|
||||||
format = None
|
format = None
|
||||||
if format and format not in supported_formats:
|
if format and format not in supported_formats:
|
||||||
print('%s format not supported.' % format)
|
print("%s format not supported." % format)
|
||||||
print('Supported formats are %s.' % formats_lst)
|
print("Supported formats are %s." % formats_lst)
|
||||||
exit(62)
|
exit(62)
|
||||||
|
|
||||||
# Can't send an empty list if params aren't expected.
|
# Can't send an empty list if params aren't expected.
|
||||||
try:
|
try:
|
||||||
params = dict([i.split('=') for i in params])
|
params = dict([i.split("=") for i in params])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print('Parameters must be given in key=value format.')
|
print("Parameters must be given in key=value format.")
|
||||||
exit(64)
|
exit(64)
|
||||||
|
|
||||||
# Be ready to fail on missing packages
|
# Be ready to fail on missing packages
|
||||||
try:
|
try:
|
||||||
# Create the Database.
|
# Create the Database.
|
||||||
db = Database(arguments['--url'])
|
db = Database(arguments["--url"])
|
||||||
|
|
||||||
# Execute the query, if it is a found file.
|
# Execute the query, if it is a found file.
|
||||||
if os.path.isfile(query):
|
if os.path.isfile(query):
|
||||||
@@ -517,7 +531,7 @@ Notes:
|
|||||||
|
|
||||||
# Otherwise, say the file wasn't found.
|
# Otherwise, say the file wasn't found.
|
||||||
else:
|
else:
|
||||||
print('The given query could not be found.')
|
print("The given query could not be found.")
|
||||||
exit(66)
|
exit(66)
|
||||||
|
|
||||||
# Print results in desired format.
|
# Print results in desired format.
|
||||||
@@ -544,5 +558,5 @@ def print_bytes(content):
|
|||||||
|
|
||||||
|
|
||||||
# Run the CLI when executed directly.
|
# Run the CLI when executed directly.
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class PublishCommand(Command):
|
|||||||
|
|
||||||
|
|
||||||
requires = [
|
requires = [
|
||||||
"SQLAlchemy",
|
"SQLAlchemy>=2.0",
|
||||||
"tablib>=0.11.4",
|
"tablib>=0.11.4",
|
||||||
"openpyxl>2.6.0", # https://github.com/kennethreitz-archive/records/pull/184#issuecomment-606207851
|
"openpyxl>2.6.0", # https://github.com/kennethreitz-archive/records/pull/184#issuecomment-606207851
|
||||||
"docopt",
|
"docopt",
|
||||||
|
|||||||
+12
-10
@@ -1,19 +1,21 @@
|
|||||||
"""Shared pytest fixtures.
|
"""Shared pytest fixtures.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import records
|
import records
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[
|
@pytest.fixture(
|
||||||
# request: (sql_url_id, sql_url_template)
|
params=[
|
||||||
|
# request: (sql_url_id, sql_url_template)
|
||||||
('sqlite_memory', 'sqlite:///:memory:'),
|
("sqlite_memory", "sqlite:///:memory:"),
|
||||||
('sqlite_file', 'sqlite:///{dbfile}'),
|
# ('sqlite_file', 'sqlite:///{dbfile}'),
|
||||||
# ('psql', 'postgresql://records:records@localhost/records')
|
# ('psql', 'postgresql://records:records@localhost/records')
|
||||||
],
|
],
|
||||||
ids=lambda r: r[0])
|
ids=lambda r: r[0],
|
||||||
|
)
|
||||||
def db(request, tmpdir):
|
def db(request, tmpdir):
|
||||||
"""Instance of `records.Database(dburl)`
|
"""Instance of `records.Database(dburl)`
|
||||||
|
|
||||||
@@ -42,6 +44,6 @@ def foo_table(db):
|
|||||||
|
|
||||||
Typically applied by `@pytest.mark.usefixtures('foo_table')`
|
Typically applied by `@pytest.mark.usefixtures('foo_table')`
|
||||||
"""
|
"""
|
||||||
db.query('CREATE TABLE foo (a integer)')
|
db.query("CREATE TABLE foo (a integer)")
|
||||||
yield
|
yield
|
||||||
db.query('DROP TABLE foo')
|
db.query("DROP TABLE foo")
|
||||||
|
|||||||
Reference in New Issue
Block a user