mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 06:46:17 +00:00
Add a .one method to RecordCollections
It's a common use-case to want one and only one result from a query. This adds a .one method to RecordCollections that is parallel to .all.
This commit is contained in:
+24
@@ -188,6 +188,30 @@ class RecordCollection(object):
|
|||||||
def as_dict(self, ordered=False):
|
def as_dict(self, ordered=False):
|
||||||
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
|
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
|
||||||
|
|
||||||
|
def one(self, default=None, as_dict=False, as_ordereddict=False):
|
||||||
|
"""Returns a single record for the RecordCollection, or `default`."""
|
||||||
|
|
||||||
|
# Try to get a record, or return default.
|
||||||
|
try:
|
||||||
|
record = next(self)
|
||||||
|
except StopIteration:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Ensure that we don't have more than one row.
|
||||||
|
try:
|
||||||
|
next(self)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError('RecordCollection contains too many rows.')
|
||||||
|
|
||||||
|
# Cast and return.
|
||||||
|
if as_dict:
|
||||||
|
return record.as_dict()
|
||||||
|
elif as_ordereddict:
|
||||||
|
return record.as_dict(ordered=True)
|
||||||
|
else:
|
||||||
|
return record
|
||||||
|
|
||||||
class Database(object):
|
class Database(object):
|
||||||
"""A Database connection."""
|
"""A Database connection."""
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from collections import namedtuple
|
|||||||
|
|
||||||
import records
|
import records
|
||||||
|
|
||||||
|
from pytest import raises
|
||||||
|
|
||||||
|
|
||||||
IdRecord = namedtuple('IdRecord', 'id')
|
IdRecord = namedtuple('IdRecord', 'id')
|
||||||
|
|
||||||
@@ -49,6 +51,32 @@ class TestRecordCollection:
|
|||||||
assert len(rows) == 10
|
assert len(rows) == 10
|
||||||
|
|
||||||
|
|
||||||
|
# all
|
||||||
|
|
||||||
|
def test_all_returns_a_list_of_records(self):
|
||||||
|
rows = records.RecordCollection(IdRecord(i) for i in range(3))
|
||||||
|
assert rows.all() == [IdRecord(0), IdRecord(1), IdRecord(2)]
|
||||||
|
|
||||||
|
|
||||||
|
# one
|
||||||
|
|
||||||
|
def test_one_returns_a_single_record(self):
|
||||||
|
rows = records.RecordCollection(IdRecord(i) for i in range(1))
|
||||||
|
assert rows.one() == IdRecord(0)
|
||||||
|
|
||||||
|
def test_one_defaults_to_None(self):
|
||||||
|
rows = records.RecordCollection(iter([]))
|
||||||
|
assert rows.one() is None
|
||||||
|
|
||||||
|
def test_one_default_is_overridable(self):
|
||||||
|
rows = records.RecordCollection(iter([]))
|
||||||
|
assert rows.one('Cheese') == 'Cheese'
|
||||||
|
|
||||||
|
def test_one_raises_when_more_than_one(self):
|
||||||
|
rows = records.RecordCollection(IdRecord(i) for i in range(3))
|
||||||
|
raises(ValueError, rows.one)
|
||||||
|
|
||||||
|
|
||||||
class TestRecord:
|
class TestRecord:
|
||||||
|
|
||||||
def test_record_dir(self):
|
def test_record_dir(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user