mirror of
https://github.com/kennethreitz/records.git
synced 2026-06-05 23:00:20 +00:00
the big refactor
This commit is contained in:
+88
-82
@@ -3,83 +3,96 @@
|
||||
import os
|
||||
from code import interact
|
||||
from datetime import datetime
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import tablib
|
||||
import psycopg2
|
||||
from docopt import docopt
|
||||
from psycopg2.extras import register_hstore, NamedTupleCursor
|
||||
from psycopg2.extensions import cursor as _cursor
|
||||
|
||||
from sqlalchemy import text, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
DATABASE_URL = os.environ.get('DATABASE_URL')
|
||||
|
||||
PG_INTERNAL_TABLES_QUERY = "SELECT * FROM pg_catalog.pg_tables"
|
||||
PG_TABLES_QUERY = """
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
pg_catalog.pg_tables
|
||||
WHERE
|
||||
schemaname != 'pg_catalog' AND
|
||||
schemaname != 'information_schema'
|
||||
"""
|
||||
|
||||
class Record(object):
|
||||
"""A row, from a query, from a database."""
|
||||
__slots__ = ('_keys', '_values')
|
||||
def __init__(self, keys, values):
|
||||
self._keys = keys
|
||||
self._values = values
|
||||
|
||||
class RecordsCursor(NamedTupleCursor):
|
||||
"""An enhanced cursor that generates Records."""
|
||||
try:
|
||||
from collections import namedtuple
|
||||
except ImportError as _exc:
|
||||
def _make_nt(self):
|
||||
raise self._exc
|
||||
else:
|
||||
def _make_nt(self, namedtuple=namedtuple):
|
||||
RecordBase = namedtuple("Record", [d[0] for d in self.description or ()])
|
||||
# Esure that lengths match properly.
|
||||
assert len(self._keys) == len(self._values)
|
||||
|
||||
# Extend the RecordsBase namedtuple, for enhanced API functionality.
|
||||
class Record(RecordBase):
|
||||
__slots__ = ()
|
||||
def keys(self):
|
||||
return self._fields
|
||||
def keys(self):
|
||||
"""Returns the list of column names from the query."""
|
||||
return self._keys
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, int):
|
||||
return super(RecordBase, self).__getitem__(key)
|
||||
def values(self):
|
||||
"""Returns the list of values from the query."""
|
||||
return self._values
|
||||
|
||||
if key in self.keys():
|
||||
return getattr(self, key)
|
||||
def __repr__(self):
|
||||
return '<Record {}>'.format(self.export('json'))
|
||||
|
||||
raise KeyError("Record contains no '{}' field.".format(key))
|
||||
def __getitem__(self, key):
|
||||
# Support for index-based lookup.
|
||||
if isinstance(key, int):
|
||||
return self.values()[key]
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
"""A Tablib Dataset containing the row."""
|
||||
data = tablib.Dataset()
|
||||
data.headers = self._fields
|
||||
# Support for string-based lookup.
|
||||
if key in self.keys():
|
||||
i = self.keys().index(key)
|
||||
return self.values()[i]
|
||||
|
||||
row = _reduce_datetimes(self)
|
||||
data.append(row)
|
||||
raise KeyError("Record contains no '{}' field.".format(key))
|
||||
|
||||
return data
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError, e:
|
||||
raise AttributeError(e)
|
||||
|
||||
def export(self, format, **kwargs):
|
||||
"""Exports the row to the given format."""
|
||||
return self.dataset.export(format, **kwargs)
|
||||
def __dir__(self):
|
||||
standard = [
|
||||
# Would love to do this programatically, but couldn't figure out how.
|
||||
'__class__', '__ddir__', '__delattr__', '__doc__', '__format__',
|
||||
'__getattr__', '__getattribute__', '__getitem__', '__hash__',
|
||||
'__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__',
|
||||
'__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__',
|
||||
'__subclasshook__', '_keys', '_values', 'as_dict', 'dataset',
|
||||
'export', 'get', 'keys', 'values'
|
||||
]
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
# Merge standard attrs with generated ones (from column names).
|
||||
return sorted(standard + [str(k) for k in self.keys()])
|
||||
|
||||
def as_dict(self, ordered=False):
|
||||
"""Returns the row as a dictionary, as ordered."""
|
||||
if ordered:
|
||||
return self._asdict()
|
||||
else:
|
||||
return dict(self._asdict())
|
||||
def get(self, key, default=None):
|
||||
"""Returns the value for a given key, or default."""
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
return Record
|
||||
def as_dict(self, ordered=False):
|
||||
"""Returns the row as a dictionary, as ordered."""
|
||||
items = zip(self.keys(), self.values())
|
||||
|
||||
return OrderedDict(items) if ordered else dict(items)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
"""A Tablib Dataset containing the row."""
|
||||
data = tablib.Dataset()
|
||||
data.headers = self.keys()
|
||||
|
||||
row = _reduce_datetimes(self.values())
|
||||
data.append(row)
|
||||
|
||||
return data
|
||||
|
||||
def export(self, format, **kwargs):
|
||||
"""Exports the row to the given format."""
|
||||
return self.dataset.export(format, **kwargs)
|
||||
|
||||
|
||||
class ResultSet(object):
|
||||
@@ -187,10 +200,7 @@ class Database(object):
|
||||
raise ValueError('You must provide a db_url.')
|
||||
|
||||
# Connect to the database.
|
||||
self.db = psycopg2.connect(self.db_url, cursor_factory=RecordsCursor)
|
||||
|
||||
# Enable hstore if it's available.
|
||||
self._enable_hstore()
|
||||
self.db = create_engine(self.db_url).connect()
|
||||
self.open = True
|
||||
|
||||
def close(self):
|
||||
@@ -207,34 +217,27 @@ class Database(object):
|
||||
def __repr__(self):
|
||||
return '<Database open={}>'.format(self.open)
|
||||
|
||||
def _enable_hstore(self):
|
||||
"""Enables HSTORE support, if available."""
|
||||
try:
|
||||
register_hstore(self.db)
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
|
||||
def get_table_names(self, internal=False):
|
||||
"""Returns a list of table names for the connected database."""
|
||||
|
||||
# Support listing internal table names as well.
|
||||
query = PG_INTERNAL_TABLES_QUERY if internal else PG_TABLES_QUERY
|
||||
# Setup SQLAlchemy for Database inspection.
|
||||
metadata = declarative_base().metadata
|
||||
metadata.reflect(create_engine(self.db_url))
|
||||
|
||||
# Return a list of tablenames.
|
||||
return [r['tablename'] for r in self.query(query)]
|
||||
# Serve the table names.
|
||||
return metadata.tables.keys()
|
||||
|
||||
def query(self, query, params=None, fetchall=False):
|
||||
def query(self, query, fetchall=False, **params):
|
||||
"""Executes the given SQL query against the Database. Parameters
|
||||
can, optionally, be provided. Returns a ResultSet, which can be
|
||||
iterated over to get result rows as dictionaries.
|
||||
"""
|
||||
|
||||
# Execute the given query.
|
||||
c = self.db.cursor()
|
||||
c.execute(query, params)
|
||||
cursor = self.db.execute(text(query), **params) # TODO: PARAMS GO HERE
|
||||
|
||||
# Row-by-row result generator.
|
||||
row_gen = (r for r in c)
|
||||
# Row-by-row Record generator.
|
||||
row_gen = (Record(cursor.keys(), row) for row in cursor)
|
||||
|
||||
# Convert psycopg2 results to ResultSet.
|
||||
results = ResultSet(row_gen)
|
||||
@@ -245,7 +248,7 @@ class Database(object):
|
||||
|
||||
return results
|
||||
|
||||
def query_file(self, path, params=None, fetchall=False):
|
||||
def query_file(self, path, fetchall=False, **params):
|
||||
"""Like Database.query, but takes a filename to load a query from."""
|
||||
|
||||
# If path doesn't exists
|
||||
@@ -261,14 +264,17 @@ class Database(object):
|
||||
query = f.read()
|
||||
|
||||
# Defer processing to self.query method.
|
||||
return self.query(query=query, params=params, fetchall=fetchall)
|
||||
return self.query(query=query, fetchall=fetchall, **params)
|
||||
|
||||
def _reduce_datetimes(row):
|
||||
"""Receives a row, converts datetimes to strings."""
|
||||
|
||||
row = list(row)
|
||||
|
||||
for i in range(len(row)):
|
||||
if hasattr(row[i], 'isoformat'):
|
||||
row = row._replace(**{row._fields[0]: row[i].isoformat()})
|
||||
return row
|
||||
row[i] = row[i].isoformat()
|
||||
return tuple(row)
|
||||
|
||||
|
||||
def cli():
|
||||
|
||||
Reference in New Issue
Block a user