diff --git a/README.rst b/README.rst index 3087192..e8df2d3 100644 --- a/README.rst +++ b/README.rst @@ -49,6 +49,13 @@ Or store a copy of your record collection for later reference: >>> rows.all() [, , , ...] +If you're only expecting one result: + +.. code:: python + + >>> rows.first() + + Other options include ``rows.as_dict()`` and ``rows.as_dict(ordered=True)``. ☤ Features diff --git a/records.py b/records.py index 74c50c4..f65e8ff 100644 --- a/records.py +++ b/records.py @@ -1,9 +1,8 @@ # -*- coding: utf-8 -*- import os -from code import interact -from datetime import datetime from collections import OrderedDict +from inspect import isclass import tablib from docopt import docopt @@ -13,6 +12,17 @@ from sqlalchemy.ext.declarative import declarative_base DATABASE_URL = os.environ.get('DATABASE_URL') +def isexception(obj): + """Given an object, return a boolean indicating whether it is an instance + or subclass of :py:class:`Exception`. + """ + if isinstance(obj, Exception): + return True + if isclass(obj) and issubclass(obj, Exception): + return True + return False + + class Record(object): """A row, from a query, from a database.""" __slots__ = ('_keys', '_values') @@ -188,6 +198,34 @@ class RecordCollection(object): def as_dict(self, ordered=False): return self.all(as_dict=not(ordered), as_ordereddict=ordered) + def first(self, default=None, as_dict=False, as_ordereddict=False): + """Returns a single record for the RecordCollection, or `default`. If + `default` is an instance or subclass of Exception, then raise it + instead of returning it.""" + + # Try to get a record, or return/raise default. + try: + record = self[0] + except IndexError: + if isexception(default): + raise default + return default + + # Ensure that we don't have more than one row. + try: + self[1] + except IndexError: + pass + else: + raise ValueError('RecordCollection contains too many rows.') + + # Cast and return. + if as_dict: + return record.as_dict() + elif as_ordereddict: + return record.as_dict(ordered=True) + else: + return record class Database(object): """A Database connection.""" diff --git a/tests/test_records.py b/tests/test_records.py index 1af26e9..722591f 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -2,6 +2,8 @@ from collections import namedtuple import records +from pytest import raises + IdRecord = namedtuple('IdRecord', 'id') @@ -49,6 +51,42 @@ class TestRecordCollection: assert len(rows) == 10 + # all + + def test_all_returns_a_list_of_records(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + assert rows.all() == [IdRecord(0), IdRecord(1), IdRecord(2)] + + + # first + + def test_first_returns_a_single_record(self): + rows = records.RecordCollection(IdRecord(i) for i in range(1)) + assert rows.first() == IdRecord(0) + + def test_first_defaults_to_None(self): + rows = records.RecordCollection(iter([])) + assert rows.first() is None + + def test_first_default_is_overridable(self): + rows = records.RecordCollection(iter([])) + assert rows.first('Cheese') == 'Cheese' + + def test_first_raises_when_more_than_first(self): + rows = records.RecordCollection(IdRecord(i) for i in range(3)) + raises(ValueError, rows.first) + + def test_first_raises_default_if_its_an_exception_subclass(self): + rows = records.RecordCollection(iter([])) + class Cheese(Exception): pass + raises(Cheese, rows.first, Cheese) + + def test_first_raises_default_if_its_an_exception_instance(self): + rows = records.RecordCollection(iter([])) + class Cheese(Exception): pass + raises(Cheese, rows.first, Cheese('cheddar')) + + class TestRecord: def test_record_dir(self):