Properly handle connection pools with txs.

This commit is contained in:
Greg Troszak
2017-12-09 08:22:07 -08:00
parent fcd41c36c9
commit 89ef187b3c
2 changed files with 114 additions and 16 deletions
+73 -6
View File
@@ -2,6 +2,7 @@
import os
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass
import tablib
@@ -242,12 +243,11 @@ class Database(object):
self._engine = create_engine(self.db_url, **kwargs)
# Connect to the database.
self.db = self._engine.connect()
self.open = True
def close(self):
"""Closes the connection to the Database."""
self.db.close()
self._engine.dispose()
self.open = False
def __enter__(self):
@@ -265,6 +265,72 @@ class Database(object):
# Setup SQLAlchemy for Database inspection.
return inspect(self._engine).get_table_names()
def get_connection(self):
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with Connection(self._engine.connect()) as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with Connection(self._engine.connect()) as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with Connection(self._engine.connect()) as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with Connection(self._engine.connect()) as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
with Connection(self._engine.connect()) as conn:
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A database connection."""
def __init__(self, connection):
self._conn = connection
self.open = not connection.closed
def close(self):
self._conn.close()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return '<Connection open={}>'.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
@@ -272,7 +338,7 @@ class Database(object):
"""
# Execute the given query.
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
@@ -289,7 +355,7 @@ class Database(object):
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
@@ -324,12 +390,13 @@ class Database(object):
with open(path) as f:
query = f.read()
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
return self.db.begin()
return self._conn.begin()
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
+41 -10
View File
@@ -1,29 +1,60 @@
import pytest
import records
db = records.Database('sqlite:///:memory:')
db.query('CREATE TABLE foo (a integer)')
def test_failing_transaction():
tx = db.transaction()
@pytest.fixture
def table_setup(request):
db.query('CREATE TABLE foo (a integer)')
def drop_table():
db.query('DROP TABLE foo')
request.addfinalizer(drop_table)
def test_failing_transaction_self_managed(table_setup):
conn = db.get_connection()
tx = conn.transaction()
try:
db.query('INSERT INTO foo VALUES (42)')
db.query('INSERT INTO foo VALUES (43)')
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
raise ValueError()
tx.commit()
db.query('INSERT INTO foo VALUES (44)')
conn.query('INSERT INTO foo VALUES (44)')
except:
tx.rollback()
finally:
conn.close()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
def test_passing_transaction():
tx = db.transaction()
def test_failing_transaction(table_setup):
with db.transaction() as conn:
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
raise ValueError()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
def test_passing_transaction_self_managed(table_setup):
conn = db.get_connection()
tx = conn.transaction()
try:
db.query('INSERT INTO foo VALUES (42)')
db.query('INSERT INTO foo VALUES (43)')
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
tx.commit()
except:
tx.rollback()
finally:
conn.close()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
def test_passing_transaction(table_setup):
with db.transaction() as conn:
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2