From 89ef187b3cc065de9ea01c505f40f1df53dbc6e4 Mon Sep 17 00:00:00 2001 From: Greg Troszak Date: Sat, 9 Dec 2017 08:22:07 -0800 Subject: [PATCH] Properly handle connection pools with txs. --- records.py | 79 +++++++++++++++++++++++++++++++++++--- tests/test_transactions.py | 51 +++++++++++++++++++----- 2 files changed, 114 insertions(+), 16 deletions(-) diff --git a/records.py b/records.py index 2a9bb56..b16d338 100644 --- a/records.py +++ b/records.py @@ -2,6 +2,7 @@ import os from collections import OrderedDict +from contextlib import contextmanager from inspect import isclass import tablib @@ -242,12 +243,11 @@ class Database(object): self._engine = create_engine(self.db_url, **kwargs) # Connect to the database. - self.db = self._engine.connect() self.open = True def close(self): """Closes the connection to the Database.""" - self.db.close() + self._engine.dispose() self.open = False def __enter__(self): @@ -265,6 +265,72 @@ class Database(object): # Setup SQLAlchemy for Database inspection. return inspect(self._engine).get_table_names() + def get_connection(self): + return Connection(self._engine.connect()) + + def query(self, query, fetchall=False, **params): + """Executes the given SQL query against the Database. Parameters + can, optionally, be provided. Returns a RecordCollection, which can be + iterated over to get result rows as dictionaries. + """ + + with Connection(self._engine.connect()) as conn: + return conn.query(query, fetchall, **params) + + def bulk_query(self, query, *multiparams): + """Bulk insert or update.""" + + with Connection(self._engine.connect()) as conn: + conn.bulk_query(query, *multiparams) + + def query_file(self, path, fetchall=False, **params): + """Like Database.query, but takes a filename to load a query from.""" + + with Connection(self._engine.connect()) as conn: + return conn.query_file(path, fetchall, **params) + + def bulk_query_file(self, path, *multiparams): + """Like Database.bulk_query, but takes a filename to load a query from.""" + + with Connection(self._engine.connect()) as conn: + conn.bulk_query_file(path, *multiparams) + + @contextmanager + def transaction(self): + """Returns a transaction object. Call ``commit`` or ``rollback`` + on the returned object as appropriate.""" + + with Connection(self._engine.connect()) as conn: + tx = conn.transaction() + try: + yield conn + tx.commit() + except: + tx.rollback() + finally: + conn.close() + + +class Connection(object): + """A database connection.""" + + def __init__(self, connection): + self._conn = connection + self.open = not connection.closed + + def close(self): + self._conn.close() + self.open = False + + def __enter__(self): + return self + + def __exit__(self, exc, val, traceback): + self.close() + + def __repr__(self): + return ''.format(self.open) + def query(self, query, fetchall=False, **params): """Executes the given SQL query against the Database. Parameters can, optionally, be provided. Returns a RecordCollection, which can be @@ -272,7 +338,7 @@ class Database(object): """ # Execute the given query. - cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE + cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE # Row-by-row Record generator. row_gen = (Record(cursor.keys(), row) for row in cursor) @@ -289,7 +355,7 @@ class Database(object): def bulk_query(self, query, *multiparams): """Bulk insert or update.""" - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" @@ -324,12 +390,13 @@ class Database(object): with open(path) as f: query = f.read() - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def transaction(self): """Returns a transaction object. Call ``commit`` or ``rollback`` on the returned object as appropriate.""" - return self.db.begin() + + return self._conn.begin() def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" diff --git a/tests/test_transactions.py b/tests/test_transactions.py index b416918..f2d5e92 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -1,29 +1,60 @@ +import pytest + import records db = records.Database('sqlite:///:memory:') -db.query('CREATE TABLE foo (a integer)') -def test_failing_transaction(): - tx = db.transaction() +@pytest.fixture +def table_setup(request): + db.query('CREATE TABLE foo (a integer)') + def drop_table(): + db.query('DROP TABLE foo') + request.addfinalizer(drop_table) + + +def test_failing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') raise ValueError() tx.commit() - db.query('INSERT INTO foo VALUES (44)') + conn.query('INSERT INTO foo VALUES (44)') except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 -def test_passing_transaction(): - tx = db.transaction() + +def test_failing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + raise ValueError() + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 + + +def test_passing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') tx.commit() except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2 + + +def test_passing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2