Merge pull request #119 from gdtroszak/proper-pooling

Properly handle SQLAlchemy connection pools
This commit is contained in:
2018-02-13 07:39:42 -06:00
committed by GitHub
2 changed files with 130 additions and 25 deletions
+89 -15
View File
@@ -2,11 +2,12 @@
import os
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass
import tablib
from docopt import docopt
from sqlalchemy import create_engine, inspect, text
from sqlalchemy import create_engine, exc, inspect, text
DATABASE_URL = os.environ.get('DATABASE_URL')
@@ -238,7 +239,9 @@ class RecordCollection(object):
class Database(object):
"""A Database connection."""
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
"""
def __init__(self, db_url=None, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL.
@@ -247,15 +250,13 @@ class Database(object):
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Create an engine.
self._engine = create_engine(self.db_url, **kwargs)
# Connect to the database.
self.db = self._engine.connect()
self.open = True
def close(self):
"""Closes the connection to the Database."""
self.db.close()
"""Closes the Database."""
self._engine.dispose()
self.open = False
def __enter__(self):
@@ -273,14 +274,84 @@ class Database(object):
# Setup SQLAlchemy for Database inspection.
return inspect(self._engine).get_table_names()
def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
if not self.open:
raise exc.ResourceClosedError('Database closed.')
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a RecordCollection, which can be
"""Executes the given SQL query against the Database. Parameters can,
optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with self.get_connection() as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with self.get_connection() as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""A context manager for executing a transaction on this Database."""
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A Database connection."""
def __init__(self, connection):
self._conn = connection
self.open = not connection.closed
def close(self):
self._conn.close()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return '<Connection open={}>'.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the connected Database.
Parameters can, optionally, be provided. Returns a RecordCollection,
which can be iterated over to get result rows as dictionaries.
"""
# Execute the given query.
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
@@ -297,10 +368,10 @@ class Database(object):
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
"""Like Connection.query, but takes a filename to load a query from."""
# If path doesn't exists
if not os.path.exists(path):
@@ -318,7 +389,9 @@ class Database(object):
return self.query(query=query, fetchall=fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
"""Like Connection.bulk_query, but takes a filename to load a query
from.
"""
# If path doesn't exists
if not os.path.exists(path):
@@ -332,12 +405,13 @@ class Database(object):
with open(path) as f:
query = f.read()
self.db.execute(text(query), *multiparams)
self._conn.execute(text(query), *multiparams)
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
return self.db.begin()
return self._conn.begin()
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""