From 3d30b5f3fcc1b120aed11fe8f2e07e8ee435a775 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sat, 13 Feb 2016 01:27:04 -0500 Subject: [PATCH] the big refactor --- records.py | 170 +++++++++++++++++++++++++++-------------------------- 1 file changed, 88 insertions(+), 82 deletions(-) diff --git a/records.py b/records.py index 7cf69cd..5ed071d 100644 --- a/records.py +++ b/records.py @@ -3,83 +3,96 @@ import os from code import interact from datetime import datetime +from collections import namedtuple, OrderedDict import tablib -import psycopg2 from docopt import docopt -from psycopg2.extras import register_hstore, NamedTupleCursor -from psycopg2.extensions import cursor as _cursor - +from sqlalchemy import text, create_engine +from sqlalchemy.ext.declarative import declarative_base DATABASE_URL = os.environ.get('DATABASE_URL') -PG_INTERNAL_TABLES_QUERY = "SELECT * FROM pg_catalog.pg_tables" -PG_TABLES_QUERY = """ - SELECT - * - FROM - pg_catalog.pg_tables - WHERE - schemaname != 'pg_catalog' AND - schemaname != 'information_schema' -""" +class Record(object): + """A row, from a query, from a database.""" + __slots__ = ('_keys', '_values') + def __init__(self, keys, values): + self._keys = keys + self._values = values -class RecordsCursor(NamedTupleCursor): - """An enhanced cursor that generates Records.""" - try: - from collections import namedtuple - except ImportError as _exc: - def _make_nt(self): - raise self._exc - else: - def _make_nt(self, namedtuple=namedtuple): - RecordBase = namedtuple("Record", [d[0] for d in self.description or ()]) + # Esure that lengths match properly. + assert len(self._keys) == len(self._values) - # Extend the RecordsBase namedtuple, for enhanced API functionality. - class Record(RecordBase): - __slots__ = () - def keys(self): - return self._fields + def keys(self): + """Returns the list of column names from the query.""" + return self._keys - def __getitem__(self, key): - if isinstance(key, int): - return super(RecordBase, self).__getitem__(key) + def values(self): + """Returns the list of values from the query.""" + return self._values - if key in self.keys(): - return getattr(self, key) + def __repr__(self): + return ''.format(self.export('json')) - raise KeyError("Record contains no '{}' field.".format(key)) + def __getitem__(self, key): + # Support for index-based lookup. + if isinstance(key, int): + return self.values()[key] - @property - def dataset(self): - """A Tablib Dataset containing the row.""" - data = tablib.Dataset() - data.headers = self._fields + # Support for string-based lookup. + if key in self.keys(): + i = self.keys().index(key) + return self.values()[i] - row = _reduce_datetimes(self) - data.append(row) + raise KeyError("Record contains no '{}' field.".format(key)) - return data + def __getattr__(self, key): + try: + return self[key] + except KeyError, e: + raise AttributeError(e) - def export(self, format, **kwargs): - """Exports the row to the given format.""" - return self.dataset.export(format, **kwargs) + def __dir__(self): + standard = [ + # Would love to do this programatically, but couldn't figure out how. + '__class__', '__ddir__', '__delattr__', '__doc__', '__format__', + '__getattr__', '__getattribute__', '__getitem__', '__hash__', + '__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__', + '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', + '__subclasshook__', '_keys', '_values', 'as_dict', 'dataset', + 'export', 'get', 'keys', 'values' + ] - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default + # Merge standard attrs with generated ones (from column names). + return sorted(standard + [str(k) for k in self.keys()]) - def as_dict(self, ordered=False): - """Returns the row as a dictionary, as ordered.""" - if ordered: - return self._asdict() - else: - return dict(self._asdict()) + def get(self, key, default=None): + """Returns the value for a given key, or default.""" + try: + return self[key] + except KeyError: + return default - return Record + def as_dict(self, ordered=False): + """Returns the row as a dictionary, as ordered.""" + items = zip(self.keys(), self.values()) + + return OrderedDict(items) if ordered else dict(items) + + @property + def dataset(self): + """A Tablib Dataset containing the row.""" + data = tablib.Dataset() + data.headers = self.keys() + + row = _reduce_datetimes(self.values()) + data.append(row) + + return data + + def export(self, format, **kwargs): + """Exports the row to the given format.""" + return self.dataset.export(format, **kwargs) class ResultSet(object): @@ -187,10 +200,7 @@ class Database(object): raise ValueError('You must provide a db_url.') # Connect to the database. - self.db = psycopg2.connect(self.db_url, cursor_factory=RecordsCursor) - - # Enable hstore if it's available. - self._enable_hstore() + self.db = create_engine(self.db_url).connect() self.open = True def close(self): @@ -207,34 +217,27 @@ class Database(object): def __repr__(self): return ''.format(self.open) - def _enable_hstore(self): - """Enables HSTORE support, if available.""" - try: - register_hstore(self.db) - except psycopg2.ProgrammingError: - pass - def get_table_names(self, internal=False): """Returns a list of table names for the connected database.""" - # Support listing internal table names as well. - query = PG_INTERNAL_TABLES_QUERY if internal else PG_TABLES_QUERY + # Setup SQLAlchemy for Database inspection. + metadata = declarative_base().metadata + metadata.reflect(create_engine(self.db_url)) - # Return a list of tablenames. - return [r['tablename'] for r in self.query(query)] + # Serve the table names. + return metadata.tables.keys() - def query(self, query, params=None, fetchall=False): + def query(self, query, fetchall=False, **params): """Executes the given SQL query against the Database. Parameters can, optionally, be provided. Returns a ResultSet, which can be iterated over to get result rows as dictionaries. """ # Execute the given query. - c = self.db.cursor() - c.execute(query, params) + cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE - # Row-by-row result generator. - row_gen = (r for r in c) + # Row-by-row Record generator. + row_gen = (Record(cursor.keys(), row) for row in cursor) # Convert psycopg2 results to ResultSet. results = ResultSet(row_gen) @@ -245,7 +248,7 @@ class Database(object): return results - def query_file(self, path, params=None, fetchall=False): + def query_file(self, path, fetchall=False, **params): """Like Database.query, but takes a filename to load a query from.""" # If path doesn't exists @@ -261,14 +264,17 @@ class Database(object): query = f.read() # Defer processing to self.query method. - return self.query(query=query, params=params, fetchall=fetchall) + return self.query(query=query, fetchall=fetchall, **params) def _reduce_datetimes(row): """Receives a row, converts datetimes to strings.""" + + row = list(row) + for i in range(len(row)): if hasattr(row[i], 'isoformat'): - row = row._replace(**{row._fields[0]: row[i].isoformat()}) - return row + row[i] = row[i].isoformat() + return tuple(row) def cli():