From d1b76caaab598c7a3ea2aba1523f255f819c49ce Mon Sep 17 00:00:00 2001 From: Joakim Uddholm Date: Sun, 11 Feb 2018 17:27:57 +0100 Subject: [PATCH] Add scalar method to RecordCollection --- records.py | 6 ++++++ tests/test_records.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/records.py b/records.py index 0fe09ca..3fedb0e 100644 --- a/records.py +++ b/records.py @@ -231,6 +231,12 @@ class RecordCollection(object): else: return record + def scalar(self, default=None): + """Returns the first column of the first row, or `default`.""" + first = self.first() + return first[0] if first else default + + class Database(object): """A Database connection.""" diff --git a/tests/test_records.py b/tests/test_records.py index 7e42731..5bfecf6 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -86,6 +86,24 @@ class TestRecordCollection: class Cheese(Exception): pass raises(Cheese, rows.first, Cheese('cheddar')) + # scalar + + def test_scalar_returns_a_single_record(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + assert rows.scalar() == 0 + + def test_scalar_defaults_to_None(self): + rows = records.RecordCollection(iter([])) + assert rows.scalar() is None + + def test_scalar_default_is_overridable(self): + rows = records.RecordCollection(iter([])) + assert rows.scalar('Kaffe') == 'Kaffe' + + def test_scalar_raises_when_more_than_first(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + raises(ValueError, rows.scalar) + class TestRecord: