diff --git a/records.py b/records.py index 3041978..3735773 100644 --- a/records.py +++ b/records.py @@ -208,6 +208,27 @@ class RecordCollection(object): `default` is an instance or subclass of Exception, then raise it instead of returning it.""" + # Try to get a record, or return/raise default. + try: + record = self[0] + except IndexError: + if isexception(default): + raise default + return default + + # Cast and return. + if as_dict: + return record.as_dict() + elif as_ordereddict: + return record.as_dict(ordered=True) + else: + return record + + def one(self, default=None, as_dict=False, as_ordereddict=False): + """Returns a single record for the RecordCollection, ensuring that it + is the only record, or returns `default`. If `default` is an instance + or subclass of Exception, then raise it instead of returning it.""" + # Try to get a record, or return/raise default. try: record = self[0] @@ -234,8 +255,8 @@ class RecordCollection(object): def scalar(self, default=None): """Returns the first column of the first row, or `default`.""" - first = self.first() - return first[0] if first else default + row = self.one() + return row[0] if row else default class Database(object): diff --git a/tests/test_records.py b/tests/test_records.py index 5bfecf6..6c6aca5 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -72,10 +72,6 @@ class TestRecordCollection: rows = records.RecordCollection(iter([])) assert rows.first('Cheese') == 'Cheese' - def test_first_raises_when_more_than_first(self): - rows = records.RecordCollection(IdRecord(i) for i in range(3)) - raises(ValueError, rows.first) - def test_first_raises_default_if_its_an_exception_subclass(self): rows = records.RecordCollection(iter([])) class Cheese(Exception): pass @@ -86,6 +82,34 @@ class TestRecordCollection: class Cheese(Exception): pass raises(Cheese, rows.first, Cheese('cheddar')) + # one + + def test_one_returns_a_single_record(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + assert rows.one() == IdRecord(0) + + def test_one_defaults_to_None(self): + rows = records.RecordCollection(iter([])) + assert rows.one() is None + + def test_one_default_is_overridable(self): + rows = records.RecordCollection(iter([])) + assert rows.one('Cheese') == 'Cheese' + + def test_one_raises_when_more_than_one(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + raises(ValueError, rows.one) + + def test_one_raises_default_if_its_an_exception_subclass(self): + rows = records.RecordCollection(iter([])) + class Cheese(Exception): pass + raises(Cheese, rows.one, Cheese) + + def test_one_raises_default_if_its_an_exception_instance(self): + rows = records.RecordCollection(iter([])) + class Cheese(Exception): pass + raises(Cheese, rows.one, Cheese('cheddar')) + # scalar def test_scalar_returns_a_single_record(self): @@ -100,7 +124,7 @@ class TestRecordCollection: rows = records.RecordCollection(iter([])) assert rows.scalar('Kaffe') == 'Kaffe' - def test_scalar_raises_when_more_than_first(self): + def test_scalar_raises_when_more_than_one(self): rows = records.RecordCollection(IdRecord(i) for i in range(3)) raises(ValueError, rows.scalar)