the big refactor

This commit is contained in:
2016-02-13 01:27:04 -05:00
parent 29301ada61
commit 3d30b5f3fc
+88 -82
View File
@@ -3,83 +3,96 @@
import os
from code import interact
from datetime import datetime
from collections import namedtuple, OrderedDict
import tablib
import psycopg2
from docopt import docopt
from psycopg2.extras import register_hstore, NamedTupleCursor
from psycopg2.extensions import cursor as _cursor
from sqlalchemy import text, create_engine
from sqlalchemy.ext.declarative import declarative_base
DATABASE_URL = os.environ.get('DATABASE_URL')
PG_INTERNAL_TABLES_QUERY = "SELECT * FROM pg_catalog.pg_tables"
PG_TABLES_QUERY = """
SELECT
*
FROM
pg_catalog.pg_tables
WHERE
schemaname != 'pg_catalog' AND
schemaname != 'information_schema'
"""
class Record(object):
"""A row, from a query, from a database."""
__slots__ = ('_keys', '_values')
def __init__(self, keys, values):
self._keys = keys
self._values = values
class RecordsCursor(NamedTupleCursor):
"""An enhanced cursor that generates Records."""
try:
from collections import namedtuple
except ImportError as _exc:
def _make_nt(self):
raise self._exc
else:
def _make_nt(self, namedtuple=namedtuple):
RecordBase = namedtuple("Record", [d[0] for d in self.description or ()])
# Esure that lengths match properly.
assert len(self._keys) == len(self._values)
# Extend the RecordsBase namedtuple, for enhanced API functionality.
class Record(RecordBase):
__slots__ = ()
def keys(self):
return self._fields
def keys(self):
"""Returns the list of column names from the query."""
return self._keys
def __getitem__(self, key):
if isinstance(key, int):
return super(RecordBase, self).__getitem__(key)
def values(self):
"""Returns the list of values from the query."""
return self._values
if key in self.keys():
return getattr(self, key)
def __repr__(self):
return '<Record {}>'.format(self.export('json'))
raise KeyError("Record contains no '{}' field.".format(key))
def __getitem__(self, key):
# Support for index-based lookup.
if isinstance(key, int):
return self.values()[key]
@property
def dataset(self):
"""A Tablib Dataset containing the row."""
data = tablib.Dataset()
data.headers = self._fields
# Support for string-based lookup.
if key in self.keys():
i = self.keys().index(key)
return self.values()[i]
row = _reduce_datetimes(self)
data.append(row)
raise KeyError("Record contains no '{}' field.".format(key))
return data
def __getattr__(self, key):
try:
return self[key]
except KeyError, e:
raise AttributeError(e)
def export(self, format, **kwargs):
"""Exports the row to the given format."""
return self.dataset.export(format, **kwargs)
def __dir__(self):
standard = [
# Would love to do this programatically, but couldn't figure out how.
'__class__', '__ddir__', '__delattr__', '__doc__', '__format__',
'__getattr__', '__getattribute__', '__getitem__', '__hash__',
'__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__',
'__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__',
'__subclasshook__', '_keys', '_values', 'as_dict', 'dataset',
'export', 'get', 'keys', 'values'
]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
# Merge standard attrs with generated ones (from column names).
return sorted(standard + [str(k) for k in self.keys()])
def as_dict(self, ordered=False):
"""Returns the row as a dictionary, as ordered."""
if ordered:
return self._asdict()
else:
return dict(self._asdict())
def get(self, key, default=None):
"""Returns the value for a given key, or default."""
try:
return self[key]
except KeyError:
return default
return Record
def as_dict(self, ordered=False):
"""Returns the row as a dictionary, as ordered."""
items = zip(self.keys(), self.values())
return OrderedDict(items) if ordered else dict(items)
@property
def dataset(self):
"""A Tablib Dataset containing the row."""
data = tablib.Dataset()
data.headers = self.keys()
row = _reduce_datetimes(self.values())
data.append(row)
return data
def export(self, format, **kwargs):
"""Exports the row to the given format."""
return self.dataset.export(format, **kwargs)
class ResultSet(object):
@@ -187,10 +200,7 @@ class Database(object):
raise ValueError('You must provide a db_url.')
# Connect to the database.
self.db = psycopg2.connect(self.db_url, cursor_factory=RecordsCursor)
# Enable hstore if it's available.
self._enable_hstore()
self.db = create_engine(self.db_url).connect()
self.open = True
def close(self):
@@ -207,34 +217,27 @@ class Database(object):
def __repr__(self):
return '<Database open={}>'.format(self.open)
def _enable_hstore(self):
"""Enables HSTORE support, if available."""
try:
register_hstore(self.db)
except psycopg2.ProgrammingError:
pass
def get_table_names(self, internal=False):
"""Returns a list of table names for the connected database."""
# Support listing internal table names as well.
query = PG_INTERNAL_TABLES_QUERY if internal else PG_TABLES_QUERY
# Setup SQLAlchemy for Database inspection.
metadata = declarative_base().metadata
metadata.reflect(create_engine(self.db_url))
# Return a list of tablenames.
return [r['tablename'] for r in self.query(query)]
# Serve the table names.
return metadata.tables.keys()
def query(self, query, params=None, fetchall=False):
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters
can, optionally, be provided. Returns a ResultSet, which can be
iterated over to get result rows as dictionaries.
"""
# Execute the given query.
c = self.db.cursor()
c.execute(query, params)
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
# Row-by-row result generator.
row_gen = (r for r in c)
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
# Convert psycopg2 results to ResultSet.
results = ResultSet(row_gen)
@@ -245,7 +248,7 @@ class Database(object):
return results
def query_file(self, path, params=None, fetchall=False):
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
# If path doesn't exists
@@ -261,14 +264,17 @@ class Database(object):
query = f.read()
# Defer processing to self.query method.
return self.query(query=query, params=params, fetchall=fetchall)
return self.query(query=query, fetchall=fetchall, **params)
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
row = list(row)
for i in range(len(row)):
if hasattr(row[i], 'isoformat'):
row = row._replace(**{row._fields[0]: row[i].isoformat()})
return row
row[i] = row[i].isoformat()
return tuple(row)
def cli():