From 89ef187b3cc065de9ea01c505f40f1df53dbc6e4 Mon Sep 17 00:00:00 2001 From: Greg Troszak Date: Sat, 9 Dec 2017 08:22:07 -0800 Subject: [PATCH 1/3] Properly handle connection pools with txs. --- records.py | 79 +++++++++++++++++++++++++++++++++++--- tests/test_transactions.py | 51 +++++++++++++++++++----- 2 files changed, 114 insertions(+), 16 deletions(-) diff --git a/records.py b/records.py index 2a9bb56..b16d338 100644 --- a/records.py +++ b/records.py @@ -2,6 +2,7 @@ import os from collections import OrderedDict +from contextlib import contextmanager from inspect import isclass import tablib @@ -242,12 +243,11 @@ class Database(object): self._engine = create_engine(self.db_url, **kwargs) # Connect to the database. - self.db = self._engine.connect() self.open = True def close(self): """Closes the connection to the Database.""" - self.db.close() + self._engine.dispose() self.open = False def __enter__(self): @@ -265,6 +265,72 @@ class Database(object): # Setup SQLAlchemy for Database inspection. return inspect(self._engine).get_table_names() + def get_connection(self): + return Connection(self._engine.connect()) + + def query(self, query, fetchall=False, **params): + """Executes the given SQL query against the Database. Parameters + can, optionally, be provided. Returns a RecordCollection, which can be + iterated over to get result rows as dictionaries. + """ + + with Connection(self._engine.connect()) as conn: + return conn.query(query, fetchall, **params) + + def bulk_query(self, query, *multiparams): + """Bulk insert or update.""" + + with Connection(self._engine.connect()) as conn: + conn.bulk_query(query, *multiparams) + + def query_file(self, path, fetchall=False, **params): + """Like Database.query, but takes a filename to load a query from.""" + + with Connection(self._engine.connect()) as conn: + return conn.query_file(path, fetchall, **params) + + def bulk_query_file(self, path, *multiparams): + """Like Database.bulk_query, but takes a filename to load a query from.""" + + with Connection(self._engine.connect()) as conn: + conn.bulk_query_file(path, *multiparams) + + @contextmanager + def transaction(self): + """Returns a transaction object. Call ``commit`` or ``rollback`` + on the returned object as appropriate.""" + + with Connection(self._engine.connect()) as conn: + tx = conn.transaction() + try: + yield conn + tx.commit() + except: + tx.rollback() + finally: + conn.close() + + +class Connection(object): + """A database connection.""" + + def __init__(self, connection): + self._conn = connection + self.open = not connection.closed + + def close(self): + self._conn.close() + self.open = False + + def __enter__(self): + return self + + def __exit__(self, exc, val, traceback): + self.close() + + def __repr__(self): + return ''.format(self.open) + def query(self, query, fetchall=False, **params): """Executes the given SQL query against the Database. Parameters can, optionally, be provided. Returns a RecordCollection, which can be @@ -272,7 +338,7 @@ class Database(object): """ # Execute the given query. - cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE + cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE # Row-by-row Record generator. row_gen = (Record(cursor.keys(), row) for row in cursor) @@ -289,7 +355,7 @@ class Database(object): def bulk_query(self, query, *multiparams): """Bulk insert or update.""" - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" @@ -324,12 +390,13 @@ class Database(object): with open(path) as f: query = f.read() - self.db.execute(text(query), *multiparams) + self._conn.execute(text(query), *multiparams) def transaction(self): """Returns a transaction object. Call ``commit`` or ``rollback`` on the returned object as appropriate.""" - return self.db.begin() + + return self._conn.begin() def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" diff --git a/tests/test_transactions.py b/tests/test_transactions.py index b416918..f2d5e92 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -1,29 +1,60 @@ +import pytest + import records db = records.Database('sqlite:///:memory:') -db.query('CREATE TABLE foo (a integer)') -def test_failing_transaction(): - tx = db.transaction() +@pytest.fixture +def table_setup(request): + db.query('CREATE TABLE foo (a integer)') + def drop_table(): + db.query('DROP TABLE foo') + request.addfinalizer(drop_table) + + +def test_failing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') raise ValueError() tx.commit() - db.query('INSERT INTO foo VALUES (44)') + conn.query('INSERT INTO foo VALUES (44)') except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 -def test_passing_transaction(): - tx = db.transaction() + +def test_failing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + raise ValueError() + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 + + +def test_passing_transaction_self_managed(table_setup): + conn = db.get_connection() + tx = conn.transaction() try: - db.query('INSERT INTO foo VALUES (42)') - db.query('INSERT INTO foo VALUES (43)') + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') tx.commit() except: tx.rollback() finally: + conn.close() assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2 + + +def test_passing_transaction(table_setup): + with db.transaction() as conn: + conn.query('INSERT INTO foo VALUES (42)') + conn.query('INSERT INTO foo VALUES (43)') + + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2 From 522199af20f297350179770df56fb5ce268151a2 Mon Sep 17 00:00:00 2001 From: Greg Troszak Date: Sat, 9 Dec 2017 08:48:14 -0800 Subject: [PATCH 2/3] Clean up. --- records.py | 60 +++++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/records.py b/records.py index b16d338..2fd0fdd 100644 --- a/records.py +++ b/records.py @@ -231,7 +231,9 @@ class RecordCollection(object): return record class Database(object): - """A Database connection.""" + """A Database. Encapsulates a url and an SQLAlchemy engine with a pool of + connections. + """ def __init__(self, db_url=None, **kwargs): # If no db_url was provided, fallback to $DATABASE_URL. @@ -240,13 +242,12 @@ class Database(object): if not self.db_url: raise ValueError('You must provide a db_url.') + # Create an engine. self._engine = create_engine(self.db_url, **kwargs) - - # Connect to the database. self.open = True def close(self): - """Closes the connection to the Database.""" + """Closes the Database.""" self._engine.dispose() self.open = False @@ -266,53 +267,54 @@ class Database(object): return inspect(self._engine).get_table_names() def get_connection(self): + """Get a connection to this Database. Connections are retrieved from a + pool. + """ return Connection(self._engine.connect()) def query(self, query, fetchall=False, **params): - """Executes the given SQL query against the Database. Parameters - can, optionally, be provided. Returns a RecordCollection, which can be + """Executes the given SQL query against the Database. Parameters can, + optionally, be provided. Returns a RecordCollection, which can be iterated over to get result rows as dictionaries. """ - - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: return conn.query(query, fetchall, **params) def bulk_query(self, query, *multiparams): """Bulk insert or update.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: conn.bulk_query(query, *multiparams) def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: return conn.query_file(path, fetchall, **params) def bulk_query_file(self, path, *multiparams): """Like Database.bulk_query, but takes a filename to load a query from.""" - with Connection(self._engine.connect()) as conn: + with self.get_connection() as conn: conn.bulk_query_file(path, *multiparams) @contextmanager def transaction(self): - """Returns a transaction object. Call ``commit`` or ``rollback`` - on the returned object as appropriate.""" + """A context manager for executing a transaction on this Database.""" - with Connection(self._engine.connect()) as conn: - tx = conn.transaction() - try: - yield conn - tx.commit() - except: - tx.rollback() - finally: - conn.close() + conn = self.get_connection() + tx = conn.transaction() + try: + yield conn + tx.commit() + except: + tx.rollback() + finally: + conn.close() class Connection(object): - """A database connection.""" + """A Database connection.""" def __init__(self, connection): self._conn = connection @@ -332,9 +334,9 @@ class Connection(object): return ''.format(self.open) def query(self, query, fetchall=False, **params): - """Executes the given SQL query against the Database. Parameters - can, optionally, be provided. Returns a RecordCollection, which can be - iterated over to get result rows as dictionaries. + """Executes the given SQL query against the connected Database. + Parameters can, optionally, be provided. Returns a RecordCollection, + which can be iterated over to get result rows as dictionaries. """ # Execute the given query. @@ -358,7 +360,7 @@ class Connection(object): self._conn.execute(text(query), *multiparams) def query_file(self, path, fetchall=False, **params): - """Like Database.query, but takes a filename to load a query from.""" + """Like Connection.query, but takes a filename to load a query from.""" # If path doesn't exists if not os.path.exists(path): @@ -376,7 +378,9 @@ class Connection(object): return self.query(query=query, fetchall=fetchall, **params) def bulk_query_file(self, path, *multiparams): - """Like Database.bulk_query, but takes a filename to load a query from.""" + """Like Connection.bulk_query, but takes a filename to load a query + from. + """ # If path doesn't exists if not os.path.exists(path): From 3d47e45f1e0236235680c15576768b4a2084aba2 Mon Sep 17 00:00:00 2001 From: Greg Troszak Date: Sat, 9 Dec 2017 13:14:59 -0800 Subject: [PATCH 3/3] Throw exception if db closed. --- records.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/records.py b/records.py index 2fd0fdd..15b4afe 100644 --- a/records.py +++ b/records.py @@ -7,7 +7,7 @@ from inspect import isclass import tablib from docopt import docopt -from sqlalchemy import create_engine, inspect, text +from sqlalchemy import create_engine, exc, inspect, text DATABASE_URL = os.environ.get('DATABASE_URL') @@ -270,6 +270,9 @@ class Database(object): """Get a connection to this Database. Connections are retrieved from a pool. """ + if not self.open: + raise exc.ResourceClosedError('Database closed.') + return Connection(self._engine.connect()) def query(self, query, fetchall=False, **params):