diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a43a28a --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +test: + py.test tests +init: + pip install -r requirements.txt +publish: + python setup.py register + python setup.py sdist upload + python setup.py bdist_wheel --universal upload + rm -fr build dist .egg records.egg-info \ No newline at end of file diff --git a/records.py b/records.py index ff31bd9..321c1ad 100644 --- a/records.py +++ b/records.py @@ -34,7 +34,7 @@ class RecordsCursor(NamedTupleCursor): def _make_nt(self, namedtuple=namedtuple): RecordBase = namedtuple("Record", [d[0] for d in self.description or ()]) - # Extend the RecordsBase namedtupe, for enhanced API functionality. + # Extend the RecordsBase namedtuple, for enhanced API functionality. class Record(RecordBase): __slots__ = () def keys(self): @@ -92,7 +92,7 @@ class ResultSet(object): # Other code may have iterated between yields, # so always check the cache. if i < len(self): - yield self._all_rows[i] + yield self[i] else: # Throws StopIteration when done. yield next(self) @@ -113,28 +113,23 @@ class ResultSet(object): def __getitem__(self, key): - is_int = False + is_int = isinstance(key, int) # Convert ResultSet[1] into slice. - if isinstance(key, int): - is_int = True - key = slice(key, key + 1, None) + if is_int: + key = slice(key, key + 1) - while len(self._all_rows) < key.stop or key.stop is None: + while len(self) < key.stop or key.stop is None: try: next(self) except StopIteration: break - item = self._all_rows[key] - if not is_int: - r = ResultSet(self._rows) - r._all_rows = item - item = r + rows = self._all_rows[key] + if is_int: + return rows[0] else: - item = item[0] - - return item + return ResultSet(iter(rows)) def __len__(self): return len(self._all_rows) @@ -196,6 +191,9 @@ class Database(object): def __exit__(self, exc, val, traceback): self.close() + def __repr__(self): + return ''.format(self.open) + def _enable_hstore(self): """Enables HSTORE support, if available.""" try: diff --git a/requirements.txt b/requirements.txt index 2b88733..f46e9ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ psycopg2==2.6.1 -tablib==0.10.0 +py==1.4.31 +pytest==2.8.7 +tablib==0.11.0 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 0000000..27ecf68 --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,47 @@ +from collections import namedtuple + +import records + + +IdRecord = namedtuple('IdRecord', 'id') + +def check_id(i, row): + assert row.id == i + +class TestResultSet: + def test_iter(self): + rows = records.ResultSet(IdRecord(i) for i in range(10)) + for i, row in enumerate(rows): + check_id(i, row) + + def test_next(self): + rows = records.ResultSet(IdRecord(i) for i in range(10)) + for i in range(10): + check_id(i, next(rows)) + + def test_iter_and_next(self): + rows = records.ResultSet(IdRecord(i) for i in range(10)) + i = enumerate(iter(rows)) + check_id(*next(i)) # Cache first row. + next(rows) # Cache second row. + check_id(*next(i)) # Read second row from cache. + + def test_multiple_iter(self): + rows = records.ResultSet(IdRecord(i) for i in range(10)) + i = enumerate(iter(rows)) + j = enumerate(iter(rows)) + + check_id(*next(i)) # Cache first row. + + check_id(*next(j)) # Read first row from cache. + check_id(*next(j)) # Cache second row. + + check_id(*next(i)) # Read second row from cache. + + def test_slice_iter(self): + rows = records.ResultSet(IdRecord(i) for i in range(10)) + for i, row in enumerate(rows[:5]): + check_id(i, row) + for i, row in enumerate(rows): + check_id(i, row) + assert len(rows) == 10