mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 14:50:18 +00:00
Merge pull request #119 from gdtroszak/proper-pooling
Properly handle SQLAlchemy connection pools
This commit is contained in:
+89
-15
@@ -2,11 +2,12 @@
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from inspect import isclass
|
||||
|
||||
import tablib
|
||||
from docopt import docopt
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
from sqlalchemy import create_engine, exc, inspect, text
|
||||
|
||||
DATABASE_URL = os.environ.get('DATABASE_URL')
|
||||
|
||||
@@ -238,7 +239,9 @@ class RecordCollection(object):
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""A Database connection."""
|
||||
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
|
||||
connections.
|
||||
"""
|
||||
|
||||
def __init__(self, db_url=None, **kwargs):
|
||||
# If no db_url was provided, fallback to $DATABASE_URL.
|
||||
@@ -247,15 +250,13 @@ class Database(object):
|
||||
if not self.db_url:
|
||||
raise ValueError('You must provide a db_url.')
|
||||
|
||||
# Create an engine.
|
||||
self._engine = create_engine(self.db_url, **kwargs)
|
||||
|
||||
# Connect to the database.
|
||||
self.db = self._engine.connect()
|
||||
self.open = True
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection to the Database."""
|
||||
self.db.close()
|
||||
"""Closes the Database."""
|
||||
self._engine.dispose()
|
||||
self.open = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -273,14 +274,84 @@ class Database(object):
|
||||
# Setup SQLAlchemy for Database inspection.
|
||||
return inspect(self._engine).get_table_names()
|
||||
|
||||
def get_connection(self):
|
||||
"""Get a connection to this Database. Connections are retrieved from a
|
||||
pool.
|
||||
"""
|
||||
if not self.open:
|
||||
raise exc.ResourceClosedError('Database closed.')
|
||||
|
||||
return Connection(self._engine.connect())
|
||||
|
||||
def query(self, query, fetchall=False, **params):
|
||||
"""Executes the given SQL query against the Database. Parameters
|
||||
can, optionally, be provided. Returns a RecordCollection, which can be
|
||||
"""Executes the given SQL query against the Database. Parameters can,
|
||||
optionally, be provided. Returns a RecordCollection, which can be
|
||||
iterated over to get result rows as dictionaries.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
return conn.query(query, fetchall, **params)
|
||||
|
||||
def bulk_query(self, query, *multiparams):
|
||||
"""Bulk insert or update."""
|
||||
|
||||
with self.get_connection() as conn:
|
||||
conn.bulk_query(query, *multiparams)
|
||||
|
||||
def query_file(self, path, fetchall=False, **params):
|
||||
"""Like Database.query, but takes a filename to load a query from."""
|
||||
|
||||
with self.get_connection() as conn:
|
||||
return conn.query_file(path, fetchall, **params)
|
||||
|
||||
def bulk_query_file(self, path, *multiparams):
|
||||
"""Like Database.bulk_query, but takes a filename to load a query from."""
|
||||
|
||||
with self.get_connection() as conn:
|
||||
conn.bulk_query_file(path, *multiparams)
|
||||
|
||||
@contextmanager
|
||||
def transaction(self):
|
||||
"""A context manager for executing a transaction on this Database."""
|
||||
|
||||
conn = self.get_connection()
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
yield conn
|
||||
tx.commit()
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""A Database connection."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self._conn = connection
|
||||
self.open = not connection.closed
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
self.open = False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc, val, traceback):
|
||||
self.close()
|
||||
|
||||
def __repr__(self):
|
||||
return '<Connection open={}>'.format(self.open)
|
||||
|
||||
def query(self, query, fetchall=False, **params):
|
||||
"""Executes the given SQL query against the connected Database.
|
||||
Parameters can, optionally, be provided. Returns a RecordCollection,
|
||||
which can be iterated over to get result rows as dictionaries.
|
||||
"""
|
||||
|
||||
# Execute the given query.
|
||||
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
|
||||
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
|
||||
|
||||
# Row-by-row Record generator.
|
||||
row_gen = (Record(cursor.keys(), row) for row in cursor)
|
||||
@@ -297,10 +368,10 @@ class Database(object):
|
||||
def bulk_query(self, query, *multiparams):
|
||||
"""Bulk insert or update."""
|
||||
|
||||
self.db.execute(text(query), *multiparams)
|
||||
self._conn.execute(text(query), *multiparams)
|
||||
|
||||
def query_file(self, path, fetchall=False, **params):
|
||||
"""Like Database.query, but takes a filename to load a query from."""
|
||||
"""Like Connection.query, but takes a filename to load a query from."""
|
||||
|
||||
# If path doesn't exists
|
||||
if not os.path.exists(path):
|
||||
@@ -318,7 +389,9 @@ class Database(object):
|
||||
return self.query(query=query, fetchall=fetchall, **params)
|
||||
|
||||
def bulk_query_file(self, path, *multiparams):
|
||||
"""Like Database.bulk_query, but takes a filename to load a query from."""
|
||||
"""Like Connection.bulk_query, but takes a filename to load a query
|
||||
from.
|
||||
"""
|
||||
|
||||
# If path doesn't exists
|
||||
if not os.path.exists(path):
|
||||
@@ -332,12 +405,13 @@ class Database(object):
|
||||
with open(path) as f:
|
||||
query = f.read()
|
||||
|
||||
self.db.execute(text(query), *multiparams)
|
||||
self._conn.execute(text(query), *multiparams)
|
||||
|
||||
def transaction(self):
|
||||
"""Returns a transaction object. Call ``commit`` or ``rollback``
|
||||
on the returned object as appropriate."""
|
||||
return self.db.begin()
|
||||
|
||||
return self._conn.begin()
|
||||
|
||||
def _reduce_datetimes(row):
|
||||
"""Receives a row, converts datetimes to strings."""
|
||||
|
||||
+41
-10
@@ -1,29 +1,60 @@
|
||||
import pytest
|
||||
|
||||
import records
|
||||
|
||||
db = records.Database('sqlite:///:memory:')
|
||||
|
||||
db.query('CREATE TABLE foo (a integer)')
|
||||
|
||||
def test_failing_transaction():
|
||||
tx = db.transaction()
|
||||
@pytest.fixture
|
||||
def table_setup(request):
|
||||
db.query('CREATE TABLE foo (a integer)')
|
||||
def drop_table():
|
||||
db.query('DROP TABLE foo')
|
||||
request.addfinalizer(drop_table)
|
||||
|
||||
|
||||
def test_failing_transaction_self_managed(table_setup):
|
||||
conn = db.get_connection()
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
db.query('INSERT INTO foo VALUES (42)')
|
||||
db.query('INSERT INTO foo VALUES (43)')
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
raise ValueError()
|
||||
tx.commit()
|
||||
db.query('INSERT INTO foo VALUES (44)')
|
||||
conn.query('INSERT INTO foo VALUES (44)')
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
|
||||
|
||||
def test_passing_transaction():
|
||||
tx = db.transaction()
|
||||
|
||||
def test_failing_transaction(table_setup):
|
||||
with db.transaction() as conn:
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
raise ValueError()
|
||||
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0
|
||||
|
||||
|
||||
def test_passing_transaction_self_managed(table_setup):
|
||||
conn = db.get_connection()
|
||||
tx = conn.transaction()
|
||||
try:
|
||||
db.query('INSERT INTO foo VALUES (42)')
|
||||
db.query('INSERT INTO foo VALUES (43)')
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
tx.commit()
|
||||
except:
|
||||
tx.rollback()
|
||||
finally:
|
||||
conn.close()
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
|
||||
|
||||
|
||||
def test_passing_transaction(table_setup):
|
||||
with db.transaction() as conn:
|
||||
conn.query('INSERT INTO foo VALUES (42)')
|
||||
conn.query('INSERT INTO foo VALUES (43)')
|
||||
|
||||
assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2
|
||||
|
||||
Reference in New Issue
Block a user