diff --git a/records.py b/records.py index 807784f..064eb24 100644 --- a/records.py +++ b/records.py @@ -49,15 +49,25 @@ class ResultSet(object): data = tablib.Dataset() # Set the column names as headers on Tablib Dataset. - data.headers = self.all()[0].keys() + first = self.all()[0] + if isinstance(first, dict): + data.headers = self.all()[0].keys() + # Take each row, string-ify datetimes, insert into Tablib Dataset. + for row in self.all(): + row = _reduce_datetimes([v for k, v in row.items()]) + data.append(row) + elif _isnamedtupleinstance(first): + data.headers = first._fields + for row in self.all(): + row = _reduce_datetimes(row) + data.append(row) + else: + raise Exception('Unsupported cursor type') - # Take each row, string-ify datetimes, insert into Tablib Dataset. - for row in self.all(): - row = _reduce_datetimes([v for k, v in row.items()]) - data.append(row) return data + def all(self): """Returns a list of all rows for the ResultSet. If they haven't been fetched yet, consume the iterator and cache the results.""" @@ -70,16 +80,18 @@ class ResultSet(object): class Database(object): """A Database connection.""" - def __init__(self, db_url=None): + def __init__(self, db_url=None, cursor_factory=RealDictCursor): # If no db_url was provided, fallback to $DATABASE_URL. self.db_url = db_url or DATABASE_URL + self.cursor_factory = cursor_factory if not self.db_url: raise ValueError('You must provide a db_url.') # Connect to the database. - self.db = psycopg2.connect(self.db_url, cursor_factory=RealDictCursor) + self.db = psycopg2.connect(self.db_url, + cursor_factory=self.cursor_factory) # Enable hstore if it's available. self._enable_hstore() @@ -135,3 +147,11 @@ def _reduce_datetimes(row): if isinstance(row[i], datetime): row[i] = '{}'.format(row[i]) return row + +def _isnamedtupleinstance(x): + t = type(x) + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: return False + f = getattr(t, '_fields', None) + if not isinstance(f, tuple): return False + return all(type(n)==str for n in f)