mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 14:50:18 +00:00
Add scalar method to RecordCollection
This commit is contained in:
@@ -231,6 +231,12 @@ class RecordCollection(object):
|
||||
else:
|
||||
return record
|
||||
|
||||
def scalar(self, default=None):
|
||||
"""Returns the first column of the first row, or `default`."""
|
||||
first = self.first()
|
||||
return first[0] if first else default
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""A Database connection."""
|
||||
|
||||
|
||||
@@ -86,6 +86,24 @@ class TestRecordCollection:
|
||||
class Cheese(Exception): pass
|
||||
raises(Cheese, rows.first, Cheese('cheddar'))
|
||||
|
||||
# scalar
|
||||
|
||||
def test_scalar_returns_a_single_record(self):
|
||||
rows = records.RecordCollection(IdRecord(i) for i in range(1))
|
||||
assert rows.scalar() == 0
|
||||
|
||||
def test_scalar_defaults_to_None(self):
|
||||
rows = records.RecordCollection(iter([]))
|
||||
assert rows.scalar() is None
|
||||
|
||||
def test_scalar_default_is_overridable(self):
|
||||
rows = records.RecordCollection(iter([]))
|
||||
assert rows.scalar('Kaffe') == 'Kaffe'
|
||||
|
||||
def test_scalar_raises_when_more_than_first(self):
|
||||
rows = records.RecordCollection(IdRecord(i) for i in range(3))
|
||||
raises(ValueError, rows.scalar)
|
||||
|
||||
|
||||
class TestRecord:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user