mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 06:46:17 +00:00
Add a .one method to RecordCollections
It's a common use-case to want one and only one result from a query. This adds a .one method to RecordCollections that is parallel to .all.
This commit is contained in:
+24
@@ -188,6 +188,30 @@ class RecordCollection(object):
|
||||
def as_dict(self, ordered=False):
|
||||
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
|
||||
|
||||
def one(self, default=None, as_dict=False, as_ordereddict=False):
|
||||
"""Returns a single record for the RecordCollection, or `default`."""
|
||||
|
||||
# Try to get a record, or return default.
|
||||
try:
|
||||
record = next(self)
|
||||
except StopIteration:
|
||||
return default
|
||||
|
||||
# Ensure that we don't have more than one row.
|
||||
try:
|
||||
next(self)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
raise ValueError('RecordCollection contains too many rows.')
|
||||
|
||||
# Cast and return.
|
||||
if as_dict:
|
||||
return record.as_dict()
|
||||
elif as_ordereddict:
|
||||
return record.as_dict(ordered=True)
|
||||
else:
|
||||
return record
|
||||
|
||||
class Database(object):
|
||||
"""A Database connection."""
|
||||
|
||||
@@ -2,6 +2,8 @@ from collections import namedtuple
|
||||
|
||||
import records
|
||||
|
||||
from pytest import raises
|
||||
|
||||
|
||||
IdRecord = namedtuple('IdRecord', 'id')
|
||||
|
||||
@@ -49,6 +51,32 @@ class TestRecordCollection:
|
||||
assert len(rows) == 10
|
||||
|
||||
|
||||
# all
|
||||
|
||||
def test_all_returns_a_list_of_records(self):
|
||||
rows = records.RecordCollection(IdRecord(i) for i in range(3))
|
||||
assert rows.all() == [IdRecord(0), IdRecord(1), IdRecord(2)]
|
||||
|
||||
|
||||
# one
|
||||
|
||||
def test_one_returns_a_single_record(self):
|
||||
rows = records.RecordCollection(IdRecord(i) for i in range(1))
|
||||
assert rows.one() == IdRecord(0)
|
||||
|
||||
def test_one_defaults_to_None(self):
|
||||
rows = records.RecordCollection(iter([]))
|
||||
assert rows.one() is None
|
||||
|
||||
def test_one_default_is_overridable(self):
|
||||
rows = records.RecordCollection(iter([]))
|
||||
assert rows.one('Cheese') == 'Cheese'
|
||||
|
||||
def test_one_raises_when_more_than_one(self):
|
||||
rows = records.RecordCollection(IdRecord(i) for i in range(3))
|
||||
raises(ValueError, rows.one)
|
||||
|
||||
|
||||
class TestRecord:
|
||||
|
||||
def test_record_dir(self):
|
||||
|
||||
Reference in New Issue
Block a user