mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 23:00:20 +00:00
Properly handle connection pools with txs.
This commit is contained in:
+73
-6
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from inspect import isclass
|
||||
|
||||
import tablib
|
||||
@@ -242,12 +243,11 @@ class Database(object):
|
||||
self._engine = create_engine(self.db_url, **kwargs)
|
||||
|
||||
# Connect to the database.
|
||||
self.db = self._engine.connect()
|
||||
self.open = True
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection to the Database."""
|
||||
self.db.close()
|
||||
self._engine.dispose()
|
||||
self.open = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -265,6 +265,72 @@ class Database(object):
|
||||
# Setup SQLAlchemy for Database inspection.
|
||||
return inspect(self._engine).get_table_names()
|
||||
|
||||
def get_connection(self):
|
||||
return Connection(self._engine.connect())
|
||||
|
||||
def query(self, query, fetchall=False, **params):
|
||||
"""Executes the given SQL query against the Database. Parameters
|
||||
can, optionally, be provided. Returns a RecordCollection, which can be
|
||||
iterated over to get result rows as dictionaries.
|
||||
"""
|
||||
|
||||
with Connection(self._engine.connect()) as conn:
|
||||
return conn.query(query, fetchall, **params)
|
||||
|
||||
def bulk_query(self, query, *multiparams):
|
||||
"""Bulk insert or update."""
|
||||
|
||||
with Connection(self._engine.connect()) as conn:
|
||||
conn.bulk_query(query, *multiparams)
|
||||
|
||||
def query_file(self, path, fetchall=False, **params):
|
||||
"""Like Database.query, but takes a filename to load a query from."""
|
||||
|
||||
with Connection(self._engine.connect()) as conn:
|
||||
return conn.query_file(path, fetchall, **params)
|
||||
|
||||
def bulk_query_file(self, path, *multiparams):
|
||||
"""Like Database.bulk_query, but takes a filename to load a query from."""
|
||||
|
||||
with Connection(self._engine.connect()) as conn:
|
||||
conn.bulk_query_file(path, *multiparams)
|
||||
|
||||
@contextmanager
|
||||
def transaction(self):
|
||||
"""Returns a transaction object. Call ``commit`` or ``rollback``
|
||||
on the returned object as appropriate."""
|
||||
|
||||
with Connection(self._engine.connect()) as conn:
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
yield conn
|
||||
tx.commit()
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""A database connection."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self._conn = connection
|
||||
self.open = not connection.closed
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
self.open = False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc, val, traceback):
|
||||
self.close()
|
||||
|
||||
def __repr__(self):
|
||||
return '<Connection open={}>'.format(self.open)
|
||||
|
||||
def query(self, query, fetchall=False, **params):
|
||||
"""Executes the given SQL query against the Database. Parameters
|
||||
can, optionally, be provided. Returns a RecordCollection, which can be
|
||||
@@ -272,7 +338,7 @@ class Database(object):
|
||||
"""
|
||||
|
||||
# Execute the given query.
|
||||
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
|
||||
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
|
||||
|
||||
# Row-by-row Record generator.
|
||||
row_gen = (Record(cursor.keys(), row) for row in cursor)
|
||||
@@ -289,7 +355,7 @@ class Database(object):
|
||||
def bulk_query(self, query, *multiparams):
|
||||
"""Bulk insert or update."""
|
||||
|
||||
self.db.execute(text(query), *multiparams)
|
||||
self._conn.execute(text(query), *multiparams)
|
||||
|
||||
def query_file(self, path, fetchall=False, **params):
|
||||
"""Like Database.query, but takes a filename to load a query from."""
|
||||
@@ -324,12 +390,13 @@ class Database(object):
|
||||
with open(path) as f:
|
||||
query = f.read()
|
||||
|
||||
self.db.execute(text(query), *multiparams)
|
||||
self._conn.execute(text(query), *multiparams)
|
||||
|
||||
def transaction(self):
|
||||
"""Returns a transaction object. Call ``commit`` or ``rollback``
|
||||
on the returned object as appropriate."""
|
||||
return self.db.begin()
|
||||
|
||||
return self._conn.begin()
|
||||
|
||||
def _reduce_datetimes(row):
|
||||
"""Receives a row, converts datetimes to strings."""
|
||||
|
||||
+41
-10
@@ -1,29 +1,60 @@
|
||||
import pytest
|
||||
|
||||
import records
|
||||
|
||||
db = records.Database('sqlite:///:memory:')
|
||||
|
||||
db.query('CREATE TABLE foo (a integer)')
|
||||
|
||||
def test_failing_transaction():
|
||||
tx = db.transaction()
|
||||
@pytest.fixture
|
||||
def table_setup(request):
|
||||
db.query('CREATE TABLE foo (a integer)')
|
||||
def drop_table():
|
||||
db.query('DROP TABLE foo')
|
||||
request.addfinalizer(drop_table)
|
||||
|
||||
|
||||
def test_failing_transaction_self_managed(table_setup):
|
||||
conn = db.get_connection()
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
db.query('INSERT INTO foo VALUES (42)')
|
||||
db.query('INSERT INTO foo VALUES (43)')
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
raise ValueError()
|
||||
tx.commit()
|
||||
db.query('INSERT INTO foo VALUES (44)')
|
||||
conn.query('INSERT INTO foo VALUES (44)')
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
|
||||
|
||||
def test_passing_transaction():
|
||||
tx = db.transaction()
|
||||
|
||||
def test_failing_transaction(table_setup):
|
||||
with db.transaction() as conn:
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
raise ValueError()
|
||||
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
|
||||
|
||||
|
||||
def test_passing_transaction_self_managed(table_setup):
|
||||
conn = db.get_connection()
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
db.query('INSERT INTO foo VALUES (42)')
|
||||
db.query('INSERT INTO foo VALUES (43)')
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
tx.commit()
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
|
||||
|
||||
|
||||
def test_passing_transaction(table_setup):
|
||||
with db.transaction() as conn:
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
|
||||
|
||||
Reference in New Issue
Block a user