Merge pull request #119 from gdtroszak/proper-pooling

Properly handle SQLAlchemy connection pools
This commit is contained in:
2018-02-13 07:39:42 -06:00
committed by GitHub
2 changed files with 130 additions and 25 deletions
+89 -15
View File
@@ -2,11 +2,12 @@
import os
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass
import tablib
from docopt import docopt
from sqlalchemy import create_engine, inspect, text
from sqlalchemy import create_engine, exc, inspect, text
DATABASE_URL = os.environ.get('DATABASE_URL')
@@ -238,7 +239,9 @@ class RecordCollection(object):
class Database(object):
"""A Database connection."""
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
"""
def __init__(self, db_url=None, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL.
@@ -247,15 +250,13 @@ class Database(object):
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Create an engine.
self._engine = create_engine(self.db_url, **kwargs)
# Connect to the database.
self.db = self._engine.connect()
self.open = True
def close(self):
"""Closes the connection to the Database."""
self.db.close()
"""Closes the Database."""
self._engine.dispose()
self.open = False
def __enter__(self):
@@ -273,14 +274,84 @@ class Database(object):
# Setup SQLAlchemy for Database inspection.
return inspect(self._engine).get_table_names()
def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
if not self.open:
raise exc.ResourceClosedError('Database closed.')
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
"""Executes the given SQL query against the Database. Parameters can,
optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with self.get_connection() as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with self.get_connection() as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""A context manager for executing a transaction on this Database."""
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A Database connection."""
def __init__(self, connection):
self._conn = connection
self.open = not connection.closed
def close(self):
self._conn.close()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return '<Connection open={}>'.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the connected Database.
Parameters can, optionally, be provided. Returns a RecordCollection,
which can be iterated over to get result rows as dictionaries.
"""
# Execute the given query.
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
@@ -297,10 +368,10 @@ class Database(object):
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
"""Like Connection.query, but takes a filename to load a query from."""
# If path doesn't exists
if not os.path.exists(path):
@@ -318,7 +389,9 @@ class Database(object):
return self.query(query=query, fetchall=fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
"""Like Connection.bulk_query, but takes a filename to load a query
from.
"""
# If path doesn't exists
if not os.path.exists(path):
@@ -332,12 +405,13 @@ class Database(object):
with open(path) as f:
query = f.read()
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
return self.db.begin()
return self._conn.begin()
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
+41 -10
View File
@@ -1,29 +1,60 @@
import pytest
import records
db = records.Database('sqlite:///:memory:')
db.query('CREATE TABLE foo (a integer)')
def test_failing_transaction():
tx = db.transaction()
@pytest.fixture
def table_setup(request):
db.query('CREATE TABLE foo (a integer)')
def drop_table():
db.query('DROP TABLE foo')
request.addfinalizer(drop_table)
def test_failing_transaction_self_managed(table_setup):
conn = db.get_connection()
tx = conn.transaction()
try:
db.query('INSERT INTO foo VALUES (42)')
db.query('INSERT INTO foo VALUES (43)')
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
raise ValueError()
tx.commit()
db.query('INSERT INTO foo VALUES (44)')
conn.query('INSERT INTO foo VALUES (44)')
except:
tx.rollback()
finally:
conn.close()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
def test_passing_transaction():
tx = db.transaction()
def test_failing_transaction(table_setup):
with db.transaction() as conn:
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
raise ValueError()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
def test_passing_transaction_self_managed(table_setup):
conn = db.get_connection()
tx = conn.transaction()
try:
db.query('INSERT INTO foo VALUES (42)')
db.query('INSERT INTO foo VALUES (43)')
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
tx.commit()
except:
tx.rollback()
finally:
conn.close()
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
def test_passing_transaction(table_setup):
with db.transaction() as conn:
conn.query('INSERT INTO foo VALUES (42)')
conn.query('INSERT INTO foo VALUES (43)')
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2