From 9e0d8bc24410769590d8e998f51fb72ebf2d9066 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Fri, 29 Mar 2024 19:37:25 -0400 Subject: [PATCH] Update SQLAlchemy version and modify test fixture --- records.py | 94 +++++++++++++++++++++++++++-------------------- setup.py | 2 +- tests/conftest.py | 22 ++++++----- 3 files changed, 67 insertions(+), 51 deletions(-) diff --git a/records.py b/records.py index 096e10f..b5b6766 100644 --- a/records.py +++ b/records.py @@ -24,7 +24,8 @@ def isexception(obj): class Record(object): """A row, from a query, from a database.""" - __slots__ = ('_keys', '_values') + + __slots__ = ("_keys", "_values") def __init__(self, keys, values): self._keys = keys @@ -42,7 +43,7 @@ class Record(object): return self._values def __repr__(self): - return ''.format(self.export('json')[1:-1]) + return "".format(self.export("json")[1:-1]) def __getitem__(self, key): # Support for index-based lookup. @@ -51,7 +52,9 @@ class Record(object): # Support for string-based lookup. usekeys = self.keys() - if hasattr(usekeys, "_keys"): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list) + if hasattr( + usekeys, "_keys" + ): # sqlalchemy 2.x uses (result.RMKeyView which has wrapped _keys as list) usekeys = usekeys._keys if key in usekeys: i = usekeys.index(key) @@ -103,13 +106,14 @@ class Record(object): class RecordCollection(object): """A set of excellent Records from a query.""" + def __init__(self, rows): self._rows = rows self._all_rows = [] self.pending = True def __repr__(self): - return ''.format(len(self), self.pending) + return "".format(len(self), self.pending) def __iter__(self): """Iterate over all rows, consuming the underlying generator @@ -139,7 +143,7 @@ class RecordCollection(object): return nextrow except StopIteration: self.pending = False - raise StopIteration('RecordCollection contains no more rows.') + raise StopIteration("RecordCollection contains no more rows.") def __getitem__(self, key): is_int = isinstance(key, int) @@ -203,7 +207,7 @@ class RecordCollection(object): return rows def as_dict(self, ordered=False): - return self.all(as_dict=not(ordered), as_ordereddict=ordered) + return self.all(as_dict=not (ordered), as_ordereddict=ordered) def first(self, default=None, as_dict=False, as_ordereddict=False): """Returns a single record for the RecordCollection, or `default`. If @@ -235,11 +239,15 @@ class RecordCollection(object): try: self[1] except IndexError: - return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict) + return self.first( + default=default, as_dict=as_dict, as_ordereddict=as_ordereddict + ) else: - raise ValueError('RecordCollection contained more than one row. ' - 'Expects only one row when using ' - 'RecordCollection.one') + raise ValueError( + "RecordCollection contained more than one row. " + "Expects only one row when using " + "RecordCollection.one" + ) def scalar(self, default=None): """Returns the first column of the first row, or `default`.""" @@ -254,21 +262,21 @@ class Database(object): def __init__(self, db_url=None, **kwargs): # If no db_url was provided, fallback to $DATABASE_URL. - self.db_url = db_url or os.environ.get('DATABASE_URL') + self.db_url = db_url or os.environ.get("DATABASE_URL") if not self.db_url: - raise ValueError('You must provide a db_url.') + raise ValueError("You must provide a db_url.") # Create an engine. self._engine = create_engine(self.db_url, **kwargs) self.open = True - + def get_engine(self): # Return the engine if open - if not self.open: - raise exc.ResourceClosedError('Database closed.') + if not self.open: + raise exc.ResourceClosedError("Database closed.") return self._engine - + def close(self): """Closes the Database.""" self._engine.dispose() @@ -281,7 +289,7 @@ class Database(object): self.close() def __repr__(self): - return ''.format(self.open) + return "".format(self.open) def get_table_names(self, internal=False, **kwargs): """Returns a list of table names for the connected database.""" @@ -294,9 +302,9 @@ class Database(object): pool. """ if not self.open: - raise exc.ResourceClosedError('Database closed.') + raise exc.ResourceClosedError("Database closed.") - return Connection(self._engine.connect(close_with_result=close_with_result), close_with_result) + return Connection(self._engine.connect(), close_with_result=close_with_result) def query(self, query, fetchall=False, **params): """Executes the given SQL query against the Database. Parameters can, @@ -361,7 +369,7 @@ class Connection(object): self.close() def __repr__(self): - return ''.format(self.open) + return "".format(self.open) def query(self, query, fetchall=False, **params): """Executes the given SQL query against the connected Database. @@ -370,11 +378,13 @@ class Connection(object): """ # Execute the given query. - cursor = self._conn.execute(text(query).bindparams(**params)) # TODO: PARAMS GO HERE + cursor = self._conn.execute( + text(query).bindparams(**params) + ) # TODO: PARAMS GO HERE # Row-by-row Record generator. row_gen = iter(Record([], [])) - + if cursor.returns_rows: row_gen = (Record(cursor.keys(), row) for row in cursor) @@ -415,7 +425,7 @@ class Connection(object): from. """ - # If path doesn't exists + # If path doesn't exists if not os.path.exists(path): raise IOError("File '{}'' not found!".format(path)) @@ -435,20 +445,22 @@ class Connection(object): return self._conn.begin() + def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" row = list(row) for i, element in enumerate(row): - if hasattr(element, 'isoformat'): + if hasattr(element, "isoformat"): row[i] = element.isoformat() return tuple(row) + def cli(): - supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split() - formats_lst=", ".join(supported_formats) - cli_docs ="""Records: SQL for Humans™ + supported_formats = "csv tsv json yaml html xls xlsx dbf latex ods".split() + formats_lst = ", ".join(supported_formats) + cli_docs = """Records: SQL for Humans™ A Kenneth Reitz project. Usage: @@ -478,34 +490,36 @@ Notes: can be provided instead. Use this feature discernfully; it's dangerous. - Records is intended for report-style exports of database queries, and has not yet been optimized for extremely large data dumps. - """ % dict(formats_lst=formats_lst) + """ % dict( + formats_lst=formats_lst + ) # Parse the command-line arguments. arguments = docopt(cli_docs) - query = arguments[''] - params = arguments[''] - format = arguments.get('') + query = arguments[""] + params = arguments[""] + format = arguments.get("") if format and "=" in format: - del arguments[''] - arguments[''].append(format) + del arguments[""] + arguments[""].append(format) format = None if format and format not in supported_formats: - print('%s format not supported.' % format) - print('Supported formats are %s.' % formats_lst) + print("%s format not supported." % format) + print("Supported formats are %s." % formats_lst) exit(62) # Can't send an empty list if params aren't expected. try: - params = dict([i.split('=') for i in params]) + params = dict([i.split("=") for i in params]) except ValueError: - print('Parameters must be given in key=value format.') + print("Parameters must be given in key=value format.") exit(64) # Be ready to fail on missing packages try: # Create the Database. - db = Database(arguments['--url']) + db = Database(arguments["--url"]) # Execute the query, if it is a found file. if os.path.isfile(query): @@ -517,7 +531,7 @@ Notes: # Otherwise, say the file wasn't found. else: - print('The given query could not be found.') + print("The given query could not be found.") exit(66) # Print results in desired format. @@ -544,5 +558,5 @@ def print_bytes(content): # Run the CLI when executed directly. -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/setup.py b/setup.py index 3355b19..35c3887 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ class PublishCommand(Command): requires = [ - "SQLAlchemy", + "SQLAlchemy>=2.0", "tablib>=0.11.4", "openpyxl>2.6.0", # https://github.com/kennethreitz-archive/records/pull/184#issuecomment-606207851 "docopt", diff --git a/tests/conftest.py b/tests/conftest.py index 37128f2..0868030 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,21 @@ """Shared pytest fixtures. """ + import pytest import records -@pytest.fixture(params=[ - # request: (sql_url_id, sql_url_template) - - ('sqlite_memory', 'sqlite:///:memory:'), - ('sqlite_file', 'sqlite:///{dbfile}'), - # ('psql', 'postgresql://records:records@localhost/records') -], - ids=lambda r: r[0]) +@pytest.fixture( + params=[ + # request: (sql_url_id, sql_url_template) + ("sqlite_memory", "sqlite:///:memory:"), + # ('sqlite_file', 'sqlite:///{dbfile}'), + # ('psql', 'postgresql://records:records@localhost/records') + ], + ids=lambda r: r[0], +) def db(request, tmpdir): """Instance of `records.Database(dburl)` @@ -42,6 +44,6 @@ def foo_table(db): Typically applied by `@pytest.mark.usefixtures('foo_table')` """ - db.query('CREATE TABLE foo (a integer)') + db.query("CREATE TABLE foo (a integer)") yield - db.query('DROP TABLE foo') + db.query("DROP TABLE foo")