From 522199af20f297350179770df56fb5ce268151a2 Mon Sep 17 00:00:00 2001 From: Greg Troszak Date: Sat, 9 Dec 2017 08:48:14 -0800 Subject: [PATCH] Clean up. --- records.py | 60 +++++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/records.py b/records.py index b16d338..2fd0fdd 100644 --- a/records.py +++ b/records.py @@ -231,7 +231,9 @@ class RecordCollection(object): return record class Database(object): - """A Database connection.""" + """A Database. Encapsulates a url and an SQLAlchemy engine with a pool of + connections. + """ def __init__(self, db_url=None, **kwargs): # If no db_url was provided, fallback to $DATABASE_URL. @@ -240,13 +242,12 @@ class Database(object): if not self.db_url: raise ValueError('You must provide a db_url.') + # Create an engine. self._engine = create_engine(self.db_url, **kwargs) - - # Connect to the database. self.open = True def close(self): - """Closes the connection to the Database.""" + """Closes the Database.""" self._engine.dispose() self.open = False @@ -266,53 +267,54 @@ class Database(object): return inspect(self._engine).get_table_names() def get_connection(self): + """Get a connection to this Database. Connections are retrieved from a + pool. + """ return Connection(self._engine.connect()) def query(self, query, fetchall=False, **params): - """Executes the given SQL query against the Database. Parameters - can, optionally, be provided. Returns a RecordCollection, which can be + """Executes the given SQL query against the Database. Parameters can, + optionally, be provided. Returns a RecordCollection, which can be iterated over to get result rows as dictionaries. """ - - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: return conn.query(query, fetchall, **params) def bulk_query(self, query, *multiparams): """Bulk insert or update.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: conn.bulk_query(query, *multiparams) def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: return conn.query_file(path, fetchall, **params) def bulk_query_file(self, path, *multiparams): """Like Database.bulk_query, but takes a filename to load a query from.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: conn.bulk_query_file(path, *multiparams) @contextmanager def transaction(self): - """Returns a transaction object. Call ``commit`` or ``rollback`` - on the returned object as appropriate.""" + """A context manager for executing a transaction on this Database.""" - with Connection(self._engine.connect()) as conn: - tx = conn.transaction() - try: - yield conn - tx.commit() - except: - tx.rollback() - finally: - conn.close() + conn = self.get_connection() + tx = conn.transaction() + try: + yield conn + tx.commit() + except: + tx.rollback() + finally: + conn.close() class Connection(object): - """A database connection.""" + """A Database connection.""" def __init__(self, connection): self._conn = connection @@ -332,9 +334,9 @@ class Connection(object): return ''.format(self.open) def query(self, query, fetchall=False, **params): - """Executes the given SQL query against the Database. Parameters - can, optionally, be provided. Returns a RecordCollection, which can be - iterated over to get result rows as dictionaries. + """Executes the given SQL query against the connected Database. + Parameters can, optionally, be provided. Returns a RecordCollection, + which can be iterated over to get result rows as dictionaries. """ # Execute the given query. @@ -358,7 +360,7 @@ class Connection(object): self._conn.execute(text(query), *multiparams) def query_file(self, path, fetchall=False, **params): - """Like Database.query, but takes a filename to load a query from.""" + """Like Connection.query, but takes a filename to load a query from.""" # If path doesn't exists if not os.path.exists(path): @@ -376,7 +378,9 @@ class Connection(object): return self.query(query=query, fetchall=fetchall, **params) def bulk_query_file(self, path, *multiparams): - """Like Database.bulk_query, but takes a filename to load a query from.""" + """Like Connection.bulk_query, but takes a filename to load a query + from. + """ # If path doesn't exists if not os.path.exists(path):