diff --git a/records.py b/records.py index b599e9a..bea5fa2 100644 --- a/records.py +++ b/records.py @@ -4,6 +4,7 @@ import os from code import interact from datetime import datetime from collections import OrderedDict +from inspect import isclass import tablib from docopt import docopt @@ -13,6 +14,17 @@ from sqlalchemy.ext.declarative import declarative_base DATABASE_URL = os.environ.get('DATABASE_URL') +def isexception(obj): + """Given an object, return a boolean indicating whether it is an instance + or subclass of :py:class:`Exception`. + """ + if isinstance(obj, Exception): + return True + if isclass(obj) and issubclass(obj, Exception): + return True + return False + + class Record(object): """A row, from a query, from a database.""" __slots__ = ('_keys', '_values') @@ -195,6 +207,8 @@ class RecordCollection(object): try: record = next(self) except StopIteration: + if isexception(default): + raise default return default # Ensure that we don't have more than one row. diff --git a/tests/test_records.py b/tests/test_records.py index 8df6f78..9166446 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -76,6 +76,16 @@ class TestRecordCollection: rows = records.RecordCollection(IdRecord(i) for i in range(3)) raises(ValueError, rows.one) + def test_one_raises_default_if_its_an_exception_subclass(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + class Cheese(Exception): pass + raises(Cheese, rows.one, Cheese) + + def test_one_raises_default_if_its_an_exception_instance(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + class Cheese(Exception): pass + raises(Cheese, rows.one, Cheese('cheddar')) + class TestRecord: