""" Collection of postgres-specific extensions, currently including: * Support for hstore, a key/value type storage """ import json import logging import uuid from peewee import * from peewee import ColumnBase from peewee import Expression from peewee import Node from peewee import NodeList from peewee import SENTINEL from peewee import __exception_wrapper__ try: from psycopg2cffi import compat compat.register() except ImportError: pass try: from psycopg2.extras import register_hstore except ImportError: def register_hstore(c, globally): pass try: from psycopg2.extras import Json except: Json = None logger = logging.getLogger('peewee') HCONTAINS_DICT = '@>' HCONTAINS_KEYS = '?&' HCONTAINS_KEY = '?' HCONTAINS_ANY_KEY = '?|' HKEY = '->' HUPDATE = '||' ACONTAINS = '@>' ACONTAINS_ANY = '&&' TS_MATCH = '@@' JSONB_CONTAINS = '@>' JSONB_CONTAINED_BY = '<@' JSONB_CONTAINS_KEY = '?' JSONB_CONTAINS_ANY_KEY = '?|' JSONB_CONTAINS_ALL_KEYS = '?&' JSONB_EXISTS = '?' JSONB_REMOVE = '-' class _LookupNode(ColumnBase): def __init__(self, node, parts): self.node = node self.parts = parts super(_LookupNode, self).__init__() def clone(self): return type(self)(self.node, list(self.parts)) def __hash__(self): return hash((self.__class__.__name__, id(self))) class _JsonLookupBase(_LookupNode): def __init__(self, node, parts, as_json=False): super(_JsonLookupBase, self).__init__(node, parts) self._as_json = as_json def clone(self): return type(self)(self.node, list(self.parts), self._as_json) @Node.copy def as_json(self, as_json=True): self._as_json = as_json def concat(self, rhs): if not isinstance(rhs, Node): rhs = Json(rhs) return Expression(self.as_json(True), OP.CONCAT, rhs) def contains(self, other): clone = self.as_json(True) if isinstance(other, (list, dict)): return Expression(clone, JSONB_CONTAINS, Json(other)) return Expression(clone, JSONB_EXISTS, other) def contains_any(self, *keys): return Expression( self.as_json(True), JSONB_CONTAINS_ANY_KEY, Value(list(keys), unpack=False)) def contains_all(self, *keys): return Expression( self.as_json(True), JSONB_CONTAINS_ALL_KEYS, Value(list(keys), unpack=False)) def has_key(self, key): return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) class JsonLookup(_JsonLookupBase): def __getitem__(self, value): return JsonLookup(self.node, self.parts + [value], self._as_json) def __sql__(self, ctx): ctx.sql(self.node) for part in self.parts[:-1]: ctx.literal('->').sql(part) if self.parts: (ctx .literal('->' if self._as_json else '->>') .sql(self.parts[-1])) return ctx class JsonPath(_JsonLookupBase): def __sql__(self, ctx): return (ctx .sql(self.node) .literal('#>' if self._as_json else '#>>') .sql(Value('{%s}' % ','.join(map(str, self.parts))))) class ObjectSlice(_LookupNode): @classmethod def create(cls, node, value): if isinstance(value, slice): parts = [value.start or 0, value.stop or 0] elif isinstance(value, int): parts = [value] elif isinstance(value, Node): parts = value else: # Assumes colon-separated integer indexes. parts = [int(i) for i in value.split(':')] return cls(node, parts) def __sql__(self, ctx): ctx.sql(self.node) if isinstance(self.parts, Node): ctx.literal('[').sql(self.parts).literal(']') else: ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)) return ctx def __getitem__(self, value): return ObjectSlice.create(self, value) class IndexedFieldMixin(object): default_index_type = 'GIN' def __init__(self, *args, **kwargs): kwargs.setdefault('index', True) # By default, use an index. super(IndexedFieldMixin, self).__init__(*args, **kwargs) class ArrayField(IndexedFieldMixin, Field): passthrough = True def __init__(self, field_class=IntegerField, field_kwargs=None, dimensions=1, convert_values=False, *args, **kwargs): self.__field = field_class(**(field_kwargs or {})) self.dimensions = dimensions self.convert_values = convert_values self.field_type = self.__field.field_type super(ArrayField, self).__init__(*args, **kwargs) def bind(self, model, name, set_attribute=True): ret = super(ArrayField, self).bind(model, name, set_attribute) self.__field.bind(model, '__array_%s' % name, False) return ret def ddl_datatype(self, ctx): data_type = self.__field.ddl_datatype(ctx) return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') def db_value(self, value): if value is None or isinstance(value, Node): return value elif self.convert_values: return self._process(self.__field.db_value, value, self.dimensions) else: return value if isinstance(value, list) else list(value) def python_value(self, value): if self.convert_values and value is not None: conv = self.__field.python_value if isinstance(value, list): return self._process(conv, value, self.dimensions) else: return conv(value) else: return value def _process(self, conv, value, dimensions): dimensions -= 1 if dimensions == 0: return [conv(v) for v in value] else: return [self._process(conv, v, dimensions) for v in value] def __getitem__(self, value): return ObjectSlice.create(self, value) def _e(op): def inner(self, rhs): return Expression(self, op, ArrayValue(self, rhs)) return inner __eq__ = _e(OP.EQ) __ne__ = _e(OP.NE) __gt__ = _e(OP.GT) __ge__ = _e(OP.GTE) __lt__ = _e(OP.LT) __le__ = _e(OP.LTE) __hash__ = Field.__hash__ def contains(self, *items): return Expression(self, ACONTAINS, ArrayValue(self, items)) def contains_any(self, *items): return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) class ArrayValue(Node): def __init__(self, field, value): self.field = field self.value = value def __sql__(self, ctx): return (ctx .sql(Value(self.value, unpack=False)) .literal('::') .sql(self.field.ddl_datatype(ctx))) class DateTimeTZField(DateTimeField): field_type = 'TIMESTAMPTZ' class HStoreField(IndexedFieldMixin, Field): field_type = 'HSTORE' __hash__ = Field.__hash__ def __getitem__(self, key): return Expression(self, HKEY, Value(key)) def keys(self): return fn.akeys(self) def values(self): return fn.avals(self) def items(self): return fn.hstore_to_matrix(self) def slice(self, *args): return fn.slice(self, Value(list(args), unpack=False)) def exists(self, key): return fn.exist(self, key) def defined(self, key): return fn.defined(self, key) def update(self, **data): return Expression(self, HUPDATE, data) def delete(self, *keys): return fn.delete(self, Value(list(keys), unpack=False)) def contains(self, value): if isinstance(value, dict): rhs = Value(value, unpack=False) return Expression(self, HCONTAINS_DICT, rhs) elif isinstance(value, (list, tuple)): rhs = Value(value, unpack=False) return Expression(self, HCONTAINS_KEYS, rhs) return Expression(self, HCONTAINS_KEY, value) def contains_any(self, *keys): return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), unpack=False)) class JSONField(Field): field_type = 'JSON' _json_datatype = 'json' def __init__(self, dumps=None, *args, **kwargs): if Json is None: raise Exception('Your version of psycopg2 does not support JSON.') self.dumps = dumps or json.dumps super(JSONField, self).__init__(*args, **kwargs) def db_value(self, value): if value is None: return value if not isinstance(value, Json): return Cast(self.dumps(value), self._json_datatype) return value def __getitem__(self, value): return JsonLookup(self, [value]) def path(self, *keys): return JsonPath(self, keys) def concat(self, value): if not isinstance(value, Node): value = Json(value) return super(JSONField, self).concat(value) def cast_jsonb(node): return NodeList((node, SQL('::jsonb')), glue='') class BinaryJSONField(IndexedFieldMixin, JSONField): field_type = 'JSONB' _json_datatype = 'jsonb' __hash__ = Field.__hash__ def contains(self, other): if isinstance(other, (list, dict)): return Expression(self, JSONB_CONTAINS, Json(other)) elif isinstance(other, JSONField): return Expression(self, JSONB_CONTAINS, other) return Expression(cast_jsonb(self), JSONB_EXISTS, other) def contained_by(self, other): return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) def contains_any(self, *items): return Expression( cast_jsonb(self), JSONB_CONTAINS_ANY_KEY, Value(list(items), unpack=False)) def contains_all(self, *items): return Expression( cast_jsonb(self), JSONB_CONTAINS_ALL_KEYS, Value(list(items), unpack=False)) def has_key(self, key): return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) def remove(self, *items): return Expression( cast_jsonb(self), JSONB_REMOVE, Value(list(items), unpack=False)) class TSVectorField(IndexedFieldMixin, TextField): field_type = 'TSVECTOR' __hash__ = Field.__hash__ def match(self, query, language=None, plain=False): params = (language, query) if language is not None else (query,) func = fn.plainto_tsquery if plain else fn.to_tsquery return Expression(self, TS_MATCH, func(*params)) def Match(field, query, language=None): params = (language, query) if language is not None else (query,) field_params = (language, field) if language is not None else (field,) return Expression( fn.to_tsvector(*field_params), TS_MATCH, fn.to_tsquery(*params)) class IntervalField(Field): field_type = 'INTERVAL' class FetchManyCursor(object): __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') def __init__(self, cursor, array_size=None): self.cursor = cursor self.array_size = array_size or cursor.itersize self.exhausted = False self.iterable = self.row_gen() @property def description(self): return self.cursor.description def close(self): self.cursor.close() def row_gen(self): while True: rows = self.cursor.fetchmany(self.array_size) if not rows: return for row in rows: yield row def fetchone(self): if self.exhausted: return try: return next(self.iterable) except StopIteration: self.exhausted = True class ServerSideQuery(Node): def __init__(self, query, array_size=None): self.query = query self.array_size = array_size self._cursor_wrapper = None def __sql__(self, ctx): return self.query.__sql__(ctx) def __iter__(self): if self._cursor_wrapper is None: self._execute(self.query._database) return iter(self._cursor_wrapper.iterator()) def _execute(self, database): if self._cursor_wrapper is None: cursor = database.execute(self.query, named_cursor=True, array_size=self.array_size) self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) return self._cursor_wrapper def ServerSide(query, database=None, array_size=None): if database is None: database = query._database with database.transaction(): server_side_query = ServerSideQuery(query, array_size=array_size) for row in server_side_query: yield row class _empty_object(object): __slots__ = () def __nonzero__(self): return False __bool__ = __nonzero__ __named_cursor__ = _empty_object() class PostgresqlExtDatabase(PostgresqlDatabase): def __init__(self, *args, **kwargs): self._register_hstore = kwargs.pop('register_hstore', False) self._server_side_cursors = kwargs.pop('server_side_cursors', False) super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) def _connect(self): conn = super(PostgresqlExtDatabase, self)._connect() if self._register_hstore: register_hstore(conn, globally=True) return conn def cursor(self, commit=None): if self.is_closed(): if self.autoconnect: self.connect() else: raise InterfaceError('Error, database connection not opened.') if commit is __named_cursor__: return self._state.conn.cursor(name=str(uuid.uuid1())) return self._state.conn.cursor() def execute(self, query, commit=SENTINEL, named_cursor=False, array_size=None, **context_options): ctx = self.get_sql_context(**context_options) sql, params = ctx.sql(query).query() named_cursor = named_cursor or (self._server_side_cursors and sql[:6].lower() == 'select') if named_cursor: commit = __named_cursor__ cursor = self.execute_sql(sql, params, commit=commit) if named_cursor: cursor = FetchManyCursor(cursor, array_size) return cursor