diff --git a/records.py b/records.py index c791b60..91039a6 100644 --- a/records.py +++ b/records.py @@ -270,6 +270,11 @@ class Database(object): # Defer processing to self.query method. return self.query(query=query, fetchall=fetchall, **params) + def transaction(self): + """Returns a transaction object. Call ``commit`` or ``rollback`` + on the returned object as appropriate.""" + return self.db.begin() + def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" @@ -280,7 +285,6 @@ def _reduce_datetimes(row): row[i] = row[i].isoformat() return tuple(row) - def cli(): cli_docs ="""Records: SQL for Humans™ A Kenneth Reitz project. diff --git a/tests/test_transactions.py b/tests/test_transactions.py new file mode 100644 index 0000000..a335d12 --- /dev/null +++ b/tests/test_transactions.py @@ -0,0 +1,31 @@ +import pytest + +import records + +db = records.Database('sqlite:///:memory:') + +db.query('CREATE TABLE foo (a integer)') + +def test_failing_transaction(): + tx = db.transaction() + try: + db.query('INSERT INTO foo VALUES (42)') + db.query('INSERT INTO foo VALUES (43)') + raise ValueError() + tx.commit() + db.query('INSERT INTO foo VALUES (44)') + except: + tx.rollback() + finally: + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 0 + +def test_passing_transaction(): + tx = db.transaction() + try: + db.query('INSERT INTO foo VALUES (42)') + db.query('INSERT INTO foo VALUES (43)') + tx.commit() + except: + tx.rollback() + finally: + assert db.query('SELECT count(*) AS n FROM foo')[0].n == 2