diff --git a/records.py b/records.py index 5f4a2e3..32dd1af 100644 --- a/records.py +++ b/records.py @@ -19,26 +19,32 @@ class ResultSet(object): def __init__(self, rows): self._rows = rows self._all_rows = [] - self._completed = False def __repr__(self): return ''.format(id(self)) def __iter__(self): - # Use cached results if available. - if self._completed: + ''' + Starts by returning the cached items and then consumes the + generator in case it is not fully consumed. + ''' + if self._all_rows: for row in self._all_rows: yield row - - # Iterate over result cursor, cache rows. - for row in self._rows: - self._all_rows.append(row) - yield row - self._completed = True + try: + while True: + yield self.__next__() + except StopIteration: + pass def next(self): + return self.__next__() + + def __next__(self): try: - return self._rows.next() + nextrow = next(self._rows) + self._all_rows.append(nextrow) + return nextrow except StopIteration: raise StopIteration("ResultSet contains no more rows.") @@ -62,10 +68,8 @@ class ResultSet(object): """Returns a list of all rows for the ResultSet. If they haven't been fetched yet, consume the iterator and cache the results.""" - # If rows aren't cached, fetch them. - if not self._all_rows: - self._all_rows = list(self._rows) - return self._all_rows + # By calling list it calls the __iter__ method + return list(self) class Database(object): """A Database connection."""