From e7da94b249a014c97ce01d6168cb1d11b6fbe7a3 Mon Sep 17 00:00:00 2001 From: Chad Whitacre Date: Sat, 13 Feb 2016 04:34:35 -0500 Subject: [PATCH] Add a .one method to RecordCollections It's a common use-case to want one and only one result from a query. This adds a .one method to RecordCollections that is parallel to .all. --- records.py | 24 ++++++++++++++++++++++++ tests/test_records.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/records.py b/records.py index 74c50c4..b599e9a 100644 --- a/records.py +++ b/records.py @@ -188,6 +188,30 @@ class RecordCollection(object): def as_dict(self, ordered=False): return self.all(as_dict=not(ordered), as_ordereddict=ordered) + def one(self, default=None, as_dict=False, as_ordereddict=False): + """Returns a single record for the RecordCollection, or `default`.""" + + # Try to get a record, or return default. + try: + record = next(self) + except StopIteration: + return default + + # Ensure that we don't have more than one row. + try: + next(self) + except StopIteration: + pass + else: + raise ValueError('RecordCollection contains too many rows.') + + # Cast and return. + if as_dict: + return record.as_dict() + elif as_ordereddict: + return record.as_dict(ordered=True) + else: + return record class Database(object): """A Database connection.""" diff --git a/tests/test_records.py b/tests/test_records.py index 1af26e9..8df6f78 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -2,6 +2,8 @@ from collections import namedtuple import records +from pytest import raises + IdRecord = namedtuple('IdRecord', 'id') @@ -49,6 +51,32 @@ class TestRecordCollection: assert len(rows) == 10 + # all + + def test_all_returns_a_list_of_records(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + assert rows.all() == [IdRecord(0), IdRecord(1), IdRecord(2)] + + + # one + + def test_one_returns_a_single_record(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + assert rows.one() == IdRecord(0) + + def test_one_defaults_to_None(self): + rows = records.RecordCollection(iter([])) + assert rows.one() is None + + def test_one_default_is_overridable(self): + rows = records.RecordCollection(iter([])) + assert rows.one('Cheese') == 'Cheese' + + def test_one_raises_when_more_than_one(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + raises(ValueError, rows.one) + + class TestRecord: def test_record_dir(self):