Clean up.

This commit is contained in:
Greg Troszak
2017-12-09 08:48:14 -08:00
parent 89ef187b3c
commit 522199af20
+32 -28
View File
@@ -231,7 +231,9 @@ class RecordCollection(object):
return record
class Database(object):
"""A Database connection."""
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
"""
def __init__(self, db_url=None, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL.
@@ -240,13 +242,12 @@ class Database(object):
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Create an engine.
self._engine = create_engine(self.db_url, **kwargs)
# Connect to the database.
self.open = True
def close(self):
"""Closes the connection to the Database."""
"""Closes the Database."""
self._engine.dispose()
self.open = False
@@ -266,53 +267,54 @@ class Database(object):
return inspect(self._engine).get_table_names()
def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
"""Executes the given SQL query against the Database. Parameters can,
optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with Connection(self._engine.connect()) as conn:
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with Connection(self._engine.connect()) as conn:
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with Connection(self._engine.connect()) as conn:
with self.get_connection() as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with Connection(self._engine.connect()) as conn:
with self.get_connection() as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
"""A context manager for executing a transaction on this Database."""
with Connection(self._engine.connect()) as conn:
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A database connection."""
"""A Database connection."""
def __init__(self, connection):
self._conn = connection
@@ -332,9 +334,9 @@ class Connection(object):
return '<Connection open={}>'.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""Executes the given SQL query against the connected Database.
Parameters can, optionally, be provided. Returns a RecordCollection,
which can be iterated over to get result rows as dictionaries.
"""
# Execute the given query.
@@ -358,7 +360,7 @@ class Connection(object):
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
"""Like Connection.query, but takes a filename to load a query from."""
# If path doesn't exists
if not os.path.exists(path):
@@ -376,7 +378,9 @@ class Connection(object):
return self.query(query=query, fetchall=fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
"""Like Connection.bulk_query, but takes a filename to load a query
from.
"""
# If path doesn't exists
if not os.path.exists(path):