diff --git a/records.py b/records.py index 3fedb0e..3041978 100644 --- a/records.py +++ b/records.py @@ -2,11 +2,12 @@ import os from collections import OrderedDict +from contextlib import contextmanager from inspect import isclass import tablib from docopt import docopt -from sqlalchemy import create_engine, inspect, text +from sqlalchemy import create_engine, exc, inspect, text DATABASE_URL = os.environ.get('DATABASE_URL') @@ -238,7 +239,9 @@ class RecordCollection(object): class Database(object): - """A Database connection.""" + """A Database. Encapsulates a url and an SQLAlchemy engine with a pool of + connections. + """ def __init__(self, db_url=None, **kwargs): # If no db_url was provided, fallback to $DATABASE_URL. @@ -247,15 +250,13 @@ class Database(object): if not self.db_url: raise ValueError('You must provide a db_url.') + # Create an engine. self._engine = create_engine(self.db_url, **kwargs) - - # Connect to the database. - self.db = self._engine.connect() self.open = True def close(self): - """Closes the connection to the Database.""" - self.db.close() + """Closes the Database.""" + self._engine.dispose() self.open = False def __enter__(self): @@ -273,14 +274,84 @@ class Database(object): # Setup SQLAlchemy for Database inspection. return inspect(self._engine).get_table_names() + def get_connection(self): + """Get a connection to this Database. Connections are retrieved from a + pool. + """ + if not self.open: + raise exc.ResourceClosedError('Database closed.') + + return Connection(self._engine.connect()) + def query(self, query, fetchall=False, **params): - """Executes the given SQL query against the Database. Parameters - can, optionally, be provided. Returns a RecordCollection, which can be + """Executes the given SQL query against the Database. Parameters can, + optionally, be provided. Returns a RecordCollection, which can be iterated over to get result rows as dictionaries. """ + with self.get_connection() as conn: + return conn.query(query, fetchall, **params) + + def bulk_query(self, query, *multiparams): + """Bulk insert or update.""" + + with self.get_connection() as conn: + conn.bulk_query(query, *multiparams) + + def query_file(self, path, fetchall=False, **params): + """Like Database.query, but takes a filename to load a query from.""" + + with self.get_connection() as conn: + return conn.query_file(path, fetchall, **params) + + def bulk_query_file(self, path, *multiparams): + """Like Database.bulk_query, but takes a filename to load a query from.""" + + with self.get_connection() as conn: + conn.bulk_query_file(path, *multiparams) + + @contextmanager + def transaction(self): + """A context manager for executing a transaction on this Database.""" + + conn = self.get_connection() + tx = conn.transaction() + try: + yield conn + tx.commit() + except: + tx.rollback() + finally: + conn.close() + + +class Connection(object): + """A Database connection.""" + + def __init__(self, connection): + self._conn = connection + self.open = not connection.closed + + def close(self): + self._conn.close() + self.open = False + + def __enter__(self): + return self + + def __exit__(self, exc, val, traceback): + self.close() + + def __repr__(self): + return ''.format(self.open) + + def query(self, query, fetchall=False, **params): + """Executes the given SQL query against the connected Database. + Parameters can, optionally, be provided. Returns a RecordCollection, + which can be iterated over to get result rows as dictionaries. + """ # Execute the given query. - cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE + cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE # Row-by-row Record generator. row_gen = (Record(cursor.keys(), row) for row in cursor) @@ -297,10 +368,10 @@ class Database(object): def bulk_query(self, query, *multiparams): """Bulk insert or update.""" - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def query_file(self, path, fetchall=False, **params): - """Like Database.query, but takes a filename to load a query from.""" + """Like Connection.query, but takes a filename to load a query from.""" # If path doesn't exists if not os.path.exists(path): @@ -318,7 +389,9 @@ class Database(object): return self.query(query=query, fetchall=fetchall, **params) def bulk_query_file(self, path, *multiparams): - """Like Database.bulk_query, but takes a filename to load a query from.""" + """Like Connection.bulk_query, but takes a filename to load a query + from. + """ # If path doesn't exists if not os.path.exists(path): @@ -332,12 +405,13 @@ class Database(object): with open(path) as f: query = f.read() - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def transaction(self): """Returns a transaction object. Call ``commit`` or ``rollback`` on the returned object as appropriate.""" - return self.db.begin() + + return self._conn.begin() def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" diff --git a/tests/test_transactions.py b/tests/test_transactions.py index b416918..f2d5e92 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -1,29 +1,60 @@ +import pytest + import records db = records.Database('sqlite:///:memory:') -db.query('CREATE TABLE foo (a integer)') -def test_failing_transaction(): - tx = db.transaction() +@pytest.fixture +def table_setup(request): + db.query('CREATE TABLE foo (a integer)') + def drop_table(): + db.query('DROP TABLE foo') + request.addfinalizer(drop_table) + + +def test_failing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') raise ValueError() tx.commit() - db.query('INSERT INTO foo VALUES (44)') + conn.query('INSERT INTO foo VALUES (44)') except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 -def test_passing_transaction(): - tx = db.transaction() + +def test_failing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + raise ValueError() + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 + + +def test_passing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') tx.commit() except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2 + + +def test_passing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2