diff --git a/records.py b/records.py index 1fd6a93..5334a66 100644 --- a/records.py +++ b/records.py @@ -280,21 +280,21 @@ class Database(object): # Setup SQLAlchemy for Database inspection. return inspect(self._engine).get_table_names() - def get_connection(self): + def get_connection(self, close_with_result=False): """Get a connection to this Database. Connections are retrieved from a pool. """ if not self.open: raise exc.ResourceClosedError('Database closed.') - return Connection(self._engine.connect()) + return Connection(self._engine.connect(close_with_result=close_with_result), close_with_result) def query(self, query, fetchall=False, **params): """Executes the given SQL query against the Database. Parameters can, optionally, be provided. Returns a RecordCollection, which can be iterated over to get result rows as dictionaries. """ - with self.get_connection() as conn: + with self.get_connection(True) as conn: return conn.query(query, fetchall, **params) def bulk_query(self, query, *multiparams): @@ -306,7 +306,7 @@ class Database(object): def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" - with self.get_connection() as conn: + with self.get_connection(True) as conn: return conn.query_file(path, fetchall, **params) def bulk_query_file(self, path, *multiparams): @@ -333,12 +333,16 @@ class Database(object): class Connection(object): """A Database connection.""" - def __init__(self, connection): + def __init__(self, connection, close_with_result=False): self._conn = connection self.open = not connection.closed + self._close_with_result = close_with_result def close(self): - self._conn.close() + # No need to close if this connection is used for a single result. + # The connection will close when the results are all consumed or GCed. + if not self._close_with_result: + self._conn.close() self.open = False def __enter__(self):