From fc757a7e546bd0424e075c287600eede71ab47f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20V=C3=A9zina?= <5130500+morpheus65535@users.noreply.github.com> Date: Mon, 29 Jul 2019 13:03:45 -0400 Subject: [PATCH] First commit to the Peewee DataModel and add module to libs. --- bazarr/database.py | 172 + libs/peewee.py | 7323 +++++++++++++++++++++++++++++++ libs/playhouse/__init__.py | 0 libs/playhouse/apsw_ext.py | 136 + libs/playhouse/dataset.py | 415 ++ libs/playhouse/db_url.py | 124 + libs/playhouse/fields.py | 64 + libs/playhouse/flask_utils.py | 185 + libs/playhouse/hybrid.py | 50 + libs/playhouse/kv.py | 172 + libs/playhouse/migrate.py | 823 ++++ libs/playhouse/mysql_ext.py | 34 + libs/playhouse/pool.py | 317 ++ libs/playhouse/postgres_ext.py | 468 ++ libs/playhouse/reflection.py | 799 ++++ libs/playhouse/shortcuts.py | 224 + libs/playhouse/signals.py | 79 + libs/playhouse/sqlcipher_ext.py | 103 + libs/playhouse/sqlite_ext.py | 1261 ++++++ libs/playhouse/sqlite_udf.py | 522 +++ libs/playhouse/sqliteq.py | 330 ++ libs/playhouse/test_utils.py | 62 + libs/pwiz.py | 221 + libs/version.txt | 1 + 24 files changed, 13885 insertions(+) create mode 100644 bazarr/database.py create mode 100644 libs/peewee.py create mode 100644 libs/playhouse/__init__.py create mode 100644 libs/playhouse/apsw_ext.py create mode 100644 libs/playhouse/dataset.py create mode 100644 libs/playhouse/db_url.py create mode 100644 libs/playhouse/fields.py create mode 100644 libs/playhouse/flask_utils.py create mode 100644 libs/playhouse/hybrid.py create mode 100644 libs/playhouse/kv.py create mode 100644 libs/playhouse/migrate.py create mode 100644 libs/playhouse/mysql_ext.py create mode 100644 libs/playhouse/pool.py create mode 100644 libs/playhouse/postgres_ext.py create mode 100644 libs/playhouse/reflection.py create mode 100644 libs/playhouse/shortcuts.py create mode 100644 libs/playhouse/signals.py create mode 100644 libs/playhouse/sqlcipher_ext.py create mode 100644 libs/playhouse/sqlite_ext.py create mode 100644 libs/playhouse/sqlite_udf.py create mode 100644 libs/playhouse/sqliteq.py create mode 100644 libs/playhouse/test_utils.py create mode 100644 libs/pwiz.py diff --git a/bazarr/database.py b/bazarr/database.py new file mode 100644 index 000000000..9f17d28f1 --- /dev/null +++ b/bazarr/database.py @@ -0,0 +1,172 @@ +import os +import atexit + +from get_args import args +from peewee import * +from playhouse.sqliteq import SqliteQueueDatabase + +from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie + +database = SqliteQueueDatabase( + os.path.join(args.config_dir, 'db', 'bazarr.db'), + use_gevent=False, # Use the standard library "threading" module. + autostart=True, # The worker thread now must be started manually. + queue_max_size=256, # Max. # of pending writes that can accumulate. + results_timeout=30.0) # Max. time to wait for query to be executed. + + +@database.func('path_substitution') +def path_substitution(path): + return path_replace(path) + + +class UnknownField(object): + def __init__(self, *_, **__): pass + +class BaseModel(Model): + class Meta: + database = database + + +class SqliteSequence(BaseModel): + name = BareField(null=True) + seq = BareField(null=True) + + class Meta: + table_name = 'sqlite_sequence' + primary_key = False + + +class System(BaseModel): + configured = TextField(null=True) + updated = TextField(null=True) + + class Meta: + table_name = 'system' + primary_key = False + + +class TableShows(BaseModel): + alternate_titles = TextField(column_name='alternateTitles', null=True) + audio_language = TextField(null=True) + fanart = TextField(null=True) + forced = TextField(null=True) + hearing_impaired = TextField(null=True) + languages = TextField(null=True) + overview = TextField(null=True) + path = TextField(unique=True) + poster = TextField(null=True) + sonarr_series_id = IntegerField(column_name='sonarrSeriesId', unique=True) + sort_title = TextField(column_name='sortTitle', null=True) + title = TextField() + tvdb_id = AutoField(column_name='tvdbId') + year = TextField(null=True) + + class Meta: + table_name = 'table_shows' + + +class TableEpisodes(BaseModel): + audio_codec = TextField(null=True) + episode = IntegerField() + failed_attempts = TextField(column_name='failedAttempts', null=True) + format = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + path = TextField() + resolution = TextField(null=True) + scene_name = TextField(null=True) + season = IntegerField() + sonarr_episode_id = IntegerField(column_name='sonarrEpisodeId', unique=True) + sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId') + subtitles = TextField(null=True) + title = TextField() + video_codec = TextField(null=True) + + class Meta: + table_name = 'table_episodes' + primary_key = False + + +class TableMovies(BaseModel): + alternative_titles = TextField(column_name='alternativeTitles', null=True) + audio_codec = TextField(null=True) + audio_language = TextField(null=True) + failed_attempts = TextField(column_name='failedAttempts', null=True) + fanart = TextField(null=True) + forced = TextField(null=True) + format = TextField(null=True) + hearing_impaired = TextField(null=True) + imdb_id = TextField(column_name='imdbId', null=True) + languages = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + overview = TextField(null=True) + path = TextField(unique=True) + poster = TextField(null=True) + radarr_id = IntegerField(column_name='radarrId', unique=True) + resolution = TextField(null=True) + scene_name = TextField(column_name='sceneName', null=True) + sort_title = TextField(column_name='sortTitle', null=True) + subtitles = TextField(null=True) + title = TextField() + tmdb_id = TextField(column_name='tmdbId', primary_key=True) + video_codec = TextField(null=True) + year = TextField(null=True) + + class Meta: + table_name = 'table_movies' + + +class TableHistory(BaseModel): + action = IntegerField() + description = TextField() + language = TextField(null=True) + provider = TextField(null=True) + score = TextField(null=True) + sonarr_episode_id = ForeignKeyField(TableEpisodes, field='sonarr_episode_id' column_name='sonarrEpisodeId') + sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id' column_name='sonarrSeriesId') + timestamp = IntegerField() + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history' + + +class TableHistoryMovie(BaseModel): + action = IntegerField() + description = TextField() + language = TextField(null=True) + provider = TextField(null=True) + radarr_id = ForeignKeyField(TableMovies, field='radarr_id' column_name='radarrId') + score = TextField(null=True) + timestamp = IntegerField() + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history_movie' + + +class TableSettingsLanguages(BaseModel): + code2 = TextField(null=True) + code3 = TextField(primary_key=True) + code3b = TextField(null=True) + enabled = IntegerField(null=True) + name = TextField() + + class Meta: + table_name = 'table_settings_languages' + + +class TableSettingsNotifier(BaseModel): + enabled = IntegerField(null=True) + name = TextField(null=True, primary_key=True) + url = TextField(null=True) + + class Meta: + table_name = 'table_settings_notifier' + + +@atexit.register +def _stop_worker_threads(): + database.stop() \ No newline at end of file diff --git a/libs/peewee.py b/libs/peewee.py new file mode 100644 index 000000000..3204edb34 --- /dev/null +++ b/libs/peewee.py @@ -0,0 +1,7323 @@ +from bisect import bisect_left +from bisect import bisect_right +from contextlib import contextmanager +from copy import deepcopy +from functools import wraps +from inspect import isclass +import calendar +import collections +import datetime +import decimal +import hashlib +import itertools +import logging +import operator +import re +import socket +import struct +import sys +import threading +import time +import uuid +import warnings +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + +try: + from pysqlite3 import dbapi2 as pysq3 +except ImportError: + try: + from pysqlite2 import dbapi2 as pysq3 + except ImportError: + pysq3 = None +try: + import sqlite3 +except ImportError: + sqlite3 = pysq3 +else: + if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: + sqlite3 = pysq3 +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass +try: + import psycopg2 + from psycopg2 import extensions as pg_extensions + try: + from psycopg2 import errors as pg_errors + except ImportError: + pg_errors = None +except ImportError: + psycopg2 = pg_errors = None + +mysql_passwd = False +try: + import pymysql as mysql +except ImportError: + try: + import MySQLdb as mysql + mysql_passwd = True + except ImportError: + mysql = None + + +__version__ = '3.9.6' +__all__ = [ + 'AsIs', + 'AutoField', + 'BareField', + 'BigAutoField', + 'BigBitField', + 'BigIntegerField', + 'BinaryUUIDField', + 'BitField', + 'BlobField', + 'BooleanField', + 'Case', + 'Cast', + 'CharField', + 'Check', + 'chunked', + 'Column', + 'CompositeKey', + 'Context', + 'Database', + 'DatabaseError', + 'DatabaseProxy', + 'DataError', + 'DateField', + 'DateTimeField', + 'DecimalField', + 'DeferredForeignKey', + 'DeferredThroughModel', + 'DJANGO_MAP', + 'DoesNotExist', + 'DoubleField', + 'DQ', + 'EXCLUDED', + 'Field', + 'FixedCharField', + 'FloatField', + 'fn', + 'ForeignKeyField', + 'IdentityField', + 'ImproperlyConfigured', + 'Index', + 'IntegerField', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'IPField', + 'JOIN', + 'ManyToManyField', + 'Model', + 'ModelIndex', + 'MySQLDatabase', + 'NotSupportedError', + 'OP', + 'OperationalError', + 'PostgresqlDatabase', + 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. + 'prefetch', + 'ProgrammingError', + 'Proxy', + 'QualifiedNames', + 'SchemaManager', + 'SmallIntegerField', + 'Select', + 'SQL', + 'SqliteDatabase', + 'Table', + 'TextField', + 'TimeField', + 'TimestampField', + 'Tuple', + 'UUIDField', + 'Value', + 'ValuesList', + 'Window', +] + +try: # Python 2.7+ + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logger = logging.getLogger('peewee') +logger.addHandler(NullHandler()) + + +if sys.version_info[0] == 2: + text_type = unicode + bytes_type = str + buffer_type = buffer + izip_longest = itertools.izip_longest + callable_ = callable + exec('def reraise(tp, value, tb=None): raise tp, value, tb') + def print_(s): + sys.stdout.write(s) + sys.stdout.write('\n') +else: + import builtins + try: + from collections.abc import Callable + except ImportError: + from collections import Callable + from functools import reduce + callable_ = lambda c: isinstance(c, Callable) + text_type = str + bytes_type = bytes + buffer_type = memoryview + basestring = str + long = int + print_ = getattr(builtins, 'print') + izip_longest = itertools.zip_longest + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + +if sqlite3: + sqlite3.register_adapter(decimal.Decimal, str) + sqlite3.register_adapter(datetime.date, str) + sqlite3.register_adapter(datetime.time, str) + __sqlite_version__ = sqlite3.sqlite_version_info +else: + __sqlite_version__ = (0, 0, 0) + + +__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) + +# Sqlite does not support the `date_part` SQL function, so we will define an +# implementation in python. +__sqlite_datetime_formats__ = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +__sqlite_date_trunc__ = { + 'year': '%Y', + 'month': '%Y-%m', + 'day': '%Y-%m-%d', + 'hour': '%Y-%m-%d %H', + 'minute': '%Y-%m-%d %H:%M', + 'second': '%Y-%m-%d %H:%M:%S'} + +__mysql_date_trunc__ = __sqlite_date_trunc__.copy() +__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i' +__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' + +def _sqlite_date_part(lookup_type, datetime_string): + assert lookup_type in __date_parts__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return getattr(dt, lookup_type) + +def _sqlite_date_trunc(lookup_type, datetime_string): + assert lookup_type in __sqlite_date_trunc__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return dt.strftime(__sqlite_date_trunc__[lookup_type]) + + +def __deprecated__(s): + warnings.warn(s, DeprecationWarning) + + +class attrdict(dict): + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + raise AttributeError(attr) + def __setattr__(self, attr, value): self[attr] = value + def __iadd__(self, rhs): self.update(rhs); return self + def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d + +SENTINEL = object() + +#: Operations for use in SQL expressions. +OP = attrdict( + AND='AND', + OR='OR', + ADD='+', + SUB='-', + MUL='*', + DIV='/', + BIN_AND='&', + BIN_OR='|', + XOR='#', + MOD='%', + EQ='=', + LT='<', + LTE='<=', + GT='>', + GTE='>=', + NE='!=', + IN='IN', + NOT_IN='NOT IN', + IS='IS', + IS_NOT='IS NOT', + LIKE='LIKE', + ILIKE='ILIKE', + BETWEEN='BETWEEN', + REGEXP='REGEXP', + IREGEXP='IREGEXP', + CONCAT='||', + BITWISE_NEGATION='~') + +# To support "django-style" double-underscore filters, create a mapping between +# operation name and operation code, e.g. "__eq" == OP.EQ. +DJANGO_MAP = attrdict({ + 'eq': operator.eq, + 'lt': operator.lt, + 'lte': operator.le, + 'gt': operator.gt, + 'gte': operator.ge, + 'ne': operator.ne, + 'in': operator.lshift, + 'is': lambda l, r: Expression(l, OP.IS, r), + 'like': lambda l, r: Expression(l, OP.LIKE, r), + 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), + 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), +}) + +#: Mapping of field type to the data-type supported by the database. Databases +#: may override or add to this list. +FIELD = attrdict( + AUTO='INTEGER', + BIGAUTO='BIGINT', + BIGINT='BIGINT', + BLOB='BLOB', + BOOL='SMALLINT', + CHAR='CHAR', + DATE='DATE', + DATETIME='DATETIME', + DECIMAL='DECIMAL', + DEFAULT='', + DOUBLE='REAL', + FLOAT='REAL', + INT='INTEGER', + SMALLINT='SMALLINT', + TEXT='TEXT', + TIME='TIME', + UUID='TEXT', + UUIDB='BLOB', + VARCHAR='VARCHAR') + +#: Join helpers (for convenience) -- all join types are supported, this object +#: is just to help avoid introducing errors by using strings everywhere. +JOIN = attrdict( + INNER='INNER', + LEFT_OUTER='LEFT OUTER', + RIGHT_OUTER='RIGHT OUTER', + FULL='FULL', + FULL_OUTER='FULL OUTER', + CROSS='CROSS', + NATURAL='NATURAL') + +# Row representations. +ROW = attrdict( + TUPLE=1, + DICT=2, + NAMED_TUPLE=3, + CONSTRUCTOR=4, + MODEL=5) + +SCOPE_NORMAL = 1 +SCOPE_SOURCE = 2 +SCOPE_VALUES = 4 +SCOPE_CTE = 8 +SCOPE_COLUMN = 16 + +# Rules for parentheses around subqueries in compound select. +CSQ_PARENTHESES_NEVER = 0 +CSQ_PARENTHESES_ALWAYS = 1 +CSQ_PARENTHESES_UNNESTED = 2 + +# Regular expressions used to convert class names to snake-case table names. +# First regex handles acronym followed by word or initial lower-word followed +# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. +# Second regex handles the normal case of two title-cased words. +SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') +SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') + +# Helper functions that are used in various parts of the codebase. +MODEL_BASE = '_metaclass_helper_' + +def with_metaclass(meta, base=object): + return meta(MODEL_BASE, (base,), {}) + +def merge_dict(source, overrides): + merged = source.copy() + if overrides: + merged.update(overrides) + return merged + +def quote(path, quote_chars): + if len(path) == 1: + return path[0].join(quote_chars) + return '.'.join([part.join(quote_chars) for part in path]) + +is_model = lambda o: isclass(o) and issubclass(o, Model) + +def ensure_tuple(value): + if value is not None: + return value if isinstance(value, (list, tuple)) else (value,) + +def ensure_entity(value): + if value is not None: + return value if isinstance(value, Node) else Entity(value) + +def make_snake_case(s): + first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) + return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() + +def chunked(it, n): + marker = object() + for group in (list(g) for g in izip_longest(*[iter(it)] * n, + fillvalue=marker)): + if group[-1] is marker: + del group[group.index(marker):] + yield group + + +class _callable_context_manager(object): + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return inner + + +class Proxy(object): + """ + Create a proxy or placeholder for another object. + """ + __slots__ = ('obj', '_callbacks') + + def __init__(self): + self._callbacks = [] + self.initialize(None) + + def initialize(self, obj): + self.obj = obj + for callback in self._callbacks: + callback(obj) + + def attach_callback(self, callback): + self._callbacks.append(callback) + return callback + + def passthrough(method): + def inner(self, *args, **kwargs): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, method)(*args, **kwargs) + return inner + + # Allow proxy to be used as a context-manager. + __enter__ = passthrough('__enter__') + __exit__ = passthrough('__exit__') + + def __getattr__(self, attr): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, attr) + + def __setattr__(self, attr, value): + if attr not in self.__slots__: + raise AttributeError('Cannot set attribute on proxy.') + return super(Proxy, self).__setattr__(attr, value) + + +class DatabaseProxy(Proxy): + """ + Proxy implementation specifically for proxying `Database` objects. + """ + def connection_context(self): + return ConnectionContext(self) + def atomic(self): + return _atomic(self) + def manual_commit(self): + return _manual(self) + def transaction(self): + return _transaction(self) + def savepoint(self): + return _savepoint(self) + + +# SQL Generation. + + +class AliasManager(object): + __slots__ = ('_counter', '_current_index', '_mapping') + + def __init__(self): + # A list of dictionaries containing mappings at various depths. + self._counter = 0 + self._current_index = 0 + self._mapping = [] + self.push() + + @property + def mapping(self): + return self._mapping[self._current_index - 1] + + def add(self, source): + if source not in self.mapping: + self._counter += 1 + self[source] = 't%d' % self._counter + return self.mapping[source] + + def get(self, source, any_depth=False): + if any_depth: + for idx in reversed(range(self._current_index)): + if source in self._mapping[idx]: + return self._mapping[idx][source] + return self.add(source) + + def __getitem__(self, source): + return self.get(source) + + def __setitem__(self, source, alias): + self.mapping[source] = alias + + def push(self): + self._current_index += 1 + if self._current_index > len(self._mapping): + self._mapping.append({}) + + def pop(self): + if self._current_index == 1: + raise ValueError('Cannot pop() from empty alias manager.') + self._current_index -= 1 + + +class State(collections.namedtuple('_State', ('scope', 'parentheses', + 'settings'))): + def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): + return super(State, cls).__new__(cls, scope, parentheses, kwargs) + + def __call__(self, scope=None, parentheses=None, **kwargs): + # Scope and settings are "inherited" (parentheses is not, however). + scope = self.scope if scope is None else scope + + # Try to avoid unnecessary dict copying. + if kwargs and self.settings: + settings = self.settings.copy() # Copy original settings dict. + settings.update(kwargs) # Update copy with overrides. + elif kwargs: + settings = kwargs + else: + settings = self.settings + return State(scope, parentheses, **settings) + + def __getattr__(self, attr_name): + return self.settings.get(attr_name) + + +def __scope_context__(scope): + @contextmanager + def inner(self, **kwargs): + with self(scope=scope, **kwargs): + yield self + return inner + + +class Context(object): + __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') + + def __init__(self, **settings): + self.stack = [] + self._sql = [] + self._values = [] + self.alias_manager = AliasManager() + self.state = State(**settings) + + def as_new(self): + return Context(**self.state.settings) + + def column_sort_key(self, item): + return item[0].get_sort_key(self) + + @property + def scope(self): + return self.state.scope + + @property + def parentheses(self): + return self.state.parentheses + + @property + def subquery(self): + return self.state.subquery + + def __call__(self, **overrides): + if overrides and overrides.get('scope') == self.scope: + del overrides['scope'] + + self.stack.append(self.state) + self.state = self.state(**overrides) + return self + + scope_normal = __scope_context__(SCOPE_NORMAL) + scope_source = __scope_context__(SCOPE_SOURCE) + scope_values = __scope_context__(SCOPE_VALUES) + scope_cte = __scope_context__(SCOPE_CTE) + scope_column = __scope_context__(SCOPE_COLUMN) + + def __enter__(self): + if self.parentheses: + self.literal('(') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.parentheses: + self.literal(')') + self.state = self.stack.pop() + + @contextmanager + def push_alias(self): + self.alias_manager.push() + yield + self.alias_manager.pop() + + def sql(self, obj): + if isinstance(obj, (Node, Context)): + return obj.__sql__(self) + elif is_model(obj): + return obj._meta.table.__sql__(self) + else: + return self.sql(Value(obj)) + + def literal(self, keyword): + self._sql.append(keyword) + return self + + def value(self, value, converter=None, add_param=True): + if converter: + value = converter(value) + if isinstance(value, Node): + return self.sql(value) + elif converter is None and self.state.converter: + # Explicitly check for None so that "False" can be used to signify + # that no conversion should be applied. + value = self.state.converter(value) + + if isinstance(value, Node): + with self(converter=None): + return self.sql(value) + + self._values.append(value) + return self.literal(self.state.param or '?') if add_param else self + + def __sql__(self, ctx): + ctx._sql.extend(self._sql) + ctx._values.extend(self._values) + return ctx + + def parse(self, node): + return self.sql(node).query() + + def query(self): + return ''.join(self._sql), self._values + + +def query_to_string(query): + # NOTE: this function is not exported by default as it might be misused -- + # and this misuse could lead to sql injection vulnerabilities. This + # function is intended for debugging or logging purposes ONLY. + db = getattr(query, '_database', None) + if db is not None: + ctx = db.get_sql_context() + else: + ctx = Context() + + sql, params = ctx.sql(query).query() + if not params: + return sql + + param = ctx.state.param or '?' + if param == '?': + sql = sql.replace('?', '%s') + + return sql % tuple(map(_query_val_transform, params)) + +def _query_val_transform(v): + # Interpolate parameters. + if isinstance(v, (text_type, datetime.datetime, datetime.date, + datetime.time)): + v = "'%s'" % v + elif isinstance(v, bytes_type): + try: + v = v.decode('utf8') + except UnicodeDecodeError: + v = v.decode('raw_unicode_escape') + v = "'%s'" % v + elif isinstance(v, int): + v = '%s' % int(v) # Also handles booleans -> 1 or 0. + elif v is None: + v = 'NULL' + else: + v = str(v) + return v + + +# AST. + + +class Node(object): + _coerce = True + + def clone(self): + obj = self.__class__.__new__(self.__class__) + obj.__dict__ = self.__dict__.copy() + return obj + + def __sql__(self, ctx): + raise NotImplementedError + + @staticmethod + def copy(method): + def inner(self, *args, **kwargs): + clone = self.clone() + method(clone, *args, **kwargs) + return clone + return inner + + def coerce(self, _coerce=True): + if _coerce != self._coerce: + clone = self.clone() + clone._coerce = _coerce + return clone + return self + + def is_alias(self): + return False + + def unwrap(self): + return self + + +class ColumnFactory(object): + __slots__ = ('node',) + + def __init__(self, node): + self.node = node + + def __getattr__(self, attr): + return Column(self.node, attr) + + +class _DynamicColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + return ColumnFactory(instance) # Implements __getattr__(). + return self + + +class _ExplicitColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + raise AttributeError( + '%s specifies columns explicitly, and does not support ' + 'dynamic column lookups.' % instance) + return self + + +class Source(Node): + c = _DynamicColumn() + + def __init__(self, alias=None): + super(Source, self).__init__() + self._alias = alias + + @Node.copy + def alias(self, name): + self._alias = name + + def select(self, *columns): + if not columns: + columns = (SQL('*'),) + return Select((self,), columns) + + def join(self, dest, join_type='INNER', on=None): + return Join(self, dest, join_type, on) + + def left_outer_join(self, dest, on=None): + return Join(self, dest, JOIN.LEFT_OUTER, on) + + def cte(self, name, recursive=False, columns=None): + return CTE(name, self, recursive=recursive, columns=columns) + + def get_sort_key(self, ctx): + if self._alias: + return (self._alias,) + return (ctx.alias_manager[self],) + + def apply_alias(self, ctx): + # If we are defining the source, include the "AS alias" declaration. An + # alias is created for the source if one is not already defined. + if ctx.scope == SCOPE_SOURCE: + if self._alias: + ctx.alias_manager[self] = self._alias + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + return ctx + + def apply_column(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class _HashableSource(object): + def __init__(self, *args, **kwargs): + super(_HashableSource, self).__init__(*args, **kwargs) + self._update_hash() + + @Node.copy + def alias(self, name): + self._alias = name + self._update_hash() + + def _update_hash(self): + self._hash = self._get_hash() + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias)) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + return self._hash == other._hash + + def __ne__(self, other): + return not (self == other) + + +def __bind_database__(meth): + @wraps(meth) + def inner(self, *args, **kwargs): + result = meth(self, *args, **kwargs) + if self._database: + return result.bind(self._database) + return result + return inner + + +def __join__(join_type='INNER', inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return Join(self, other, join_type=join_type) + return method + + +class BaseTable(Source): + __and__ = __join__(JOIN.INNER) + __add__ = __join__(JOIN.LEFT_OUTER) + __sub__ = __join__(JOIN.RIGHT_OUTER) + __or__ = __join__(JOIN.FULL_OUTER) + __mul__ = __join__(JOIN.CROSS) + __rand__ = __join__(JOIN.INNER, inverted=True) + __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) + __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) + __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) + __rmul__ = __join__(JOIN.CROSS, inverted=True) + + +class _BoundTableContext(_callable_context_manager): + def __init__(self, table, database): + self.table = table + self.database = database + + def __enter__(self): + self._orig_database = self.table._database + self.table.bind(self.database) + if self.table._model is not None: + self.table._model.bind(self.database) + return self.table + + def __exit__(self, exc_type, exc_val, exc_tb): + self.table.bind(self._orig_database) + if self.table._model is not None: + self.table._model.bind(self._orig_database) + + +class Table(_HashableSource, BaseTable): + def __init__(self, name, columns=None, primary_key=None, schema=None, + alias=None, _model=None, _database=None): + self.__name__ = name + self._columns = columns + self._primary_key = primary_key + self._schema = schema + self._path = (schema, name) if schema else (name,) + self._model = _model + self._database = _database + super(Table, self).__init__(alias=alias) + + # Allow tables to restrict what columns are available. + if columns is not None: + self.c = _ExplicitColumn() + for column in columns: + setattr(self, column, Column(self, column)) + + if primary_key: + col_src = self if self._columns else self.c + self.primary_key = getattr(col_src, primary_key) + else: + self.primary_key = None + + def clone(self): + # Ensure a deep copy of the column instances. + return Table( + self.__name__, + columns=self._columns, + primary_key=self._primary_key, + schema=self._schema, + alias=self._alias, + _model=self._model, + _database=self._database) + + def bind(self, database=None): + self._database = database + return self + + def bind_ctx(self, database=None): + return _BoundTableContext(self, database) + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias, self._model)) + + @__bind_database__ + def select(self, *columns): + if not columns and self._columns: + columns = [Column(self, column) for column in self._columns] + return Select((self,), columns) + + @__bind_database__ + def insert(self, insert=None, columns=None, **kwargs): + if kwargs: + insert = {} if insert is None else insert + src = self if self._columns else self.c + for key, value in kwargs.items(): + insert[getattr(src, key)] = value + return Insert(self, insert=insert, columns=columns) + + @__bind_database__ + def replace(self, insert=None, columns=None, **kwargs): + return (self + .insert(insert=insert, columns=columns) + .on_conflict('REPLACE')) + + @__bind_database__ + def update(self, update=None, **kwargs): + if kwargs: + update = {} if update is None else update + for key, value in kwargs.items(): + src = self if self._columns else self.c + update[getattr(src, key)] = value + return Update(self, update=update) + + @__bind_database__ + def delete(self): + return Delete(self) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(Entity(*self._path)) + + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return self.apply_alias(ctx.sql(Entity(*self._path))) + else: + # Refer to the table using the alias. + return self.apply_column(ctx) + + +class Join(BaseTable): + def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): + super(Join, self).__init__(alias=alias) + self.lhs = lhs + self.rhs = rhs + self.join_type = join_type + self._on = on + + def on(self, predicate): + self._on = predicate + return self + + def __sql__(self, ctx): + (ctx + .sql(self.lhs) + .literal(' %s JOIN ' % self.join_type) + .sql(self.rhs)) + if self._on is not None: + ctx.literal(' ON ').sql(self._on) + return ctx + + +class ValuesList(_HashableSource, BaseTable): + def __init__(self, values, columns=None, alias=None): + self._values = values + self._columns = columns + super(ValuesList, self).__init__(alias=alias) + + def _get_hash(self): + return hash((self.__class__, id(self._values), self._alias)) + + @Node.copy + def columns(self, *names): + self._columns = names + + def __sql__(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: + with ctx(parentheses=not ctx.parentheses): + ctx = (ctx + .literal('VALUES ') + .sql(CommaNodeList([ + EnclosedNodeList(row) for row in self._values]))) + + if ctx.scope == SCOPE_SOURCE: + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + if self._columns: + entities = [Entity(c) for c in self._columns] + ctx.sql(EnclosedNodeList(entities)) + else: + ctx.sql(Entity(ctx.alias_manager[self])) + + return ctx + + +class CTE(_HashableSource, Source): + def __init__(self, name, query, recursive=False, columns=None): + self._alias = name + self._query = query + self._recursive = recursive + if columns is not None: + columns = [Entity(c) if isinstance(c, basestring) else c + for c in columns] + self._columns = columns + query._cte_list = () + super(CTE, self).__init__(alias=name) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns ' + 'from the CTE to select.') + + query = (Select((self,), columns) + .with_cte(self) + .bind(self._query._database)) + try: + query = query.objects(self._query.model) + except AttributeError: + pass + return query + + def _get_hash(self): + return hash((self.__class__, self._alias, id(self._query))) + + def union_all(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone + rhs, self._recursive, self._columns) + __add__ = union_all + + def __sql__(self, ctx): + if ctx.scope != SCOPE_CTE: + return ctx.sql(Entity(self._alias)) + + with ctx.push_alias(): + ctx.alias_manager[self] = self._alias + ctx.sql(Entity(self._alias)) + + if self._columns: + ctx.literal(' ').sql(EnclosedNodeList(self._columns)) + ctx.literal(' AS ') + with ctx.scope_normal(parentheses=True): + ctx.sql(self._query) + return ctx + + +class ColumnBase(Node): + def alias(self, alias): + if alias: + return Alias(self, alias) + return self + + def unalias(self): + return self + + def cast(self, as_type): + return Cast(self, as_type) + + def asc(self, collation=None, nulls=None): + return Asc(self, collation=collation, nulls=nulls) + __pos__ = asc + + def desc(self, collation=None, nulls=None): + return Desc(self, collation=collation, nulls=nulls) + __neg__ = desc + + def __invert__(self): + return Negated(self) + + def _e(op, inv=False): + """ + Lightweight factory which returns a method that builds an Expression + consisting of the left-hand and right-hand operands, using `op`. + """ + def inner(self, rhs): + if inv: + return Expression(rhs, op, self) + return Expression(self, op, rhs) + return inner + __and__ = _e(OP.AND) + __or__ = _e(OP.OR) + + __add__ = _e(OP.ADD) + __sub__ = _e(OP.SUB) + __mul__ = _e(OP.MUL) + __div__ = __truediv__ = _e(OP.DIV) + __xor__ = _e(OP.XOR) + __radd__ = _e(OP.ADD, inv=True) + __rsub__ = _e(OP.SUB, inv=True) + __rmul__ = _e(OP.MUL, inv=True) + __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) + __rand__ = _e(OP.AND, inv=True) + __ror__ = _e(OP.OR, inv=True) + __rxor__ = _e(OP.XOR, inv=True) + + def __eq__(self, rhs): + op = OP.IS if rhs is None else OP.EQ + return Expression(self, op, rhs) + def __ne__(self, rhs): + op = OP.IS_NOT if rhs is None else OP.NE + return Expression(self, op, rhs) + + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lshift__ = _e(OP.IN) + __rshift__ = _e(OP.IS) + __mod__ = _e(OP.LIKE) + __pow__ = _e(OP.ILIKE) + + bin_and = _e(OP.BIN_AND) + bin_or = _e(OP.BIN_OR) + in_ = _e(OP.IN) + not_in = _e(OP.NOT_IN) + regexp = _e(OP.REGEXP) + + # Special expressions. + def is_null(self, is_null=True): + op = OP.IS if is_null else OP.IS_NOT + return Expression(self, op, None) + def contains(self, rhs): + return Expression(self, OP.ILIKE, '%%%s%%' % rhs) + def startswith(self, rhs): + return Expression(self, OP.ILIKE, '%s%%' % rhs) + def endswith(self, rhs): + return Expression(self, OP.ILIKE, '%%%s' % rhs) + def between(self, lo, hi): + return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) + def concat(self, rhs): + return StringExpression(self, OP.CONCAT, rhs) + def regexp(self, rhs): + return Expression(self, OP.REGEXP, rhs) + def iregexp(self, rhs): + return Expression(self, OP.IREGEXP, rhs) + def __getitem__(self, item): + if isinstance(item, slice): + if item.start is None or item.stop is None: + raise ValueError('BETWEEN range must have both a start- and ' + 'end-point.') + return self.between(item.start, item.stop) + return self == item + + def distinct(self): + return NodeList((SQL('DISTINCT'), self)) + + def collate(self, collation): + return NodeList((self, SQL('COLLATE %s' % collation))) + + def get_sort_key(self, ctx): + return () + + +class Column(ColumnBase): + def __init__(self, source, name): + self.source = source + self.name = name + + def get_sort_key(self, ctx): + if ctx.scope == SCOPE_VALUES: + return (self.name,) + else: + return self.source.get_sort_key(ctx) + (self.name,) + + def __hash__(self): + return hash((self.source, self.name)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + return ctx.sql(Entity(self.name)) + else: + with ctx.scope_column(): + return ctx.sql(self.source).literal('.').sql(Entity(self.name)) + + +class WrappedNode(ColumnBase): + def __init__(self, node): + self.node = node + self._coerce = getattr(node, '_coerce', True) + + def is_alias(self): + return self.node.is_alias() + + def unwrap(self): + return self.node.unwrap() + + +class EntityFactory(object): + __slots__ = ('node',) + def __init__(self, node): + self.node = node + def __getattr__(self, attr): + return Entity(self.node, attr) + + +class _DynamicEntity(object): + __slots__ = () + def __get__(self, instance, instance_type=None): + if instance is not None: + return EntityFactory(instance._alias) # Implements __getattr__(). + return self + + +class Alias(WrappedNode): + c = _DynamicEntity() + + def __init__(self, node, alias): + super(Alias, self).__init__(node) + self._alias = alias + + def alias(self, alias=None): + if alias is None: + return self.node + else: + return Alias(self.node, alias) + + def unalias(self): + return self.node + + def is_alias(self): + return True + + def __sql__(self, ctx): + if ctx.scope == SCOPE_SOURCE: + return (ctx + .sql(self.node) + .literal(' AS ') + .sql(Entity(self._alias))) + else: + return ctx.sql(Entity(self._alias)) + + +class Negated(WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + return ctx.literal('NOT ').sql(self.node) + + +class BitwiseMixin(object): + def __and__(self, other): + return self.bin_and(other) + + def __or__(self, other): + return self.bin_or(other) + + def __sub__(self, other): + return self.bin_and(other.bin_negated()) + + def __invert__(self): + return BitwiseNegated(self) + + +class BitwiseNegated(BitwiseMixin, WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + return ctx.literal(op_sql).sql(self.node) + + +class Value(ColumnBase): + _multi_types = (list, tuple, frozenset, set) + + def __init__(self, value, converter=None, unpack=True): + self.value = value + self.converter = converter + self.multi = isinstance(self.value, self._multi_types) and unpack + if self.multi: + self.values = [] + for item in self.value: + if isinstance(item, Node): + self.values.append(item) + else: + self.values.append(Value(item, self.converter)) + + def __sql__(self, ctx): + if self.multi: + # For multi-part values (e.g. lists of IDs). + return ctx.sql(EnclosedNodeList(self.values)) + + return ctx.value(self.value, self.converter) + + +def AsIs(value): + return Value(value, unpack=False) + + +class Cast(WrappedNode): + def __init__(self, node, cast): + super(Cast, self).__init__(node) + self._cast = cast + self._coerce = False + + def __sql__(self, ctx): + return (ctx + .literal('CAST(') + .sql(self.node) + .literal(' AS %s)' % self._cast)) + + +class Ordering(WrappedNode): + def __init__(self, node, direction, collation=None, nulls=None): + super(Ordering, self).__init__(node) + self.direction = direction + self.collation = collation + self.nulls = nulls + if nulls and nulls.lower() not in ('first', 'last'): + raise ValueError('Ordering nulls= parameter must be "first" or ' + '"last", got: %s' % nulls) + + def collate(self, collation=None): + return Ordering(self.node, self.direction, collation) + + def _null_ordering_case(self, nulls): + if nulls.lower() == 'last': + ifnull, notnull = 1, 0 + elif nulls.lower() == 'first': + ifnull, notnull = 0, 1 + else: + raise ValueError('unsupported value for nulls= ordering.') + return Case(None, ((self.node.is_null(), ifnull),), notnull) + + def __sql__(self, ctx): + if self.nulls and not ctx.state.nulls_ordering: + ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') + + ctx.sql(self.node).literal(' %s' % self.direction) + if self.collation: + ctx.literal(' COLLATE %s' % self.collation) + if self.nulls and ctx.state.nulls_ordering: + ctx.literal(' NULLS %s' % self.nulls) + return ctx + + +def Asc(node, collation=None, nulls=None): + return Ordering(node, 'ASC', collation, nulls) + + +def Desc(node, collation=None, nulls=None): + return Ordering(node, 'DESC', collation, nulls) + + +class Expression(ColumnBase): + def __init__(self, lhs, op, rhs, flat=False): + self.lhs = lhs + self.op = op + self.rhs = rhs + self.flat = flat + + def __sql__(self, ctx): + overrides = {'parentheses': not self.flat, 'in_expr': True} + if isinstance(self.lhs, Field): + overrides['converter'] = self.lhs.db_value + else: + overrides['converter'] = None + + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + + with ctx(**overrides): + # Postgresql reports an error for IN/NOT IN (), so convert to + # the equivalent boolean expression. + op_in = self.op == OP.IN or self.op == OP.NOT_IN + if op_in and ctx.as_new().parse(self.rhs)[0] == '()': + return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') + + return (ctx + .sql(self.lhs) + .literal(' %s ' % op_sql) + .sql(self.rhs)) + + +class StringExpression(Expression): + def __add__(self, rhs): + return self.concat(rhs) + def __radd__(self, lhs): + return StringExpression(lhs, OP.CONCAT, self) + + +class Entity(ColumnBase): + def __init__(self, *path): + self._path = [part.replace('"', '""') for part in path if part] + + def __getattr__(self, attr): + return Entity(*self._path + [attr]) + + def get_sort_key(self, ctx): + return tuple(self._path) + + def __hash__(self): + return hash((self.__class__.__name__, tuple(self._path))) + + def __sql__(self, ctx): + return ctx.literal(quote(self._path, ctx.state.quote or '""')) + + +class SQL(ColumnBase): + def __init__(self, sql, params=None): + self.sql = sql + self.params = params + + def __sql__(self, ctx): + ctx.literal(self.sql) + if self.params: + for param in self.params: + ctx.value(param, False, add_param=False) + return ctx + + +def Check(constraint): + return SQL('CHECK (%s)' % constraint) + + +class Function(ColumnBase): + def __init__(self, name, arguments, coerce=True, python_value=None): + self.name = name + self.arguments = arguments + self._filter = None + self._python_value = python_value + if name and name.lower() in ('sum', 'count', 'cast'): + self._coerce = False + else: + self._coerce = coerce + + def __getattr__(self, attr): + def decorator(*args, **kwargs): + return Function(attr, args, **kwargs) + return decorator + + @Node.copy + def filter(self, where=None): + self._filter = where + + @Node.copy + def python_value(self, func=None): + self._python_value = func + + def over(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, window=None, exclude=None): + if isinstance(partition_by, Window) and window is None: + window = partition_by + + if window is not None: + node = WindowAlias(window) + else: + node = Window(partition_by=partition_by, order_by=order_by, + start=start, end=end, frame_type=frame_type, + exclude=exclude, _inline=True) + return NodeList((self, SQL('OVER'), node)) + + def __sql__(self, ctx): + ctx.literal(self.name) + if not len(self.arguments): + ctx.literal('()') + else: + with ctx(in_function=True, function_arg_count=len(self.arguments)): + ctx.sql(EnclosedNodeList([ + (argument if isinstance(argument, Node) + else Value(argument, False)) + for argument in self.arguments])) + + if self._filter: + ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') + return ctx + + +fn = Function(None, None) + + +class Window(Node): + # Frame start/end and frame exclusion. + CURRENT_ROW = SQL('CURRENT ROW') + GROUP = SQL('GROUP') + TIES = SQL('TIES') + NO_OTHERS = SQL('NO OTHERS') + + # Frame types. + GROUPS = 'GROUPS' + RANGE = 'RANGE' + ROWS = 'ROWS' + + def __init__(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, extends=None, exclude=None, alias=None, + _inline=False): + super(Window, self).__init__() + if start is not None and not isinstance(start, SQL): + start = SQL(start) + if end is not None and not isinstance(end, SQL): + end = SQL(end) + + self.partition_by = ensure_tuple(partition_by) + self.order_by = ensure_tuple(order_by) + self.start = start + self.end = end + if self.start is None and self.end is not None: + raise ValueError('Cannot specify WINDOW end without start.') + self._alias = alias or 'w' + self._inline = _inline + self.frame_type = frame_type + self._extends = extends + self._exclude = exclude + + def alias(self, alias=None): + self._alias = alias or 'w' + return self + + @Node.copy + def as_range(self): + self.frame_type = Window.RANGE + + @Node.copy + def as_rows(self): + self.frame_type = Window.ROWS + + @Node.copy + def as_groups(self): + self.frame_type = Window.GROUPS + + @Node.copy + def extends(self, window=None): + self._extends = window + + @Node.copy + def exclude(self, frame_exclusion=None): + if isinstance(frame_exclusion, basestring): + frame_exclusion = SQL(frame_exclusion) + self._exclude = frame_exclusion + + @staticmethod + def following(value=None): + if value is None: + return SQL('UNBOUNDED FOLLOWING') + return SQL('%d FOLLOWING' % value) + + @staticmethod + def preceding(value=None): + if value is None: + return SQL('UNBOUNDED PRECEDING') + return SQL('%d PRECEDING' % value) + + def __sql__(self, ctx): + if ctx.scope != SCOPE_SOURCE and not self._inline: + ctx.literal(self._alias) + ctx.literal(' AS ') + + with ctx(parentheses=True): + parts = [] + if self._extends is not None: + ext = self._extends + if isinstance(ext, Window): + ext = SQL(ext._alias) + elif isinstance(ext, basestring): + ext = SQL(ext) + parts.append(ext) + if self.partition_by: + parts.extend(( + SQL('PARTITION BY'), + CommaNodeList(self.partition_by))) + if self.order_by: + parts.extend(( + SQL('ORDER BY'), + CommaNodeList(self.order_by))) + if self.start is not None and self.end is not None: + frame = self.frame_type or 'ROWS' + parts.extend(( + SQL('%s BETWEEN' % frame), + self.start, + SQL('AND'), + self.end)) + elif self.start is not None: + parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) + elif self.frame_type is not None: + parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) + if self._exclude is not None: + parts.extend((SQL('EXCLUDE'), self._exclude)) + ctx.sql(NodeList(parts)) + return ctx + + +class WindowAlias(Node): + def __init__(self, window): + self.window = window + + def alias(self, window_alias): + self.window._alias = window_alias + return self + + def __sql__(self, ctx): + return ctx.literal(self.window._alias or 'w') + + +def Case(predicate, expression_tuples, default=None): + clauses = [SQL('CASE')] + if predicate is not None: + clauses.append(predicate) + for expr, value in expression_tuples: + clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) + if default is not None: + clauses.extend((SQL('ELSE'), default)) + clauses.append(SQL('END')) + return NodeList(clauses) + + +class NodeList(ColumnBase): + def __init__(self, nodes, glue=' ', parens=False): + self.nodes = nodes + self.glue = glue + self.parens = parens + if parens and len(self.nodes) == 1: + if isinstance(self.nodes[0], Expression): + # Hack to avoid double-parentheses. + self.nodes[0].flat = True + + def __sql__(self, ctx): + n_nodes = len(self.nodes) + if n_nodes == 0: + return ctx.literal('()') if self.parens else ctx + with ctx(parentheses=self.parens): + for i in range(n_nodes - 1): + ctx.sql(self.nodes[i]) + ctx.literal(self.glue) + ctx.sql(self.nodes[n_nodes - 1]) + return ctx + + +def CommaNodeList(nodes): + return NodeList(nodes, ', ') + + +def EnclosedNodeList(nodes): + return NodeList(nodes, ', ', True) + + +class _Namespace(Node): + __slots__ = ('_name',) + def __init__(self, name): + self._name = name + def __getattr__(self, attr): + return NamespaceAttribute(self, attr) + __getitem__ = __getattr__ + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace, attribute): + self._namespace = namespace + self._attribute = attribute + + def __sql__(self, ctx): + return (ctx + .literal(self._namespace._name + '.') + .sql(Entity(self._attribute))) + +EXCLUDED = _Namespace('EXCLUDED') + + +class DQ(ColumnBase): + def __init__(self, **query): + super(DQ, self).__init__() + self.query = query + self._negated = False + + @Node.copy + def __invert__(self): + self._negated = not self._negated + + def clone(self): + node = DQ(**self.query) + node._negated = self._negated + return node + +#: Represent a row tuple. +Tuple = lambda *a: EnclosedNodeList(a) + + +class QualifiedNames(WrappedNode): + def __sql__(self, ctx): + with ctx.scope_column(): + return ctx.sql(self.node) + + +def qualify_names(node): + # Search a node heirarchy to ensure that any column-like objects are + # referenced using fully-qualified names. + if isinstance(node, Expression): + return node.__class__(qualify_names(node.lhs), node.op, + qualify_names(node.rhs), node.flat) + elif isinstance(node, ColumnBase): + return QualifiedNames(node) + return node + + +class OnConflict(Node): + def __init__(self, action=None, update=None, preserve=None, where=None, + conflict_target=None, conflict_where=None, + conflict_constraint=None): + self._action = action + self._update = update + self._preserve = ensure_tuple(preserve) + self._where = where + if conflict_target is not None and conflict_constraint is not None: + raise ValueError('only one of "conflict_target" and ' + '"conflict_constraint" may be specified.') + self._conflict_target = ensure_tuple(conflict_target) + self._conflict_where = conflict_where + self._conflict_constraint = conflict_constraint + + def get_conflict_statement(self, ctx, query): + return ctx.state.conflict_statement(self, query) + + def get_conflict_update(self, ctx, query): + return ctx.state.conflict_update(self, query) + + @Node.copy + def preserve(self, *columns): + self._preserve = columns + + @Node.copy + def update(self, _data=None, **kwargs): + if _data and kwargs and not isinstance(_data, dict): + raise ValueError('Cannot mix data with keyword arguments in the ' + 'OnConflict update method.') + _data = _data or {} + if kwargs: + _data.update(kwargs) + self._update = _data + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_target(self, *constraints): + self._conflict_constraint = None + self._conflict_target = constraints + + @Node.copy + def conflict_where(self, *expressions): + if self._conflict_where is not None: + expressions = (self._conflict_where,) + expressions + self._conflict_where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_constraint(self, constraint): + self._conflict_constraint = constraint + self._conflict_target = None + + +def database_required(method): + @wraps(method) + def inner(self, database=None, *args, **kwargs): + database = self._database if database is None else database + if not database: + raise InterfaceError('Query must be bound to a database in order ' + 'to call "%s".' % method.__name__) + return method(self, database, *args, **kwargs) + return inner + +# BASE QUERY INTERFACE. + +class BaseQuery(Node): + default_row_type = ROW.DICT + + def __init__(self, _database=None, **kwargs): + self._database = _database + self._cursor_wrapper = None + self._row_type = None + self._constructor = None + super(BaseQuery, self).__init__(**kwargs) + + def bind(self, database=None): + self._database = database + return self + + def clone(self): + query = super(BaseQuery, self).clone() + query._cursor_wrapper = None + return query + + @Node.copy + def dicts(self, as_dict=True): + self._row_type = ROW.DICT if as_dict else None + return self + + @Node.copy + def tuples(self, as_tuple=True): + self._row_type = ROW.TUPLE if as_tuple else None + return self + + @Node.copy + def namedtuples(self, as_namedtuple=True): + self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None + return self + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR if constructor else None + self._constructor = constructor + return self + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + + if row_type == ROW.DICT: + return DictCursorWrapper(cursor) + elif row_type == ROW.TUPLE: + return CursorWrapper(cursor) + elif row_type == ROW.NAMED_TUPLE: + return NamedTupleCursorWrapper(cursor) + elif row_type == ROW.CONSTRUCTOR: + return ObjectCursorWrapper(cursor, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def __sql__(self, ctx): + raise NotImplementedError + + def sql(self): + if self._database: + context = self._database.get_sql_context() + else: + context = Context() + return context.parse(self) + + @database_required + def execute(self, database): + return self._execute(database) + + def _execute(self, database): + raise NotImplementedError + + def iterator(self, database=None): + return iter(self.execute(database).iterator()) + + def _ensure_execution(self): + if not self._cursor_wrapper: + if not self._database: + raise ValueError('Query has not been executed.') + self.execute() + + def __iter__(self): + self._ensure_execution() + return iter(self._cursor_wrapper) + + def __getitem__(self, value): + self._ensure_execution() + if isinstance(value, slice): + index = value.stop + else: + index = value + if index is not None: + index = index + 1 if index >= 0 else 0 + self._cursor_wrapper.fill_cache(index) + return self._cursor_wrapper.row_cache[value] + + def __len__(self): + self._ensure_execution() + return len(self._cursor_wrapper) + + def __str__(self): + return query_to_string(self) + + +class RawQuery(BaseQuery): + def __init__(self, sql=None, params=None, **kwargs): + super(RawQuery, self).__init__(**kwargs) + self._sql = sql + self._params = params + + def __sql__(self, ctx): + ctx.literal(self._sql) + if self._params: + for param in self._params: + ctx.value(param, add_param=False) + return ctx + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +class Query(BaseQuery): + def __init__(self, where=None, order_by=None, limit=None, offset=None, + **kwargs): + super(Query, self).__init__(**kwargs) + self._where = where + self._order_by = order_by + self._limit = limit + self._offset = offset + + self._cte_list = None + + @Node.copy + def with_cte(self, *cte_list): + self._cte_list = cte_list + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def orwhere(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.or_, expressions) + + @Node.copy + def order_by(self, *values): + self._order_by = values + + @Node.copy + def order_by_extend(self, *values): + self._order_by = ((self._order_by or ()) + values) or None + + @Node.copy + def limit(self, value=None): + self._limit = value + + @Node.copy + def offset(self, value=None): + self._offset = value + + @Node.copy + def paginate(self, page, paginate_by=20): + if page > 0: + page -= 1 + self._limit = paginate_by + self._offset = page * paginate_by + + def _apply_ordering(self, ctx): + if self._order_by: + (ctx + .literal(' ORDER BY ') + .sql(CommaNodeList(self._order_by))) + if self._limit is not None or (self._offset is not None and + ctx.state.limit_max): + ctx.literal(' LIMIT ').sql(self._limit or ctx.state.limit_max) + if self._offset is not None: + ctx.literal(' OFFSET ').sql(self._offset) + return ctx + + def __sql__(self, ctx): + if self._cte_list: + # The CTE scope is only used at the very beginning of the query, + # when we are describing the various CTEs we will be using. + recursive = any(cte._recursive for cte in self._cte_list) + + # Explicitly disable the "subquery" flag here, so as to avoid + # unnecessary parentheses around subsequent selects. + with ctx.scope_cte(subquery=False): + (ctx + .literal('WITH RECURSIVE ' if recursive else 'WITH ') + .sql(CommaNodeList(self._cte_list)) + .literal(' ')) + return ctx + + +def __compound_select__(operation, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return CompoundSelectQuery(self, operation, other) + return method + + +class SelectQuery(Query): + union_all = __add__ = __compound_select__('UNION ALL') + union = __or__ = __compound_select__('UNION') + intersect = __and__ = __compound_select__('INTERSECT') + except_ = __sub__ = __compound_select__('EXCEPT') + __radd__ = __compound_select__('UNION ALL', inverted=True) + __ror__ = __compound_select__('UNION', inverted=True) + __rand__ = __compound_select__('INTERSECT', inverted=True) + __rsub__ = __compound_select__('EXCEPT', inverted=True) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns.') + + query = (Select((self,), columns) + .bind(self._database)) + if getattr(self, 'model', None) is not None: + # Bind to the sub-select's model type, if defined. + query = query.objects(self.model) + return query + + +class SelectBase(_HashableSource, Source, SelectQuery): + def _get_hash(self): + return hash((self.__class__, self._alias or id(self))) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + @database_required + def peek(self, database, n=1): + rows = self.execute(database)[:n] + if rows: + return rows[0] if n == 1 else rows + + @database_required + def first(self, database, n=1): + if self._limit != n: + self._limit = n + self._cursor_wrapper = None + return self.peek(database, n=n) + + @database_required + def scalar(self, database, as_tuple=False): + row = self.tuples().peek(database) + return row[0] if row and not as_tuple else row + + @database_required + def count(self, database, clear_limit=False): + clone = self.order_by().alias('_wrapped') + if clear_limit: + clone._limit = clone._offset = None + try: + if clone._having is None and clone._group_by is None and \ + clone._windows is None and clone._distinct is None and \ + clone._simple_distinct is not True: + clone = clone.select(SQL('1')) + except AttributeError: + pass + return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) + + @database_required + def exists(self, database): + clone = self.columns(SQL('1')) + clone._limit = 1 + clone._offset = None + return bool(clone.scalar()) + + @database_required + def get(self, database): + self._cursor_wrapper = None + try: + return self.execute(database)[0] + except IndexError: + pass + + +# QUERY IMPLEMENTATIONS. + + +class CompoundSelectQuery(SelectBase): + def __init__(self, lhs, op, rhs): + super(CompoundSelectQuery, self).__init__() + self.lhs = lhs + self.op = op + self.rhs = rhs + + @property + def _returning(self): + return self.lhs._returning + + def _get_query_key(self): + return (self.lhs.get_query_key(), self.rhs.get_query_key()) + + def _wrap_parens(self, ctx, subq): + csq_setting = ctx.state.compound_select_parentheses + + if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: + return False + elif csq_setting == CSQ_PARENTHESES_ALWAYS: + return True + elif csq_setting == CSQ_PARENTHESES_UNNESTED: + return not isinstance(subq, CompoundSelectQuery) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) + with ctx(parentheses=outer_parens): + # Should the left-hand query be wrapped in parentheses? + lhs_parens = self._wrap_parens(ctx, self.lhs) + with ctx.scope_normal(parentheses=lhs_parens, subquery=False): + ctx.sql(self.lhs) + ctx.literal(' %s ' % self.op) + with ctx.push_alias(): + # Should the right-hand query be wrapped in parentheses? + rhs_parens = self._wrap_parens(ctx, self.rhs) + with ctx.scope_normal(parentheses=rhs_parens, subquery=False): + ctx.sql(self.rhs) + + # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that + # entity names are not fully-qualified. This is a bit of a hack, as + # we're relying on the logic in Column.__sql__() to not fully + # qualify column names. + with ctx.scope_values(): + self._apply_ordering(ctx) + + return self.apply_alias(ctx) + + +class Select(SelectBase): + def __init__(self, from_list=None, columns=None, group_by=None, + having=None, distinct=None, windows=None, for_update=None, + **kwargs): + super(Select, self).__init__(**kwargs) + self._from_list = (list(from_list) if isinstance(from_list, tuple) + else from_list) or [] + self._returning = columns + self._group_by = group_by + self._having = having + self._windows = None + self._for_update = 'FOR UPDATE' if for_update is True else for_update + + self._distinct = self._simple_distinct = None + if distinct: + if isinstance(distinct, bool): + self._simple_distinct = distinct + else: + self._distinct = distinct + + self._cursor_wrapper = None + + def clone(self): + clone = super(Select, self).clone() + if clone._from_list: + clone._from_list = list(clone._from_list) + return clone + + @Node.copy + def columns(self, *columns, **kwargs): + self._returning = columns + select = columns + + @Node.copy + def select_extend(self, *columns): + self._returning = tuple(self._returning) + columns + + @Node.copy + def from_(self, *sources): + self._from_list = list(sources) + + @Node.copy + def join(self, dest, join_type='INNER', on=None): + if not self._from_list: + raise ValueError('No sources to join on.') + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + def group_by_extend(self, *values): + """@Node.copy used from group_by() call""" + group_by = tuple(self._group_by or ()) + values + return self.group_by(*group_by) + + @Node.copy + def having(self, *expressions): + if self._having is not None: + expressions = (self._having,) + expressions + self._having = reduce(operator.and_, expressions) + + @Node.copy + def distinct(self, *columns): + if len(columns) == 1 and (columns[0] is True or columns[0] is False): + self._simple_distinct = columns[0] + else: + self._simple_distinct = False + self._distinct = columns + + @Node.copy + def window(self, *windows): + self._windows = windows if windows else None + + @Node.copy + def for_update(self, for_update=True): + self._for_update = 'FOR UPDATE' if for_update is True else for_update + + def _get_query_key(self): + return self._alias + + def __sql_selection__(self, ctx, is_subquery=False): + return ctx.sql(CommaNodeList(self._returning)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + is_subquery = ctx.subquery + state = { + 'converter': None, + 'in_function': False, + 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), + 'subquery': True, + } + if ctx.state.in_function and ctx.state.function_arg_count == 1: + state['parentheses'] = False + + with ctx.scope_normal(**state): + # Defer calling parent SQL until here. This ensures that any CTEs + # for this query will be properly nested if this query is a + # sub-select or is used in an expression. See GH#1809 for example. + super(Select, self).__sql__(ctx) + + ctx.literal('SELECT ') + if self._simple_distinct or self._distinct is not None: + ctx.literal('DISTINCT ') + if self._distinct: + (ctx + .literal('ON ') + .sql(EnclosedNodeList(self._distinct)) + .literal(' ')) + + with ctx.scope_source(): + ctx = self.__sql_selection__(ctx, is_subquery) + + if self._from_list: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) + + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + if self._group_by: + ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) + + if self._having is not None: + ctx.literal(' HAVING ').sql(self._having) + + if self._windows is not None: + ctx.literal(' WINDOW ') + ctx.sql(CommaNodeList(self._windows)) + + # Apply ORDER BY, LIMIT, OFFSET. + self._apply_ordering(ctx) + + if self._for_update: + if not ctx.state.for_update: + raise ValueError('FOR UPDATE specified but not supported ' + 'by database.') + ctx.literal(' ') + ctx.sql(SQL(self._for_update)) + + # If the subquery is inside a function -or- we are evaluating a + # subquery on either side of an expression w/o an explicit alias, do + # not generate an alias + AS clause. + if ctx.state.in_function or (ctx.state.in_expr and + self._alias is None): + return ctx + + return self.apply_alias(ctx) + + +class _WriteQuery(Query): + def __init__(self, table, returning=None, **kwargs): + self.table = table + self._returning = returning + self._return_cursor = True if returning else False + super(_WriteQuery, self).__init__(**kwargs) + + @Node.copy + def returning(self, *returning): + self._returning = returning + self._return_cursor = True if returning else False + + def apply_returning(self, ctx): + if self._returning: + with ctx.scope_normal(): + ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) + return ctx + + def _execute(self, database): + if self._returning: + cursor = self.execute_returning(database) + else: + cursor = database.execute(self) + return self.handle_result(database, cursor) + + def execute_returning(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + return database.rows_affected(cursor) + + def _set_table_alias(self, ctx): + ctx.alias_manager[self.table] = self.table.__name__ + + def __sql__(self, ctx): + super(_WriteQuery, self).__sql__(ctx) + # We explicitly set the table alias to the table's name, which ensures + # that if a sub-select references a column on the outer table, we won't + # assign it a new alias (e.g. t2) but will refer to it as table.column. + self._set_table_alias(ctx) + return ctx + + +class Update(_WriteQuery): + def __init__(self, table, update=None, **kwargs): + super(Update, self).__init__(table, **kwargs) + self._update = update + self._from = None + + @Node.copy + def from_(self, *sources): + self._from = sources + + def __sql__(self, ctx): + super(Update, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('UPDATE ') + + expressions = [] + for k, v in sorted(self._update.items(), key=ctx.column_sort_key): + if not isinstance(v, Node): + converter = k.db_value if isinstance(k, Field) else None + v = Value(v, converter=converter, unpack=False) + if not isinstance(v, Value): + v = qualify_names(v) + expressions.append(NodeList((k, SQL('='), v))) + + (ctx + .sql(self.table) + .literal(' SET ') + .sql(CommaNodeList(expressions))) + + if self._from: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from)) + + if self._where: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Insert(_WriteQuery): + SIMPLE = 0 + QUERY = 1 + MULTI = 2 + class DefaultValuesException(Exception): pass + + def __init__(self, table, insert=None, columns=None, on_conflict=None, + **kwargs): + super(Insert, self).__init__(table, **kwargs) + self._insert = insert + self._columns = columns + self._on_conflict = on_conflict + self._query_type = None + + def where(self, *expressions): + raise NotImplementedError('INSERT queries cannot have a WHERE clause.') + + @Node.copy + def on_conflict_ignore(self, ignore=True): + self._on_conflict = OnConflict('IGNORE') if ignore else None + + @Node.copy + def on_conflict_replace(self, replace=True): + self._on_conflict = OnConflict('REPLACE') if replace else None + + @Node.copy + def on_conflict(self, *args, **kwargs): + self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) + else None) + + def _simple_insert(self, ctx): + if not self._insert: + raise self.DefaultValuesException('Error: no data to insert.') + return self._generate_insert((self._insert,), ctx) + + def get_default_data(self): + return {} + + def get_default_columns(self): + if self.table._columns: + return [getattr(self.table, col) for col in self.table._columns + if col != self.table._primary_key] + + def _generate_insert(self, insert, ctx): + rows_iter = iter(insert) + columns = self._columns + + # Load and organize column defaults (if provided). + defaults = self.get_default_data() + value_lookups = {} + + # First figure out what columns are being inserted (if they weren't + # specified explicitly). Resulting columns are normalized and ordered. + if not columns: + try: + row = next(rows_iter) + except StopIteration: + raise self.DefaultValuesException('Error: no rows to insert.') + + if not isinstance(row, dict): + columns = self.get_default_columns() + if columns is None: + raise ValueError('Bulk insert must specify columns.') + else: + # Infer column names from the dict of data being inserted. + accum = [] + uses_strings = False # Are the dict keys strings or columns? + for key in row: + if isinstance(key, basestring): + column = getattr(self.table, key) + uses_strings = True + else: + column = key + accum.append(column) + value_lookups[column] = key + + # Add any columns present in the default data that are not + # accounted for by the dictionary of row data. + column_set = set(accum) + for col in (set(defaults) - column_set): + accum.append(col) + value_lookups[col] = col.name if uses_strings else col + + columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) + rows_iter = itertools.chain(iter((row,)), rows_iter) + else: + clean_columns = [] + for column in columns: + if isinstance(column, basestring): + column_obj = getattr(self.table, column) + else: + column_obj = column + value_lookups[column_obj] = column + clean_columns.append(column_obj) + + columns = clean_columns + for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): + if col not in value_lookups: + columns.append(col) + value_lookups[col] = col + + ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') + columns_converters = [ + (column, column.db_value if isinstance(column, Field) else None) + for column in columns] + + all_values = [] + for row in rows_iter: + values = [] + is_dict = isinstance(row, Mapping) + for i, (column, converter) in enumerate(columns_converters): + try: + if is_dict: + val = row[value_lookups[column]] + else: + val = row[i] + except (KeyError, IndexError): + if column in defaults: + val = defaults[column] + if callable_(val): + val = val() + else: + raise ValueError('Missing value for %s.' % column.name) + + if not isinstance(val, Node): + val = Value(val, converter=converter, unpack=False) + values.append(val) + + all_values.append(EnclosedNodeList(values)) + + if not all_values: + raise self.DefaultValuesException('Error: no data to insert.') + + with ctx.scope_values(subquery=True): + return ctx.sql(CommaNodeList(all_values)) + + def _query_insert(self, ctx): + return (ctx + .sql(EnclosedNodeList(self._columns)) + .literal(' ') + .sql(self._insert)) + + def _default_values(self, ctx): + if not self._database: + return ctx.literal('DEFAULT VALUES') + return self._database.default_values_insert(ctx) + + def __sql__(self, ctx): + super(Insert, self).__sql__(ctx) + with ctx.scope_values(): + stmt = None + if self._on_conflict is not None: + stmt = self._on_conflict.get_conflict_statement(ctx, self) + + (ctx + .sql(stmt or SQL('INSERT')) + .literal(' INTO ') + .sql(self.table) + .literal(' ')) + + if isinstance(self._insert, dict) and not self._columns: + try: + self._simple_insert(ctx) + except self.DefaultValuesException: + self._default_values(ctx) + self._query_type = Insert.SIMPLE + elif isinstance(self._insert, (SelectQuery, SQL)): + self._query_insert(ctx) + self._query_type = Insert.QUERY + else: + self._generate_insert(self._insert, ctx) + self._query_type = Insert.MULTI + + if self._on_conflict is not None: + update = self._on_conflict.get_conflict_update(ctx, self) + if update is not None: + ctx.literal(' ').sql(update) + + return self.apply_returning(ctx) + + def _execute(self, database): + if self._returning is None and database.returning_clause \ + and self.table._primary_key: + self._returning = (self.table._primary_key,) + try: + return super(Insert, self)._execute(database) + except self.DefaultValuesException: + pass + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + return database.last_insert_id(cursor, self._query_type) + + +class Delete(_WriteQuery): + def __sql__(self, ctx): + super(Delete, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('DELETE FROM ').sql(self.table) + if self._where is not None: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Index(Node): + def __init__(self, name, table, expressions, unique=False, safe=False, + where=None, using=None): + self._name = name + self._table = Entity(table) if not isinstance(table, Table) else table + self._expressions = expressions + self._where = where + self._unique = unique + self._safe = safe + self._using = using + + @Node.copy + def safe(self, _safe=True): + self._safe = _safe + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def using(self, _using=None): + self._using = _using + + def __sql__(self, ctx): + statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' + with ctx.scope_values(subquery=True): + ctx.literal(statement) + if self._safe: + ctx.literal('IF NOT EXISTS ') + + # Sqlite uses CREATE INDEX . ON , whereas most + # others use: CREATE INDEX ON .
. + if ctx.state.index_schema_prefix and \ + isinstance(self._table, Table) and self._table._schema: + index_name = Entity(self._table._schema, self._name) + table_name = Entity(self._table.__name__) + else: + index_name = Entity(self._name) + table_name = self._table + + (ctx + .sql(index_name) + .literal(' ON ') + .sql(table_name) + .literal(' ')) + if self._using is not None: + ctx.literal('USING %s ' % self._using) + + ctx.sql(EnclosedNodeList([ + SQL(expr) if isinstance(expr, basestring) else expr + for expr in self._expressions])) + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + return ctx + + +class ModelIndex(Index): + def __init__(self, model, fields, unique=False, safe=True, where=None, + using=None, name=None): + self._model = model + if name is None: + name = self._generate_name_from_fields(model, fields) + if using is None: + for field in fields: + if isinstance(field, Field) and hasattr(field, 'index_type'): + using = field.index_type + super(ModelIndex, self).__init__( + name=name, + table=model._meta.table, + expressions=fields, + unique=unique, + safe=safe, + where=where, + using=using) + + def _generate_name_from_fields(self, model, fields): + accum = [] + for field in fields: + if isinstance(field, basestring): + accum.append(field.split()[0]) + else: + if isinstance(field, Node) and not isinstance(field, Field): + field = field.unwrap() + if isinstance(field, Field): + accum.append(field.column_name) + + if not accum: + raise ValueError('Unable to generate a name for the index, please ' + 'explicitly specify a name.') + + clean_field_names = re.sub('[^\w]+', '', '_'.join(accum)) + meta = model._meta + prefix = meta.name if meta.legacy_table_names else meta.table_name + return _truncate_constraint_name('_'.join((prefix, clean_field_names))) + + +def _truncate_constraint_name(constraint, maxlen=64): + if len(constraint) > maxlen: + name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() + constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) + return constraint + + +# DB-API 2.0 EXCEPTIONS. + + +class PeeweeException(Exception): pass +class ImproperlyConfigured(PeeweeException): pass +class DatabaseError(PeeweeException): pass +class DataError(DatabaseError): pass +class IntegrityError(DatabaseError): pass +class InterfaceError(PeeweeException): pass +class InternalError(DatabaseError): pass +class NotSupportedError(DatabaseError): pass +class OperationalError(DatabaseError): pass +class ProgrammingError(DatabaseError): pass + + +class ExceptionWrapper(object): + __slots__ = ('exceptions',) + def __init__(self, exceptions): + self.exceptions = exceptions + def __enter__(self): pass + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + return + # psycopg2.8 shits out a million cute error types. Try to catch em all. + if pg_errors is not None and exc_type.__name__ not in self.exceptions \ + and issubclass(exc_type, pg_errors.Error): + exc_type = exc_type.__bases__[0] + if exc_type.__name__ in self.exceptions: + new_type = self.exceptions[exc_type.__name__] + exc_args = exc_value.args + reraise(new_type, new_type(*exc_args), traceback) + + +EXCEPTIONS = { + 'ConstraintError': IntegrityError, + 'DatabaseError': DatabaseError, + 'DataError': DataError, + 'IntegrityError': IntegrityError, + 'InterfaceError': InterfaceError, + 'InternalError': InternalError, + 'NotSupportedError': NotSupportedError, + 'OperationalError': OperationalError, + 'ProgrammingError': ProgrammingError} + +__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) + + +# DATABASE INTERFACE AND CONNECTION MANAGEMENT. + + +IndexMetadata = collections.namedtuple( + 'IndexMetadata', + ('name', 'sql', 'columns', 'unique', 'table')) +ColumnMetadata = collections.namedtuple( + 'ColumnMetadata', + ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) +ForeignKeyMetadata = collections.namedtuple( + 'ForeignKeyMetadata', + ('column', 'dest_table', 'dest_column', 'table')) +ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) + + +class _ConnectionState(object): + def __init__(self, **kwargs): + super(_ConnectionState, self).__init__(**kwargs) + self.reset() + + def reset(self): + self.closed = True + self.conn = None + self.ctx = [] + self.transactions = [] + + def set_connection(self, conn): + self.conn = conn + self.closed = False + self.ctx = [] + self.transactions = [] + + +class _ConnectionLocal(_ConnectionState, threading.local): pass +class _NoopLock(object): + __slots__ = () + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): pass + + +class ConnectionContext(_callable_context_manager): + __slots__ = ('db',) + def __init__(self, db): self.db = db + def __enter__(self): + if self.db.is_closed(): + self.db.connect() + def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() + + +class Database(_callable_context_manager): + context_class = Context + field_types = {} + operations = {} + param = '?' + quote = '""' + server_version = None + + # Feature toggles. + commit_select = False + compound_select_parentheses = CSQ_PARENTHESES_NEVER + for_update = False + index_schema_prefix = False + limit_max = None + nulls_ordering = False + returning_clause = False + safe_create_index = True + safe_drop_index = True + sequences = False + truncate_table = True + + def __init__(self, database, thread_safe=True, autorollback=False, + field_types=None, operations=None, autocommit=None, **kwargs): + self._field_types = merge_dict(FIELD, self.field_types) + self._operations = merge_dict(OP, self.operations) + if field_types: + self._field_types.update(field_types) + if operations: + self._operations.update(operations) + + self.autorollback = autorollback + self.thread_safe = thread_safe + if thread_safe: + self._state = _ConnectionLocal() + self._lock = threading.Lock() + else: + self._state = _ConnectionState() + self._lock = _NoopLock() + + if autocommit is not None: + __deprecated__('Peewee no longer uses the "autocommit" option, as ' + 'the semantics now require it to always be True. ' + 'Because some database-drivers also use the ' + '"autocommit" parameter, you are receiving a ' + 'warning so you may update your code and remove ' + 'the parameter, as in the future, specifying ' + 'autocommit could impact the behavior of the ' + 'database driver you are using.') + + self.connect_params = {} + self.init(database, **kwargs) + + def init(self, database, **kwargs): + if not self.is_closed(): + self.close() + self.database = database + self.connect_params.update(kwargs) + self.deferred = not bool(database) + + def __enter__(self): + if self.is_closed(): + self.connect() + ctx = self.atomic() + self._state.ctx.append(ctx) + ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ctx = self._state.ctx.pop() + try: + ctx.__exit__(exc_type, exc_val, exc_tb) + finally: + if not self._state.ctx: + self.close() + + def connection_context(self): + return ConnectionContext(self) + + def _connect(self): + raise NotImplementedError + + def connect(self, reuse_if_open=False): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if not self._state.closed: + if reuse_if_open: + return False + raise OperationalError('Connection already opened.') + + self._state.reset() + with __exception_wrapper__: + self._state.set_connection(self._connect()) + if self.server_version is None: + self._set_server_version(self._state.conn) + self._initialize_connection(self._state.conn) + return True + + def _initialize_connection(self, conn): + pass + + def _set_server_version(self, conn): + self.server_version = 0 + + def close(self): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if self.in_transaction(): + raise OperationalError('Attempting to close database while ' + 'transaction is open.') + is_open = not self._state.closed + try: + if is_open: + with __exception_wrapper__: + self._close(self._state.conn) + finally: + self._state.reset() + return is_open + + def _close(self, conn): + conn.close() + + def is_closed(self): + return self._state.closed + + def connection(self): + if self.is_closed(): + self.connect() + return self._state.conn + + def cursor(self, commit=None): + if self.is_closed(): + self.connect() + return self._state.conn.cursor() + + def execute_sql(self, sql, params=None, commit=SENTINEL): + logger.debug((sql, params)) + if commit is SENTINEL: + if self.in_transaction(): + commit = False + elif self.commit_select: + commit = True + else: + commit = not sql[:6].lower().startswith('select') + + with __exception_wrapper__: + cursor = self.cursor(commit) + try: + cursor.execute(sql, params or ()) + except Exception: + if self.autorollback and not self.in_transaction(): + self.rollback() + raise + else: + if commit and not self.in_transaction(): + self.commit() + return cursor + + def execute(self, query, commit=SENTINEL, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + return self.execute_sql(sql, params, commit=commit) + + def get_context_options(self): + return { + 'field_types': self._field_types, + 'operations': self._operations, + 'param': self.param, + 'quote': self.quote, + 'compound_select_parentheses': self.compound_select_parentheses, + 'conflict_statement': self.conflict_statement, + 'conflict_update': self.conflict_update, + 'for_update': self.for_update, + 'index_schema_prefix': self.index_schema_prefix, + 'limit_max': self.limit_max, + 'nulls_ordering': self.nulls_ordering, + } + + def get_sql_context(self, **context_options): + context = self.get_context_options() + if context_options: + context.update(context_options) + return self.context_class(**context) + + def conflict_statement(self, on_conflict, query): + raise NotImplementedError + + def conflict_update(self, on_conflict, query): + raise NotImplementedError + + def _build_on_conflict_update(self, on_conflict, query): + if on_conflict._conflict_target: + stmt = SQL('ON CONFLICT') + target = EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in on_conflict._conflict_target]) + if on_conflict._conflict_where is not None: + target = NodeList([target, SQL('WHERE'), + on_conflict._conflict_where]) + else: + stmt = SQL('ON CONFLICT ON CONSTRAINT') + target = on_conflict._conflict_constraint + if isinstance(target, basestring): + target = Entity(target) + + updates = [] + if on_conflict._preserve: + for column in on_conflict._preserve: + excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), + glue='.') + expression = NodeList((ensure_entity(column), SQL('='), + excluded)) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + converter = k.db_value if isinstance(k, Field) else None + v = Value(v, converter=converter, unpack=False) + else: + v = QualifiedNames(v) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] + if on_conflict._where: + parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) + + return NodeList(parts) + + def last_insert_id(self, cursor, query_type=None): + return cursor.lastrowid + + def rows_affected(self, cursor): + return cursor.rowcount + + def default_values_insert(self, ctx): + return ctx.literal('DEFAULT VALUES') + + def session_start(self): + with self._lock: + return self.transaction().__enter__() + + def session_commit(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.commit(begin=self.in_transaction()) + return True + + def session_rollback(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.rollback(begin=self.in_transaction()) + return True + + def in_transaction(self): + return bool(self._state.transactions) + + def push_transaction(self, transaction): + self._state.transactions.append(transaction) + + def pop_transaction(self): + return self._state.transactions.pop() + + def transaction_depth(self): + return len(self._state.transactions) + + def top_transaction(self): + if self._state.transactions: + return self._state.transactions[-1] + + def atomic(self): + return _atomic(self) + + def manual_commit(self): + return _manual(self) + + def transaction(self): + return _transaction(self) + + def savepoint(self): + return _savepoint(self) + + def begin(self): + if self.is_closed(): + self.connect() + + def commit(self): + return self._state.conn.commit() + + def rollback(self): + return self._state.conn.rollback() + + def batch_commit(self, it, n): + for group in chunked(it, n): + with self.atomic(): + for obj in group: + yield obj + + def table_exists(self, table_name, schema=None): + return table_name in self.get_tables(schema=schema) + + def get_tables(self, schema=None): + raise NotImplementedError + + def get_indexes(self, table, schema=None): + raise NotImplementedError + + def get_columns(self, table, schema=None): + raise NotImplementedError + + def get_primary_keys(self, table, schema=None): + raise NotImplementedError + + def get_foreign_keys(self, table, schema=None): + raise NotImplementedError + + def sequence_exists(self, seq): + raise NotImplementedError + + def create_tables(self, models, **options): + for model in sort_models(models): + model.create_table(**options) + + def drop_tables(self, models, **kwargs): + for model in reversed(sort_models(models)): + model.drop_table(**kwargs) + + def extract_date(self, date_part, date_field): + raise NotImplementedError + + def truncate_date(self, date_part, date_field): + raise NotImplementedError + + def bind(self, models, bind_refs=True, bind_backrefs=True): + for model in models: + model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) + + def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext(models, self, bind_refs, bind_backrefs) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) + + +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + + +class SqliteDatabase(Database): + field_types = { + 'BIGAUTO': FIELD.AUTO, + 'BIGINT': FIELD.INT, + 'BOOL': FIELD.INT, + 'DOUBLE': FIELD.FLOAT, + 'SMALLINT': FIELD.INT, + 'UUID': FIELD.TEXT} + operations = { + 'LIKE': 'GLOB', + 'ILIKE': 'LIKE'} + index_schema_prefix = True + limit_max = -1 + server_version = __sqlite_version__ + truncate_table = False + + def __init__(self, database, *args, **kwargs): + self._pragmas = kwargs.pop('pragmas', ()) + super(SqliteDatabase, self).__init__(database, *args, **kwargs) + self._aggregates = {} + self._collations = {} + self._functions = {} + self._window_functions = {} + self._table_functions = [] + self._extensions = set() + self._attached = {} + self.register_function(_sqlite_date_part, 'date_part', 2) + self.register_function(_sqlite_date_trunc, 'date_trunc', 2) + + def init(self, database, pragmas=None, timeout=5, **kwargs): + if pragmas is not None: + self._pragmas = pragmas + if isinstance(self._pragmas, dict): + self._pragmas = list(self._pragmas.items()) + self._timeout = timeout + super(SqliteDatabase, self).init(database, **kwargs) + + def _set_server_version(self, conn): + pass + + def _connect(self): + if sqlite3 is None: + raise ImproperlyConfigured('SQLite driver not installed!') + conn = sqlite3.connect(self.database, timeout=self._timeout, + isolation_level=None, **self.connect_params) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + if self._attached: + self._attach_databases(conn) + if self._pragmas: + self._set_pragmas(conn) + self._load_aggregates(conn) + self._load_collations(conn) + self._load_functions(conn) + if self.server_version >= (3, 25, 0): + self._load_window_functions(conn) + if self._table_functions: + for table_function in self._table_functions: + table_function.register(conn) + if self._extensions: + self._load_extensions(conn) + + def _set_pragmas(self, conn): + cursor = conn.cursor() + for pragma, value in self._pragmas: + cursor.execute('PRAGMA %s = %s;' % (pragma, value)) + cursor.close() + + def _attach_databases(self, conn): + cursor = conn.cursor() + for name, db in self._attached.items(): + cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) + cursor.close() + + def pragma(self, key, value=SENTINEL, permanent=False, schema=None): + if schema is not None: + key = '"%s".%s' % (schema, key) + sql = 'PRAGMA %s' % key + if value is not SENTINEL: + sql += ' = %s' % (value or 0) + if permanent: + pragmas = dict(self._pragmas or ()) + pragmas[key] = value + self._pragmas = list(pragmas.items()) + elif permanent: + raise ValueError('Cannot specify a permanent pragma without value') + row = self.execute_sql(sql).fetchone() + if row: + return row[0] + + cache_size = __pragma__('cache_size') + foreign_keys = __pragma__('foreign_keys') + journal_mode = __pragma__('journal_mode') + journal_size_limit = __pragma__('journal_size_limit') + mmap_size = __pragma__('mmap_size') + page_size = __pragma__('page_size') + read_uncommitted = __pragma__('read_uncommitted') + synchronous = __pragma__('synchronous') + wal_autocheckpoint = __pragma__('wal_autocheckpoint') + + @property + def timeout(self): + return self._timeout + + @timeout.setter + def timeout(self, seconds): + if self._timeout == seconds: + return + + self._timeout = seconds + if not self.is_closed(): + # PySQLite multiplies user timeout by 1000, but the unit of the + # timeout PRAGMA is actually milliseconds. + self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + conn.create_aggregate(name, num_params, klass) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.create_collation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.create_function(name, num_params, fn) + + def _load_window_functions(self, conn): + for name, (klass, num_params) in self._window_functions.items(): + conn.create_window_function(name, num_params, klass) + + def register_aggregate(self, klass, name=None, num_params=-1): + self._aggregates[name or klass.__name__.lower()] = (klass, num_params) + if not self.is_closed(): + self._load_aggregates(self.connection()) + + def aggregate(self, name=None, num_params=-1): + def decorator(klass): + self.register_aggregate(klass, name, num_params) + return klass + return decorator + + def register_collation(self, fn, name=None): + name = name or fn.__name__ + def _collation(*args): + expressions = args + (SQL('collate %s' % name),) + return NodeList(expressions) + fn.collation = _collation + self._collations[name] = fn + if not self.is_closed(): + self._load_collations(self.connection()) + + def collation(self, name=None): + def decorator(fn): + self.register_collation(fn, name) + return fn + return decorator + + def register_function(self, fn, name=None, num_params=-1): + self._functions[name or fn.__name__] = (fn, num_params) + if not self.is_closed(): + self._load_functions(self.connection()) + + def func(self, name=None, num_params=-1): + def decorator(fn): + self.register_function(fn, name, num_params) + return fn + return decorator + + def register_window_function(self, klass, name=None, num_params=-1): + name = name or klass.__name__.lower() + self._window_functions[name] = (klass, num_params) + if not self.is_closed(): + self._load_window_functions(self.connection()) + + def window_function(self, name=None, num_params=-1): + def decorator(klass): + self.register_window_function(klass, name, num_params) + return klass + return decorator + + def register_table_function(self, klass, name=None): + if name is not None: + klass.name = name + self._table_functions.append(klass) + if not self.is_closed(): + klass.register(self.connection()) + + def table_function(self, name=None): + def decorator(klass): + self.register_table_function(klass, name) + return klass + return decorator + + def unregister_aggregate(self, name): + del(self._aggregates[name]) + + def unregister_collation(self, name): + del(self._collations[name]) + + def unregister_function(self, name): + del(self._functions[name]) + + def unregister_window_function(self, name): + del(self._window_functions[name]) + + def unregister_table_function(self, name): + for idx, klass in enumerate(self._table_functions): + if klass.name == name: + break + else: + return False + self._table_functions.pop(idx) + return True + + def _load_extensions(self, conn): + conn.enable_load_extension(True) + for extension in self._extensions: + conn.load_extension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enable_load_extension(True) + conn.load_extension(extension) + + def unload_extension(self, extension): + self._extensions.remove(extension) + + def attach(self, filename, name): + if name in self._attached: + if self._attached[name] == filename: + return False + raise OperationalError('schema "%s" already attached.' % name) + + self._attached[name] = filename + if not self.is_closed(): + self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) + return True + + def detach(self, name): + if name not in self._attached: + return False + + del self._attached[name] + if not self.is_closed(): + self.execute_sql('DETACH DATABASE "%s"' % name) + return True + + def atomic(self, lock_type=None): + return _atomic(self, lock_type=lock_type) + + def transaction(self, lock_type=None): + return _transaction(self, lock_type=lock_type) + + def begin(self, lock_type=None): + statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' + self.execute_sql(statement, commit=False) + + def get_tables(self, schema=None): + schema = schema or 'main' + cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' + 'type=? ORDER BY name' % schema, ('table',)) + return [row for row, in cursor.fetchall()] + + def get_views(self, schema=None): + sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' + 'ORDER BY name') % (schema or 'main') + return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] + + def get_indexes(self, table, schema=None): + schema = schema or 'main' + query = ('SELECT name, sql FROM "%s".sqlite_master ' + 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema + cursor = self.execute_sql(query, (table, 'index')) + index_to_sql = dict(cursor.fetchall()) + + # Determine which indexes have a unique constraint. + unique_indexes = set() + cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % + (schema, table)) + for row in cursor.fetchall(): + name = row[1] + is_unique = int(row[2]) == 1 + if is_unique: + unique_indexes.add(name) + + # Retrieve the indexed columns. + index_columns = {} + for index_name in sorted(index_to_sql): + cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % + (schema, index_name)) + index_columns[index_name] = [row[2] for row in cursor.fetchall()] + + return [ + IndexMetadata( + name, + index_to_sql[name], + index_columns[name], + name in unique_indexes, + table) + for name in sorted(index_to_sql)] + + def get_columns(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) + for r in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % + (schema or 'main', table)) + return [ForeignKeyMetadata(row[3], row[2], row[4], table) + for row in cursor.fetchall()] + + def get_binary_type(self): + return sqlite3.Binary + + def conflict_statement(self, on_conflict, query): + action = on_conflict._action.lower() if on_conflict._action else '' + if action and action not in ('nothing', 'update'): + return SQL('INSERT OR %s' % on_conflict._action.upper()) + + def conflict_update(self, oc, query): + # Sqlite prior to 3.24.0 does not support Postgres-style upsert. + if self.server_version < (3, 24, 0) and \ + any((oc._preserve, oc._update, oc._where, oc._conflict_target, + oc._conflict_constraint)): + raise ValueError('SQLite does not support specifying which values ' + 'to preserve or update.') + + action = oc._action.lower() if oc._action else '' + if action and action not in ('nothing', 'update', ''): + return + + if action == 'nothing': + return SQL('ON CONFLICT DO NOTHING') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"NOTHING".') + elif oc._conflict_constraint: + raise ValueError('SQLite does not support specifying named ' + 'constraints for conflict resolution.') + elif not oc._conflict_target: + raise ValueError('SQLite requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.date_part(date_part, date_field) + + def truncate_date(self, date_part, date_field): + return fn.date_trunc(date_part, date_field) + + +class PostgresqlDatabase(Database): + field_types = { + 'AUTO': 'SERIAL', + 'BIGAUTO': 'BIGSERIAL', + 'BLOB': 'BYTEA', + 'BOOL': 'BOOLEAN', + 'DATETIME': 'TIMESTAMP', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'UUID': 'UUID', + 'UUIDB': 'BYTEA'} + operations = {'REGEXP': '~', 'IREGEXP': '~*'} + param = '%s' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_ALWAYS + for_update = True + nulls_ordering = True + returning_clause = True + safe_create_index = False + sequences = True + + def init(self, database, register_unicode=True, encoding=None, **kwargs): + self._register_unicode = register_unicode + self._encoding = encoding + super(PostgresqlDatabase, self).init(database, **kwargs) + + def _connect(self): + if psycopg2 is None: + raise ImproperlyConfigured('Postgres driver not installed!') + conn = psycopg2.connect(database=self.database, **self.connect_params) + if self._register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + if self._encoding: + conn.set_client_encoding(self._encoding) + return conn + + def _set_server_version(self, conn): + self.server_version = conn.server_version + if self.server_version >= 90600: + self.safe_create_index = True + + def last_insert_id(self, cursor, query_type=None): + try: + return cursor if query_type else cursor[0][0] + except (IndexError, KeyError, TypeError): + pass + + def get_tables(self, schema=None): + query = ('SELECT tablename FROM pg_catalog.pg_tables ' + 'WHERE schemaname = %s ORDER BY tablename') + cursor = self.execute_sql(query, (schema or 'public',)) + return [table for table, in cursor.fetchall()] + + def get_views(self, schema=None): + query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' + 'WHERE schemaname = %s ORDER BY viewname') + cursor = self.execute_sql(query, (schema or 'public',)) + return [ViewMetadata(v, sql.strip()) for (v, sql) in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + query = """ + SELECT + i.relname, idxs.indexdef, idx.indisunique, + array_to_string(array_agg(cols.attname), ',') + FROM pg_catalog.pg_class AS t + INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid + INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid + INNER JOIN pg_catalog.pg_indexes AS idxs ON + (idxs.tablename = t.relname AND idxs.indexname = i.relname) + LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON + (cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey)) + WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s + GROUP BY i.relname, idxs.indexdef, idx.indisunique + ORDER BY idx.indisunique DESC, i.relname;""" + cursor = self.execute_sql(query, (table, 'r', schema or 'public')) + return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table) + for row in cursor.fetchall()] + + def get_columns(self, table, schema=None): + query = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = %s + ORDER BY ordinal_position""" + cursor = self.execute_sql(query, (table, schema or 'public')) + pks = set(self.get_primary_keys(table, schema)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + query = """ + SELECT kc.column_name + FROM information_schema.table_constraints AS tc + INNER JOIN information_schema.key_column_usage AS kc ON ( + tc.table_name = kc.table_name AND + tc.table_schema = kc.table_schema AND + tc.constraint_name = kc.constraint_name) + WHERE + tc.constraint_type = %s AND + tc.table_name = %s AND + tc.table_schema = %s""" + ctype = 'PRIMARY KEY' + cursor = self.execute_sql(query, (ctype, table, schema or 'public')) + return [pk for pk, in cursor.fetchall()] + + def get_foreign_keys(self, table, schema=None): + sql = """ + SELECT + kcu.column_name, ccu.table_name, ccu.column_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON (tc.constraint_name = kcu.constraint_name AND + tc.constraint_schema = kcu.constraint_schema) + JOIN information_schema.constraint_column_usage AS ccu + ON (ccu.constraint_name = tc.constraint_name AND + ccu.constraint_schema = tc.constraint_schema) + WHERE + tc.constraint_type = 'FOREIGN KEY' AND + tc.table_name = %s AND + tc.table_schema = %s""" + cursor = self.execute_sql(sql, (table, schema or 'public')) + return [ForeignKeyMetadata(row[0], row[1], row[2], table) + for row in cursor.fetchall()] + + def sequence_exists(self, sequence): + res = self.execute_sql(""" + SELECT COUNT(*) FROM pg_class, pg_namespace + WHERE relkind='S' + AND pg_class.relnamespace = pg_namespace.oid + AND relname=%s""", (sequence,)) + return bool(res.fetchone()[0]) + + def get_binary_type(self): + return psycopg2.Binary + + def conflict_statement(self, on_conflict, query): + return + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + return SQL('ON CONFLICT DO NOTHING') + elif action and action != 'update': + raise ValueError('The only supported actions for conflict ' + 'resolution with Postgresql are "ignore" or ' + '"update".') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"IGNORE".') + elif not (oc._conflict_target or oc._conflict_constraint): + raise ValueError('Postgres requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_TRUNC(date_part, date_field) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) + + def set_time_zone(self, timezone): + self.execute_sql('set time zone "%s";' % timezone) + + +class MySQLDatabase(Database): + field_types = { + 'AUTO': 'INTEGER AUTO_INCREMENT', + 'BIGAUTO': 'BIGINT AUTO_INCREMENT', + 'BOOL': 'BOOL', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'FLOAT': 'FLOAT', + 'UUID': 'VARCHAR(40)', + 'UUIDB': 'VARBINARY(16)'} + operations = { + 'LIKE': 'LIKE BINARY', + 'ILIKE': 'LIKE', + 'REGEXP': 'REGEXP BINARY', + 'IREGEXP': 'REGEXP', + 'XOR': 'XOR'} + param = '%s' + quote = '``' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_UNNESTED + for_update = True + limit_max = 2 ** 64 - 1 + safe_create_index = False + safe_drop_index = False + + def init(self, database, **kwargs): + params = {'charset': 'utf8', 'use_unicode': True} + params.update(kwargs) + if 'password' in params and mysql_passwd: + params['passwd'] = params.pop('password') + super(MySQLDatabase, self).init(database, **params) + + def _connect(self): + if mysql is None: + raise ImproperlyConfigured('MySQL driver not installed!') + conn = mysql.connect(db=self.database, **self.connect_params) + return conn + + def _set_server_version(self, conn): + try: + version_raw = conn.server_version + except AttributeError: + version_raw = conn.get_server_info() + self.server_version = self._extract_server_version(version_raw) + + def _extract_server_version(self, version): + version = version.lower() + if 'maria' in version: + match_obj = re.search(r'(1\d\.\d+\.\d+)', version) + else: + match_obj = re.search(r'(\d\.\d+\.\d+)', version) + if match_obj is not None: + return tuple(int(num) for num in match_obj.groups()[0].split('.')) + + warnings.warn('Unable to determine MySQL version: "%s"' % version) + return (0, 0, 0) # Unable to determine version! + + def default_values_insert(self, ctx): + return ctx.literal('() VALUES ()') + + def get_tables(self, schema=None): + query = ('SELECT table_name FROM information_schema.tables ' + 'WHERE table_schema = DATABASE() AND table_type != %s ' + 'ORDER BY table_name') + return [table for table, in self.execute_sql(query, ('VIEW',))] + + def get_views(self, schema=None): + query = ('SELECT table_name, view_definition ' + 'FROM information_schema.views ' + 'WHERE table_schema = DATABASE() ORDER BY table_name') + cursor = self.execute_sql(query) + return [ViewMetadata(*row) for row in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + unique = set() + indexes = {} + for row in cursor.fetchall(): + if not row[1]: + unique.add(row[2]) + indexes.setdefault(row[2], []) + indexes[row[2]].append(row[4]) + return [IndexMetadata(name, None, indexes[name], name in unique, table) + for name in indexes] + + def get_columns(self, table, schema=None): + sql = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = DATABASE()""" + cursor = self.execute_sql(sql, (table,)) + pks = set(self.get_primary_keys(table)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + return [row[4] for row in + filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + query = """ + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""" + cursor = self.execute_sql(query, (table,)) + return [ + ForeignKeyMetadata(column, dest_table, dest_column, table) + for column, dest_table, dest_column in cursor.fetchall()] + + def get_binary_type(self): + return mysql.Binary + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action == 'replace': + return SQL('REPLACE') + elif action == 'ignore': + return SQL('INSERT IGNORE') + elif action != 'update': + raise ValueError('Un-supported action for conflict resolution. ' + 'MySQL supports REPLACE, IGNORE and UPDATE.') + + def conflict_update(self, on_conflict, query): + if on_conflict._where or on_conflict._conflict_target or \ + on_conflict._conflict_constraint: + raise ValueError('MySQL does not support the specification of ' + 'where clauses or conflict targets for conflict ' + 'resolution.') + + updates = [] + if on_conflict._preserve: + # Here we need to determine which function to use, which varies + # depending on the MySQL server version. MySQL and MariaDB prior to + # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". + version = self.server_version or (0,) + if version[0] == 10 and version >= (10, 3, 3): + VALUE_FN = fn.VALUE + else: + VALUE_FN = fn.VALUES + + for column in on_conflict._preserve: + entity = ensure_entity(column) + expression = NodeList(( + ensure_entity(column), + SQL('='), + VALUE_FN(entity))) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + converter = k.db_value if isinstance(k, Field) else None + v = Value(v, converter=converter, unpack=False) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + if updates: + return NodeList((SQL('ON DUPLICATE KEY UPDATE'), + CommaNodeList(updates))) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part]) + + def get_noop_select(self, ctx): + return ctx.literal('DO 0') + + +# TRANSACTION CONTROL. + + +class _manual(_callable_context_manager): + def __init__(self, db): + self.db = db + + def __enter__(self): + top = self.db.top_transaction() + if top and not isinstance(self.db.top_transaction(), _manual): + raise ValueError('Cannot enter manual commit block while a ' + 'transaction is active.') + self.db.push_transaction(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.db.pop_transaction() is not self: + raise ValueError('Transaction stack corrupted while exiting ' + 'manual commit block.') + + +class _atomic(_callable_context_manager): + def __init__(self, db, lock_type=None): + self.db = db + self._lock_type = lock_type + self._transaction_args = (lock_type,) if lock_type is not None else () + + def __enter__(self): + if self.db.transaction_depth() == 0: + self._helper = self.db.transaction(*self._transaction_args) + elif isinstance(self.db.top_transaction(), _manual): + raise ValueError('Cannot enter atomic commit block while in ' + 'manual commit mode.') + else: + self._helper = self.db.savepoint() + return self._helper.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._helper.__exit__(exc_type, exc_val, exc_tb) + + +class _transaction(_callable_context_manager): + def __init__(self, db, lock_type=None): + self.db = db + self._lock_type = lock_type + + def _begin(self): + if self._lock_type: + self.db.begin(self._lock_type) + else: + self.db.begin() + + def commit(self, begin=True): + self.db.commit() + if begin: + self._begin() + + def rollback(self, begin=True): + self.db.rollback() + if begin: + self._begin() + + def __enter__(self): + if self.db.transaction_depth() == 0: + self._begin() + self.db.push_transaction(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.rollback(False) + elif self.db.transaction_depth() == 1: + try: + self.commit(False) + except: + self.rollback(False) + raise + finally: + self.db.pop_transaction() + + +class _savepoint(_callable_context_manager): + def __init__(self, db, sid=None): + self.db = db + self.sid = sid or 's' + uuid.uuid4().hex + self.quoted_sid = self.sid.join(self.db.quote) + + def _begin(self): + self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) + + def commit(self, begin=True): + self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) + if begin: self._begin() + + def rollback(self): + self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + self.rollback() + else: + try: + self.commit(begin=False) + except: + self.rollback() + raise + + +# CURSOR REPRESENTATIONS. + + +class CursorWrapper(object): + def __init__(self, cursor): + self.cursor = cursor + self.count = 0 + self.index = 0 + self.initialized = False + self.populated = False + self.row_cache = [] + + def __iter__(self): + if self.populated: + return iter(self.row_cache) + return ResultIterator(self) + + def __getitem__(self, item): + if isinstance(item, slice): + stop = item.stop + if stop is None or stop < 0: + self.fill_cache() + else: + self.fill_cache(stop) + return self.row_cache[item] + elif isinstance(item, int): + self.fill_cache(item if item > 0 else 0) + return self.row_cache[item] + else: + raise ValueError('CursorWrapper only supports integer and slice ' + 'indexes.') + + def __len__(self): + self.fill_cache() + return self.count + + def initialize(self): + pass + + def iterate(self, cache=True): + row = self.cursor.fetchone() + if row is None: + self.populated = True + self.cursor.close() + raise StopIteration + elif not self.initialized: + self.initialize() # Lazy initialization. + self.initialized = True + self.count += 1 + result = self.process_row(row) + if cache: + self.row_cache.append(result) + return result + + def process_row(self, row): + return row + + def iterator(self): + """Efficient one-pass iteration over the result set.""" + while True: + try: + yield self.iterate(False) + except StopIteration: + return + + def fill_cache(self, n=0): + n = n or float('Inf') + if n < 0: + raise ValueError('Negative values are not supported.') + + iterator = ResultIterator(self) + iterator.index = self.count + while not self.populated and (n > self.count): + try: + iterator.next() + except StopIteration: + break + + +class DictCursorWrapper(CursorWrapper): + def _initialize_columns(self): + description = self.cursor.description + self.columns = [t[0][t[0].find('.') + 1:].strip('"') + for t in description] + self.ncols = len(description) + + initialize = _initialize_columns + + def _row_to_dict(self, row): + result = {} + for i in range(self.ncols): + result.setdefault(self.columns[i], row[i]) # Do not overwrite. + return result + + process_row = _row_to_dict + + +class NamedTupleCursorWrapper(CursorWrapper): + def initialize(self): + description = self.cursor.description + self.tuple_class = collections.namedtuple( + 'Row', + [col[0][col[0].find('.') + 1:].strip('"') for col in description]) + + def process_row(self, row): + return self.tuple_class(*row) + + +class ObjectCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, constructor): + super(ObjectCursorWrapper, self).__init__(cursor) + self.constructor = constructor + + def process_row(self, row): + row_dict = self._row_to_dict(row) + return self.constructor(**row_dict) + + +class ResultIterator(object): + def __init__(self, cursor_wrapper): + self.cursor_wrapper = cursor_wrapper + self.index = 0 + + def __iter__(self): + return self + + def next(self): + if self.index < self.cursor_wrapper.count: + obj = self.cursor_wrapper.row_cache[self.index] + elif not self.cursor_wrapper.populated: + self.cursor_wrapper.iterate() + obj = self.cursor_wrapper.row_cache[self.index] + else: + raise StopIteration + self.index += 1 + return obj + + __next__ = next + +# FIELDS + +class FieldAccessor(object): + def __init__(self, model, field, name): + self.model = model + self.field = field + self.name = name + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.name) + return self.field + + def __set__(self, instance, value): + instance.__data__[self.name] = value + instance._dirty.add(self.name) + + +class ForeignKeyAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ForeignKeyAccessor, self).__init__(model, field, name) + self.rel_model = field.rel_model + + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None or self.name in instance.__rel__: + if self.name not in instance.__rel__: + obj = self.rel_model.get(self.field.rel_field == value) + instance.__rel__[self.name] = obj + return instance.__rel__[self.name] + elif not self.field.null: + raise self.rel_model.DoesNotExist + return value + + def __get__(self, instance, instance_type=None): + if instance is not None: + return self.get_rel_instance(instance) + return self.field + + def __set__(self, instance, obj): + if isinstance(obj, self.rel_model): + instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) + instance.__rel__[self.name] = obj + else: + fk_value = instance.__data__.get(self.name) + instance.__data__[self.name] = obj + if obj != fk_value and self.name in instance.__rel__: + del instance.__rel__[self.name] + instance._dirty.add(self.name) + + +class NoQueryForeignKeyAccessor(ForeignKeyAccessor): + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None: + return instance.__rel__.get(self.name, value) + elif not self.field.null: + raise self.rel_model.DoesNotExist + + +class BackrefAccessor(object): + def __init__(self, field): + self.field = field + self.model = field.rel_model + self.rel_model = field.model + + def __get__(self, instance, instance_type=None): + if instance is not None: + dest = self.field.rel_field.name + return (self.rel_model + .select() + .where(self.field == getattr(instance, dest))) + return self + + +class ObjectIdAccessor(object): + """Gives direct access to the underlying id""" + def __init__(self, field): + self.field = field + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.field.name) + return self.field + + def __set__(self, instance, value): + setattr(instance, self.field.name, value) + + +class Field(ColumnBase): + _field_counter = 0 + _order = 0 + accessor_class = FieldAccessor + auto_increment = False + default_index_type = None + field_type = 'DEFAULT' + + def __init__(self, null=False, index=False, unique=False, column_name=None, + default=None, primary_key=False, constraints=None, + sequence=None, collation=None, unindexed=False, choices=None, + help_text=None, verbose_name=None, index_type=None, + db_column=None, _hidden=False): + if db_column is not None: + __deprecated__('"db_column" has been deprecated in favor of ' + '"column_name" for Field objects.') + column_name = db_column + + self.null = null + self.index = index + self.unique = unique + self.column_name = column_name + self.default = default + self.primary_key = primary_key + self.constraints = constraints # List of column constraints. + self.sequence = sequence # Name of sequence, e.g. foo_id_seq. + self.collation = collation + self.unindexed = unindexed + self.choices = choices + self.help_text = help_text + self.verbose_name = verbose_name + self.index_type = index_type or self.default_index_type + self._hidden = _hidden + + # Used internally for recovering the order in which Fields were defined + # on the Model class. + Field._field_counter += 1 + self._order = Field._field_counter + self._sort_key = (self.primary_key and 1 or 2), self._order + + def __hash__(self): + return hash(self.name + '.' + self.model.__name__) + + def __repr__(self): + if hasattr(self, 'model') and getattr(self, 'name', None): + return '<%s: %s.%s>' % (type(self).__name__, + self.model.__name__, + self.name) + return '<%s: (unbound)>' % type(self).__name__ + + def bind(self, model, name, set_attribute=True): + self.model = model + self.name = name + self.column_name = self.column_name or name + if set_attribute: + setattr(model, name, self.accessor_class(model, self, name)) + + @property + def column(self): + return Column(self.model._meta.table, self.column_name) + + def adapt(self, value): + return value + + def db_value(self, value): + return value if value is None else self.adapt(value) + + def python_value(self, value): + return value if value is None else self.adapt(value) + + def get_sort_key(self, ctx): + return self._sort_key + + def __sql__(self, ctx): + return ctx.sql(self.column) + + def get_modifiers(self): + return + + def ddl_datatype(self, ctx): + if ctx and ctx.state.field_types: + column_type = ctx.state.field_types.get(self.field_type, + self.field_type) + else: + column_type = self.field_type + + modifiers = self.get_modifiers() + if column_type and modifiers: + modifier_literal = ', '.join([str(m) for m in modifiers]) + return SQL('%s(%s)' % (column_type, modifier_literal)) + else: + return SQL(column_type) + + def ddl(self, ctx): + accum = [Entity(self.column_name)] + data_type = self.ddl_datatype(ctx) + if data_type: + accum.append(data_type) + if self.unindexed: + accum.append(SQL('UNINDEXED')) + if not self.null: + accum.append(SQL('NOT NULL')) + if self.primary_key: + accum.append(SQL('PRIMARY KEY')) + if self.sequence: + accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) + if self.constraints: + accum.extend(self.constraints) + if self.collation: + accum.append(SQL('COLLATE %s' % self.collation)) + return NodeList(accum) + + +class IntegerField(Field): + field_type = 'INT' + adapt = int + + +class BigIntegerField(IntegerField): + field_type = 'BIGINT' + + +class SmallIntegerField(IntegerField): + field_type = 'SMALLINT' + + +class AutoField(IntegerField): + auto_increment = True + field_type = 'AUTO' + + def __init__(self, *args, **kwargs): + if kwargs.get('primary_key') is False: + raise ValueError('%s must always be a primary key.' % type(self)) + kwargs['primary_key'] = True + super(AutoField, self).__init__(*args, **kwargs) + + +class BigAutoField(AutoField): + field_type = 'BIGAUTO' + + +class IdentityField(AutoField): + field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' + + +class PrimaryKeyField(AutoField): + def __init__(self, *args, **kwargs): + __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' + 'Please update your code accordingly as this will be ' + 'completely removed in a subsequent release.') + super(PrimaryKeyField, self).__init__(*args, **kwargs) + + +class FloatField(Field): + field_type = 'FLOAT' + adapt = float + + +class DoubleField(FloatField): + field_type = 'DOUBLE' + + +class DecimalField(Field): + field_type = 'DECIMAL' + + def __init__(self, max_digits=10, decimal_places=5, auto_round=False, + rounding=None, *args, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.auto_round = auto_round + self.rounding = rounding or decimal.DefaultContext.rounding + super(DecimalField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return [self.max_digits, self.decimal_places] + + def db_value(self, value): + D = decimal.Decimal + if not value: + return value if value is None else D(0) + if self.auto_round: + exp = D(10) ** (-self.decimal_places) + rounding = self.rounding + return D(text_type(value)).quantize(exp, rounding=rounding) + return value + + def python_value(self, value): + if value is not None: + if isinstance(value, decimal.Decimal): + return value + return decimal.Decimal(text_type(value)) + + +class _StringField(Field): + def adapt(self, value): + if isinstance(value, text_type): + return value + elif isinstance(value, bytes_type): + return value.decode('utf-8') + return text_type(value) + + def __add__(self, other): return self.concat(other) + def __radd__(self, other): return other.concat(self) + + +class CharField(_StringField): + field_type = 'VARCHAR' + + def __init__(self, max_length=255, *args, **kwargs): + self.max_length = max_length + super(CharField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return self.max_length and [self.max_length] or None + + +class FixedCharField(CharField): + field_type = 'CHAR' + + def python_value(self, value): + value = super(FixedCharField, self).python_value(value) + if value: + value = value.strip() + return value + + +class TextField(_StringField): + field_type = 'TEXT' + + +class BlobField(Field): + field_type = 'BLOB' + + def _db_hook(self, database): + if database is None: + self._constructor = bytearray + else: + self._constructor = database.get_binary_type() + + def bind(self, model, name, set_attribute=True): + self._constructor = bytearray + if model._meta.database: + if isinstance(model._meta.database, Proxy): + model._meta.database.attach_callback(self._db_hook) + else: + self._db_hook(model._meta.database) + + # Attach a hook to the model metadata; in the event the database is + # changed or set at run-time, we will be sure to apply our callback and + # use the proper data-type for our database driver. + model._meta._db_hooks.append(self._db_hook) + return super(BlobField, self).bind(model, name, set_attribute) + + def db_value(self, value): + if isinstance(value, text_type): + value = value.encode('raw_unicode_escape') + if isinstance(value, bytes_type): + return self._constructor(value) + return value + + +class BitField(BitwiseMixin, BigIntegerField): + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', 0) + super(BitField, self).__init__(*args, **kwargs) + self.__current_flag = 1 + + def flag(self, value=None): + if value is None: + value = self.__current_flag + self.__current_flag <<= 1 + else: + self.__current_flag = value << 1 + + class FlagDescriptor(object): + def __init__(self, field, value): + self._field = field + self._value = value + def __get__(self, instance, instance_type=None): + if instance is None: + return self._field.bin_and(self._value) != 0 + value = getattr(instance, self._field.name) or 0 + return (value & self._value) != 0 + def __set__(self, instance, is_set): + if is_set not in (True, False): + raise ValueError('Value must be either True or False') + value = getattr(instance, self._field.name) or 0 + if is_set: + value |= self._value + else: + value &= ~self._value + setattr(instance, self._field.name, value) + return FlagDescriptor(self, value) + + +class BigBitFieldData(object): + def __init__(self, instance, name): + self.instance = instance + self.name = name + value = self.instance.__data__.get(self.name) + if not value: + value = bytearray() + elif not isinstance(value, bytearray): + value = bytearray(value) + self._buffer = self.instance.__data__[self.name] = value + + def _ensure_length(self, idx): + byte_num, byte_offset = divmod(idx, 8) + cur_size = len(self._buffer) + if cur_size <= byte_num: + self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) + return byte_num, byte_offset + + def set_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] |= (1 << byte_offset) + + def clear_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] &= ~(1 << byte_offset) + + def toggle_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] ^= (1 << byte_offset) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def is_set(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def __repr__(self): + return repr(self._buffer) + + +class BigBitFieldAccessor(FieldAccessor): + def __get__(self, instance, instance_type=None): + if instance is None: + return self.field + return BigBitFieldData(instance, self.name) + def __set__(self, instance, value): + if isinstance(value, memoryview): + value = value.tobytes() + elif isinstance(value, buffer_type): + value = bytes(value) + elif isinstance(value, bytearray): + value = bytes_type(value) + elif isinstance(value, BigBitFieldData): + value = bytes_type(value._buffer) + elif isinstance(value, text_type): + value = value.encode('utf-8') + elif not isinstance(value, bytes_type): + raise ValueError('Value must be either a bytes, memoryview or ' + 'BigBitFieldData instance.') + super(BigBitFieldAccessor, self).__set__(instance, value) + + +class BigBitField(BlobField): + accessor_class = BigBitFieldAccessor + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', bytes_type) + super(BigBitField, self).__init__(*args, **kwargs) + + def db_value(self, value): + return bytes_type(value) if value is not None else value + + +class UUIDField(Field): + field_type = 'UUID' + + def db_value(self, value): + if isinstance(value, basestring) and len(value) == 32: + # Hex string. No transformation is necessary. + return value + elif isinstance(value, bytes) and len(value) == 16: + # Allow raw binary representation. + value = uuid.UUID(bytes=value) + if isinstance(value, uuid.UUID): + return value.hex + try: + return uuid.UUID(value).hex + except: + return value + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) if value is not None else None + + +class BinaryUUIDField(BlobField): + field_type = 'UUIDB' + + def db_value(self, value): + if isinstance(value, bytes) and len(value) == 16: + # Raw binary value. No transformation is necessary. + return self._constructor(value) + elif isinstance(value, basestring) and len(value) == 32: + # Allow hex string representation. + value = uuid.UUID(hex=value) + if isinstance(value, uuid.UUID): + return self._constructor(value.bytes) + elif value is not None: + raise ValueError('value for binary UUID field must be UUID(), ' + 'a hexadecimal string, or a bytes object.') + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + elif isinstance(value, memoryview): + value = value.tobytes() + elif value and not isinstance(value, bytes): + value = bytes(value) + return uuid.UUID(bytes=value) if value is not None else None + + +def _date_part(date_part): + def dec(self): + return self.model._meta.database.extract_date(date_part, self) + return dec + +def format_date_time(value, formats, post_process=None): + post_process = post_process or (lambda x: x) + for fmt in formats: + try: + return post_process(datetime.datetime.strptime(value, fmt)) + except ValueError: + pass + return value + + +class _BaseFormattedField(Field): + formats = None + + def __init__(self, formats=None, *args, **kwargs): + if formats is not None: + self.formats = formats + super(_BaseFormattedField, self).__init__(*args, **kwargs) + + +class DateTimeField(_BaseFormattedField): + field_type = 'DATETIME' + formats = [ + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + return format_date_time(value, self.formats) + return value + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +class DateField(_BaseFormattedField): + field_type = 'DATE' + formats = [ + '%Y-%m-%d', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + pp = lambda x: x.date() + return format_date_time(value, self.formats, pp) + elif value and isinstance(value, datetime.datetime): + return value.date() + return value + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + + +class TimeField(_BaseFormattedField): + field_type = 'TIME' + formats = [ + '%H:%M:%S.%f', + '%H:%M:%S', + '%H:%M', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + ] + + def adapt(self, value): + if value: + if isinstance(value, basestring): + pp = lambda x: x.time() + return format_date_time(value, self.formats, pp) + elif isinstance(value, datetime.datetime): + return value.time() + if value is not None and isinstance(value, datetime.timedelta): + return (datetime.datetime.min + value).time() + return value + + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +class TimestampField(BigIntegerField): + # Support second -> microsecond resolution. + valid_resolutions = [10**i for i in range(7)] + + def __init__(self, *args, **kwargs): + self.resolution = kwargs.pop('resolution', None) + if not self.resolution: + self.resolution = 1 + elif self.resolution in range(7): + self.resolution = 10 ** self.resolution + elif self.resolution not in self.valid_resolutions: + raise ValueError('TimestampField resolution must be one of: %s' % + ', '.join(str(i) for i in self.valid_resolutions)) + + self.utc = kwargs.pop('utc', False) or False + dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now + kwargs.setdefault('default', dflt) + super(TimestampField, self).__init__(*args, **kwargs) + + def local_to_utc(self, dt): + # Convert naive local datetime into naive UTC, e.g.: + # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. + # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) + + def utc_to_local(self, dt): + # Convert a naive UTC datetime into local time, e.g.: + # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. + # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + ts = calendar.timegm(dt.utctimetuple()) + return datetime.datetime.fromtimestamp(ts) + + def db_value(self, value): + if value is None: + return + + if isinstance(value, datetime.datetime): + pass + elif isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + else: + return int(round(value * self.resolution)) + + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + timestamp = calendar.timegm(value.utctimetuple()) + else: + timestamp = time.mktime(value.timetuple()) + + if self.resolution > 1: + timestamp += (value.microsecond * .000001) + timestamp *= self.resolution + return int(round(timestamp)) + + def python_value(self, value): + if value is not None and isinstance(value, (int, float, long)): + if self.resolution > 1: + ticks_to_microsecond = 1000000 // self.resolution + value, ticks = divmod(value, self.resolution) + microseconds = int(ticks * ticks_to_microsecond) + else: + microseconds = 0 + + if self.utc: + value = datetime.datetime.utcfromtimestamp(value) + else: + value = datetime.datetime.fromtimestamp(value) + + if microseconds: + value = value.replace(microsecond=microseconds) + + return value + + +class IPField(BigIntegerField): + def db_value(self, val): + if val is not None: + return struct.unpack('!I', socket.inet_aton(val))[0] + + def python_value(self, val): + if val is not None: + return socket.inet_ntoa(struct.pack('!I', val)) + + +class BooleanField(Field): + field_type = 'BOOL' + adapt = bool + + +class BareField(Field): + def __init__(self, adapt=None, *args, **kwargs): + super(BareField, self).__init__(*args, **kwargs) + if adapt is not None: + self.adapt = adapt + + def ddl_datatype(self, ctx): + return + + +class ForeignKeyField(Field): + accessor_class = ForeignKeyAccessor + + def __init__(self, model, field=None, backref=None, on_delete=None, + on_update=None, deferrable=None, _deferred=None, + rel_model=None, to_field=None, object_id_name=None, + lazy_load=True, related_name=None, *args, **kwargs): + kwargs.setdefault('index', True) + + # If lazy_load is disable, we use a different descriptor/accessor that + # will ensure we don't accidentally perform a query. + if not lazy_load: + self.accessor_class = NoQueryForeignKeyAccessor + + super(ForeignKeyField, self).__init__(*args, **kwargs) + + if rel_model is not None: + __deprecated__('"rel_model" has been deprecated in favor of ' + '"model" for ForeignKeyField objects.') + model = rel_model + if to_field is not None: + __deprecated__('"to_field" has been deprecated in favor of ' + '"field" for ForeignKeyField objects.') + field = to_field + if related_name is not None: + __deprecated__('"related_name" has been deprecated in favor of ' + '"backref" for Field objects.') + backref = related_name + + self.rel_model = model + self.rel_field = field + self.declared_backref = backref + self.backref = None + self.on_delete = on_delete + self.on_update = on_update + self.deferrable = deferrable + self.deferred = _deferred + self.object_id_name = object_id_name + self.lazy_load = lazy_load + + @property + def field_type(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.field_type + elif isinstance(self.rel_field, BigAutoField): + return BigIntegerField.field_type + return IntegerField.field_type + + def get_modifiers(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.get_modifiers() + return super(ForeignKeyField, self).get_modifiers() + + def adapt(self, value): + return self.rel_field.adapt(value) + + def db_value(self, value): + if isinstance(value, self.rel_model): + value = value.get_id() + return self.rel_field.db_value(value) + + def python_value(self, value): + if isinstance(value, self.rel_model): + return value + return self.rel_field.python_value(value) + + def bind(self, model, name, set_attribute=True): + if not self.column_name: + self.column_name = name if name.endswith('_id') else name + '_id' + if not self.object_id_name: + self.object_id_name = self.column_name + if self.object_id_name == name: + self.object_id_name += '_id' + elif self.object_id_name == name: + raise ValueError('ForeignKeyField "%s"."%s" specifies an ' + 'object_id_name that conflicts with its field ' + 'name.' % (model._meta.name, name)) + if self.rel_model == 'self': + self.rel_model = model + if isinstance(self.rel_field, basestring): + self.rel_field = getattr(self.rel_model, self.rel_field) + elif self.rel_field is None: + self.rel_field = self.rel_model._meta.primary_key + + # Bind field before assigning backref, so field is bound when + # calling declared_backref() (if callable). + super(ForeignKeyField, self).bind(model, name, set_attribute) + + if callable_(self.declared_backref): + self.backref = self.declared_backref(self) + else: + self.backref, self.declared_backref = self.declared_backref, None + if not self.backref: + self.backref = '%s_set' % model._meta.name + + if set_attribute: + setattr(model, self.object_id_name, ObjectIdAccessor(self)) + if self.backref not in '!+': + setattr(self.rel_model, self.backref, BackrefAccessor(self)) + + def foreign_key_constraint(self): + parts = [ + SQL('FOREIGN KEY'), + EnclosedNodeList((self,)), + SQL('REFERENCES'), + self.rel_model, + EnclosedNodeList((self.rel_field,))] + if self.on_delete: + parts.append(SQL('ON DELETE %s' % self.on_delete)) + if self.on_update: + parts.append(SQL('ON UPDATE %s' % self.on_update)) + if self.deferrable: + parts.append(SQL('DEFERRABLE %s' % self.deferrable)) + return NodeList(parts) + + def __getattr__(self, attr): + if attr.startswith('__'): + # Prevent recursion error when deep-copying. + raise AttributeError('Cannot look-up non-existant "__" methods.') + if attr in self.rel_model._meta.fields: + return self.rel_model._meta.fields[attr] + raise AttributeError('Foreign-key has no attribute %s, nor is it a ' + 'valid field on the related model.' % attr) + + +class DeferredForeignKey(Field): + _unresolved = set() + + def __init__(self, rel_model_name, **kwargs): + self.field_kwargs = kwargs + self.rel_model_name = rel_model_name.lower() + DeferredForeignKey._unresolved.add(self) + super(DeferredForeignKey, self).__init__( + column_name=kwargs.get('column_name'), + null=kwargs.get('null')) + + __hash__ = object.__hash__ + + def __deepcopy__(self, memo=None): + return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) + + def set_model(self, rel_model): + field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) + self.model._meta.add_field(self.name, field) + + @staticmethod + def resolve(model_cls): + unresolved = sorted(DeferredForeignKey._unresolved, + key=operator.attrgetter('_order')) + for dr in unresolved: + if dr.rel_model_name == model_cls.__name__.lower(): + dr.set_model(model_cls) + DeferredForeignKey._unresolved.discard(dr) + + +class DeferredThroughModel(object): + def __init__(self): + self._refs = [] + + def set_field(self, model, field, name): + self._refs.append((model, field, name)) + + def set_model(self, through_model): + for src_model, m2mfield, name in self._refs: + m2mfield.through_model = through_model + src_model._meta.add_field(name, m2mfield) + + +class MetaField(Field): + column_name = default = model = name = None + primary_key = False + + +class ManyToManyFieldAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ManyToManyFieldAccessor, self).__init__(model, field, name) + self.model = field.model + self.rel_model = field.rel_model + self.through_model = field.through_model + src_fks = self.through_model._meta.model_refs[self.model] + dest_fks = self.through_model._meta.model_refs[self.rel_model] + if not src_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.model, self.through_model)) + elif not dest_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.rel_model, self.through_model)) + self.src_fk = src_fks[0] + self.dest_fk = dest_fks[0] + + def __get__(self, instance, instance_type=None, force_query=False): + if instance is not None: + if not force_query and self.src_fk.backref != '+': + backref = getattr(instance, self.src_fk.backref) + if isinstance(backref, list): + return [getattr(obj, self.dest_fk.name) for obj in backref] + + src_id = getattr(instance, self.src_fk.rel_field.name) + return (ManyToManyQuery(instance, self, self.rel_model) + .join(self.through_model) + .join(self.model) + .where(self.src_fk == src_id)) + + return self.field + + def __set__(self, instance, value): + query = self.__get__(instance, force_query=True) + query.add(value, clear_existing=True) + + +class ManyToManyField(MetaField): + accessor_class = ManyToManyFieldAccessor + + def __init__(self, model, backref=None, through_model=None, on_delete=None, + on_update=None, _is_backref=False): + if through_model is not None: + if not (isinstance(through_model, DeferredThroughModel) or + is_model(through_model)): + raise TypeError('Unexpected value for through_model. Expected ' + 'Model or DeferredThroughModel.') + if not _is_backref and (on_delete is not None or on_update is not None): + raise ValueError('Cannot specify on_delete or on_update when ' + 'through_model is specified.') + self.rel_model = model + self.backref = backref + self._through_model = through_model + self._on_delete = on_delete + self._on_update = on_update + self._is_backref = _is_backref + + def _get_descriptor(self): + return ManyToManyFieldAccessor(self) + + def bind(self, model, name, set_attribute=True): + if isinstance(self._through_model, DeferredThroughModel): + self._through_model.set_field(model, self, name) + return + + super(ManyToManyField, self).bind(model, name, set_attribute) + + if not self._is_backref: + many_to_many_field = ManyToManyField( + self.model, + backref=name, + through_model=self.through_model, + on_delete=self._on_delete, + on_update=self._on_update, + _is_backref=True) + self.backref = self.backref or model._meta.name + 's' + self.rel_model._meta.add_field(self.backref, many_to_many_field) + + def get_models(self): + return [model for _, model in sorted(( + (self._is_backref, self.model), + (not self._is_backref, self.rel_model)))] + + @property + def through_model(self): + if self._through_model is None: + self._through_model = self._create_through_model() + return self._through_model + + @through_model.setter + def through_model(self, value): + self._through_model = value + + def _create_through_model(self): + lhs, rhs = self.get_models() + tables = [model._meta.table_name for model in (lhs, rhs)] + + class Meta: + database = self.model._meta.database + schema = self.model._meta.schema + table_name = '%s_%s_through' % tuple(tables) + indexes = ( + ((lhs._meta.name, rhs._meta.name), + True),) + + params = {'on_delete': self._on_delete, 'on_update': self._on_update} + attrs = { + lhs._meta.name: ForeignKeyField(lhs, **params), + rhs._meta.name: ForeignKeyField(rhs, **params), + 'Meta': Meta} + + klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) + return type(klass_name, (Model,), attrs) + + def get_through_model(self): + # XXX: Deprecated. Just use the "through_model" property. + return self.through_model + + +class VirtualField(MetaField): + field_class = None + + def __init__(self, field_class=None, *args, **kwargs): + Field = field_class if field_class is not None else self.field_class + self.field_instance = Field() if Field is not None else None + super(VirtualField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if self.field_instance is not None: + return self.field_instance.db_value(value) + return value + + def python_value(self, value): + if self.field_instance is not None: + return self.field_instance.python_value(value) + return value + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = name + setattr(model, name, self.accessor_class(model, self, name)) + + +class CompositeKey(MetaField): + sequence = None + + def __init__(self, *field_names): + self.field_names = field_names + + def __get__(self, instance, instance_type=None): + if instance is not None: + return tuple([getattr(instance, field_name) + for field_name in self.field_names]) + return self + + def __set__(self, instance, value): + if not isinstance(value, (list, tuple)): + raise TypeError('A list or tuple must be used to set the value of ' + 'a composite primary key.') + if len(value) != len(self.field_names): + raise ValueError('The length of the value must equal the number ' + 'of columns of the composite primary key.') + for idx, field_value in enumerate(value): + setattr(instance, self.field_names[idx], field_value) + + def __eq__(self, other): + expressions = [(self.model._meta.fields[field] == value) + for field, value in zip(self.field_names, other)] + return reduce(operator.and_, expressions) + + def __ne__(self, other): + return ~(self == other) + + def __hash__(self): + return hash((self.model.__name__, self.field_names)) + + def __sql__(self, ctx): + # If the composite PK is being selected, do not use parens. Elsewhere, + # such as in an expression, we want to use parentheses and treat it as + # a row value. + parens = ctx.scope != SCOPE_SOURCE + return ctx.sql(NodeList([self.model._meta.fields[field] + for field in self.field_names], ', ', parens)) + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = name + setattr(model, self.name, self) + + +class _SortedFieldList(object): + __slots__ = ('_keys', '_items') + + def __init__(self): + self._keys = [] + self._items = [] + + def __getitem__(self, i): + return self._items[i] + + def __iter__(self): + return iter(self._items) + + def __contains__(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + j = bisect_right(self._keys, k) + return item in self._items[i:j] + + def index(self, field): + return self._keys.index(field._sort_key) + + def insert(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + self._keys.insert(i, k) + self._items.insert(i, item) + + def remove(self, item): + idx = self.index(item) + del self._items[idx] + del self._keys[idx] + + +# MODELS + + +class SchemaManager(object): + def __init__(self, model, database=None, **context_options): + self.model = model + self._database = database + context_options.setdefault('scope', SCOPE_VALUES) + self.context_options = context_options + + @property + def database(self): + db = self._database or self.model._meta.database + if db is None: + raise ImproperlyConfigured('database attribute does not appear to ' + 'be set on the model: %s' % self.model) + return db + + @database.setter + def database(self, value): + self._database = value + + def _create_context(self): + return self.database.get_sql_context(**self.context_options) + + def _create_table(self, safe=True, **options): + is_temp = options.pop('temporary', False) + ctx = self._create_context() + ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + ctx.sql(self.model).literal(' ') + + columns = [] + constraints = [] + meta = self.model._meta + if meta.composite_key: + pk_columns = [meta.fields[field_name].column + for field_name in meta.primary_key.field_names] + constraints.append(NodeList((SQL('PRIMARY KEY'), + EnclosedNodeList(pk_columns)))) + + for field in meta.sorted_fields: + columns.append(field.ddl(ctx)) + if isinstance(field, ForeignKeyField) and not field.deferred: + constraints.append(field.foreign_key_constraint()) + + if meta.constraints: + constraints.extend(meta.constraints) + + constraints.extend(self._create_table_option_sql(options)) + ctx.sql(EnclosedNodeList(columns + constraints)) + + if meta.table_settings is not None: + table_settings = ensure_tuple(meta.table_settings) + for setting in table_settings: + if not isinstance(setting, basestring): + raise ValueError('table_settings must be strings') + ctx.literal(' ').literal(setting) + + if meta.without_rowid: + ctx.literal(' WITHOUT ROWID') + return ctx + + def _create_table_option_sql(self, options): + accum = [] + options = merge_dict(self.model._meta.options or {}, options) + if not options: + return accum + + for key, value in sorted(options.items()): + if not isinstance(value, Node): + if is_model(value): + value = value._meta.table + else: + value = SQL(str(value)) + accum.append(NodeList((SQL(key), value), glue='=')) + return accum + + def create_table(self, safe=True, **options): + self.database.execute(self._create_table(safe=safe, **options)) + + def _create_table_as(self, table_name, query, safe=True, **meta): + ctx = (self._create_context() + .literal('CREATE TEMPORARY TABLE ' + if meta.get('temporary') else 'CREATE TABLE ')) + if safe: + ctx.literal('IF NOT EXISTS ') + return (ctx + .sql(Entity(table_name)) + .literal(' AS ') + .sql(query)) + + def create_table_as(self, table_name, query, safe=True, **meta): + ctx = self._create_table_as(table_name, query, safe=safe, **meta) + self.database.execute(ctx) + + def _drop_table(self, safe=True, **options): + ctx = (self._create_context() + .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') + .sql(self.model)) + if options.get('cascade'): + ctx = ctx.literal(' CASCADE') + elif options.get('restrict'): + ctx = ctx.literal(' RESTRICT') + return ctx + + def drop_table(self, safe=True, **options): + self.database.execute(self._drop_table(safe=safe, **options)) + + def _truncate_table(self, restart_identity=False, cascade=False): + db = self.database + if not db.truncate_table: + return (self._create_context() + .literal('DELETE FROM ').sql(self.model)) + + ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) + if restart_identity: + ctx = ctx.literal(' RESTART IDENTITY') + if cascade: + ctx = ctx.literal(' CASCADE') + return ctx + + def truncate_table(self, restart_identity=False, cascade=False): + self.database.execute(self._truncate_table(restart_identity, cascade)) + + def _create_indexes(self, safe=True): + return [self._create_index(index, safe) + for index in self.model._meta.fields_to_index()] + + def _create_index(self, index, safe=True): + if isinstance(index, Index): + if not self.database.safe_create_index: + index = index.safe(False) + elif index._safe != safe: + index = index.safe(safe) + return self._create_context().sql(index) + + def create_indexes(self, safe=True): + for query in self._create_indexes(safe=safe): + self.database.execute(query) + + def _drop_indexes(self, safe=True): + return [self._drop_index(index, safe) + for index in self.model._meta.fields_to_index() + if isinstance(index, Index)] + + def _drop_index(self, index, safe): + statement = 'DROP INDEX ' + if safe and self.database.safe_drop_index: + statement += 'IF EXISTS ' + if isinstance(index._table, Table) and index._table._schema: + index_name = Entity(index._table._schema, index._name) + else: + index_name = Entity(index._name) + return (self + ._create_context() + .literal(statement) + .sql(index_name)) + + def drop_indexes(self, safe=True): + for query in self._drop_indexes(safe=safe): + self.database.execute(query) + + def _check_sequences(self, field): + if not field.sequence or not self.database.sequences: + raise ValueError('Sequences are either not supported, or are not ' + 'defined for "%s".' % field.name) + + def _sequence_for_field(self, field): + if field.model._meta.schema: + return Entity(field.model._meta.schema, field.sequence) + else: + return Entity(field.sequence) + + def _create_sequence(self, field): + self._check_sequences(field) + if not self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('CREATE SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def create_sequence(self, field): + seq_ctx = self._create_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _drop_sequence(self, field): + self._check_sequences(field) + if self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('DROP SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def drop_sequence(self, field): + seq_ctx = self._drop_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _create_foreign_key(self, field): + name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, + field.column_name, + field.rel_model._meta.table_name) + return (self + ._create_context() + .literal('ALTER TABLE ') + .sql(field.model) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(name))) + .literal(' ') + .sql(field.foreign_key_constraint())) + + def create_foreign_key(self, field): + self.database.execute(self._create_foreign_key(field)) + + def create_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.create_sequence(field) + + def create_all(self, safe=True, **table_options): + self.create_sequences() + self.create_table(safe, **table_options) + self.create_indexes(safe=safe) + + def drop_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.drop_sequence(field) + + def drop_all(self, safe=True, drop_sequences=True, **options): + self.drop_table(safe, **options) + if drop_sequences: + self.drop_sequences() + + +class Metadata(object): + def __init__(self, model, database=None, table_name=None, indexes=None, + primary_key=None, constraints=None, schema=None, + only_save_dirty=False, depends_on=None, options=None, + db_table=None, table_function=None, table_settings=None, + without_rowid=False, temporary=False, legacy_table_names=True, + **kwargs): + if db_table is not None: + __deprecated__('"db_table" has been deprecated in favor of ' + '"table_name" for Models.') + table_name = db_table + self.model = model + self.database = database + + self.fields = {} + self.columns = {} + self.combined = {} + + self._sorted_field_list = _SortedFieldList() + self.sorted_fields = [] + self.sorted_field_names = [] + + self.defaults = {} + self._default_by_name = {} + self._default_dict = {} + self._default_callables = {} + self._default_callable_list = [] + + self.name = model.__name__.lower() + self.table_function = table_function + self.legacy_table_names = legacy_table_names + if not table_name: + table_name = (self.table_function(model) + if self.table_function + else self.make_table_name()) + self.table_name = table_name + self._table = None + + self.indexes = list(indexes) if indexes else [] + self.constraints = constraints + self._schema = schema + self.primary_key = primary_key + self.composite_key = self.auto_increment = None + self.only_save_dirty = only_save_dirty + self.depends_on = depends_on + self.table_settings = table_settings + self.without_rowid = without_rowid + self.temporary = temporary + + self.refs = {} + self.backrefs = {} + self.model_refs = collections.defaultdict(list) + self.model_backrefs = collections.defaultdict(list) + self.manytomany = {} + + self.options = options or {} + for key, value in kwargs.items(): + setattr(self, key, value) + self._additional_keys = set(kwargs.keys()) + + # Allow objects to register hooks that are called if the model is bound + # to a different database. For example, BlobField uses a different + # Python data-type depending on the db driver / python version. When + # the database changes, we need to update any BlobField so they can use + # the appropriate data-type. + self._db_hooks = [] + + def make_table_name(self): + if self.legacy_table_names: + return re.sub('[^\w]+', '_', self.name) + return make_snake_case(self.model.__name__) + + def model_graph(self, refs=True, backrefs=True, depth_first=True): + if not refs and not backrefs: + raise ValueError('One of `refs` or `backrefs` must be True.') + + accum = [(None, self.model, None)] + seen = set() + queue = collections.deque((self,)) + method = queue.pop if depth_first else queue.popleft + + while queue: + curr = method() + if curr in seen: continue + seen.add(curr) + + if refs: + for fk, model in curr.refs.items(): + accum.append((fk, model, False)) + queue.append(model._meta) + if backrefs: + for fk, model in curr.backrefs.items(): + accum.append((fk, model, True)) + queue.append(model._meta) + + return accum + + def add_ref(self, field): + rel = field.rel_model + self.refs[field] = rel + self.model_refs[rel].append(field) + rel._meta.backrefs[field] = self.model + rel._meta.model_backrefs[self.model].append(field) + + def remove_ref(self, field): + rel = field.rel_model + del self.refs[field] + self.model_refs[rel].remove(field) + del rel._meta.backrefs[field] + rel._meta.model_backrefs[self.model].remove(field) + + def add_manytomany(self, field): + self.manytomany[field.name] = field + + def remove_manytomany(self, field): + del self.manytomany[field.name] + + @property + def table(self): + if self._table is None: + self._table = Table( + self.table_name, + [field.column_name for field in self.sorted_fields], + schema=self.schema, + _model=self.model, + _database=self.database) + return self._table + + @table.setter + def table(self, value): + raise AttributeError('Cannot set the "table".') + + @table.deleter + def table(self): + self._table = None + + @property + def schema(self): + return self._schema + + @schema.setter + def schema(self, value): + self._schema = value + del self.table + + @property + def entity(self): + if self._schema: + return Entity(self._schema, self.table_name) + else: + return Entity(self.table_name) + + def _update_sorted_fields(self): + self.sorted_fields = list(self._sorted_field_list) + self.sorted_field_names = [f.name for f in self.sorted_fields] + + def get_rel_for_model(self, model): + if isinstance(model, ModelAlias): + model = model.model + forwardrefs = self.model_refs.get(model, []) + backrefs = self.model_backrefs.get(model, []) + return (forwardrefs, backrefs) + + def add_field(self, field_name, field, set_attribute=True): + if field_name in self.fields: + self.remove_field(field_name) + elif field_name in self.manytomany: + self.remove_manytomany(self.manytomany[field_name]) + + if not isinstance(field, MetaField): + del self.table + field.bind(self.model, field_name, set_attribute) + self.fields[field.name] = field + self.columns[field.column_name] = field + self.combined[field.name] = field + self.combined[field.column_name] = field + + self._sorted_field_list.insert(field) + self._update_sorted_fields() + + if field.default is not None: + # This optimization helps speed up model instance construction. + self.defaults[field] = field.default + if callable_(field.default): + self._default_callables[field] = field.default + self._default_callable_list.append((field.name, + field.default)) + else: + self._default_dict[field] = field.default + self._default_by_name[field.name] = field.default + else: + field.bind(self.model, field_name, set_attribute) + + if isinstance(field, ForeignKeyField): + self.add_ref(field) + elif isinstance(field, ManyToManyField) and field.name: + self.add_manytomany(field) + + def remove_field(self, field_name): + if field_name not in self.fields: + return + + del self.table + original = self.fields.pop(field_name) + del self.columns[original.column_name] + del self.combined[field_name] + try: + del self.combined[original.column_name] + except KeyError: + pass + self._sorted_field_list.remove(original) + self._update_sorted_fields() + + if original.default is not None: + del self.defaults[original] + if self._default_callables.pop(original, None): + for i, (name, _) in enumerate(self._default_callable_list): + if name == field_name: + self._default_callable_list.pop(i) + break + else: + self._default_dict.pop(original, None) + self._default_by_name.pop(original.name, None) + + if isinstance(original, ForeignKeyField): + self.remove_ref(original) + + def set_primary_key(self, name, field): + self.composite_key = isinstance(field, CompositeKey) + self.add_field(name, field) + self.primary_key = field + self.auto_increment = ( + field.auto_increment or + bool(field.sequence)) + + def get_primary_keys(self): + if self.composite_key: + return tuple([self.fields[field_name] + for field_name in self.primary_key.field_names]) + else: + return (self.primary_key,) if self.primary_key is not False else () + + def get_default_dict(self): + dd = self._default_by_name.copy() + for field_name, default in self._default_callable_list: + dd[field_name] = default() + return dd + + def fields_to_index(self): + indexes = [] + for f in self.sorted_fields: + if f.primary_key: + continue + if f.index or f.unique: + indexes.append(ModelIndex(self.model, (f,), unique=f.unique, + using=f.index_type)) + + for index_obj in self.indexes: + if isinstance(index_obj, Node): + indexes.append(index_obj) + elif isinstance(index_obj, (list, tuple)): + index_parts, unique = index_obj + fields = [] + for part in index_parts: + if isinstance(part, basestring): + fields.append(self.combined[part]) + elif isinstance(part, Node): + fields.append(part) + else: + raise ValueError('Expected either a field name or a ' + 'subclass of Node. Got: %s' % part) + indexes.append(ModelIndex(self.model, fields, unique=unique)) + + return indexes + + def set_database(self, database): + self.database = database + self.model._schema._database = database + del self.table + + # Apply any hooks that have been registered. + for hook in self._db_hooks: + hook(database) + + def set_table_name(self, table_name): + self.table_name = table_name + del self.table + + +class SubclassAwareMetadata(Metadata): + models = [] + + def __init__(self, model, *args, **kwargs): + super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) + self.models.append(model) + + def map_models(self, fn): + for model in self.models: + fn(model) + + +class DoesNotExist(Exception): pass + + +class ModelBase(type): + inheritable = set(['constraints', 'database', 'indexes', 'primary_key', + 'options', 'schema', 'table_function', 'temporary', + 'only_save_dirty', 'legacy_table_names', + 'table_settings']) + + def __new__(cls, name, bases, attrs): + if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: + return super(ModelBase, cls).__new__(cls, name, bases, attrs) + + meta_options = {} + meta = attrs.pop('Meta', None) + if meta: + for k, v in meta.__dict__.items(): + if not k.startswith('_'): + meta_options[k] = v + + pk = getattr(meta, 'primary_key', None) + pk_name = parent_pk = None + + # Inherit any field descriptors by deep copying the underlying field + # into the attrs of the new model, additionally see if the bases define + # inheritable model options and swipe them. + for b in bases: + if not hasattr(b, '_meta'): + continue + + base_meta = b._meta + if parent_pk is None: + parent_pk = deepcopy(base_meta.primary_key) + all_inheritable = cls.inheritable | base_meta._additional_keys + for k in base_meta.__dict__: + if k in all_inheritable and k not in meta_options: + meta_options[k] = base_meta.__dict__[k] + meta_options.setdefault('schema', base_meta.schema) + + for (k, v) in b.__dict__.items(): + if k in attrs: continue + + if isinstance(v, FieldAccessor) and not v.field.primary_key: + attrs[k] = deepcopy(v.field) + + sopts = meta_options.pop('schema_options', None) or {} + Meta = meta_options.get('model_metadata_class', Metadata) + Schema = meta_options.get('schema_manager_class', SchemaManager) + + # Construct the new class. + cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) + cls.__data__ = cls.__rel__ = None + + cls._meta = Meta(cls, **meta_options) + cls._schema = Schema(cls, **sopts) + + fields = [] + for key, value in cls.__dict__.items(): + if isinstance(value, Field): + if value.primary_key and pk: + raise ValueError('over-determined primary key %s.' % name) + elif value.primary_key: + pk, pk_name = value, key + else: + fields.append((key, value)) + + if pk is None: + if parent_pk is not False: + pk, pk_name = ((parent_pk, parent_pk.name) + if parent_pk is not None else + (AutoField(), 'id')) + else: + pk = False + elif isinstance(pk, CompositeKey): + pk_name = '__composite_key__' + cls._meta.composite_key = True + + if pk is not False: + cls._meta.set_primary_key(pk_name, pk) + + for name, field in fields: + cls._meta.add_field(name, field) + + # Create a repr and error class before finalizing. + if hasattr(cls, '__str__') and '__repr__' not in attrs: + setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( + cls.__name__, self.__str__())) + + exc_name = '%sDoesNotExist' % cls.__name__ + exc_attrs = {'__module__': cls.__module__} + exception_class = type(exc_name, (DoesNotExist,), exc_attrs) + cls.DoesNotExist = exception_class + + # Call validation hook, allowing additional model validation. + cls.validate_model() + DeferredForeignKey.resolve(cls) + return cls + + def __repr__(self): + return '' % self.__name__ + + def __iter__(self): + return iter(self.select()) + + def __getitem__(self, key): + return self.get_by_id(key) + + def __setitem__(self, key, value): + self.set_by_id(key, value) + + def __delitem__(self, key): + self.delete_by_id(key) + + def __contains__(self, key): + try: + self.get_by_id(key) + except self.DoesNotExist: + return False + else: + return True + + def __len__(self): + return self.select().count() + def __bool__(self): return True + __nonzero__ = __bool__ # Python 2. + + +class _BoundModelsContext(_callable_context_manager): + def __init__(self, models, database, bind_refs, bind_backrefs): + self.models = models + self.database = database + self.bind_refs = bind_refs + self.bind_backrefs = bind_backrefs + + def __enter__(self): + self._orig_database = [] + for model in self.models: + self._orig_database.append(model._meta.database) + model.bind(self.database, self.bind_refs, self.bind_backrefs) + return self.models + + def __exit__(self, exc_type, exc_val, exc_tb): + for model, db in zip(self.models, self._orig_database): + model.bind(db, self.bind_refs, self.bind_backrefs) + + +class Model(with_metaclass(ModelBase, Node)): + def __init__(self, *args, **kwargs): + if kwargs.pop('__no_default__', None): + self.__data__ = {} + else: + self.__data__ = self._meta.get_default_dict() + self._dirty = set(self.__data__) + self.__rel__ = {} + + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __str__(self): + return str(self._pk) if self._meta.primary_key is not False else 'n/a' + + @classmethod + def validate_model(cls): + pass + + @classmethod + def alias(cls, alias=None): + return ModelAlias(cls, alias) + + @classmethod + def select(cls, *fields): + is_default = not fields + if not fields: + fields = cls._meta.sorted_fields + return ModelSelect(cls, fields, is_default=is_default) + + @classmethod + def _normalize_data(cls, data, kwargs): + normalized = {} + if data: + if not isinstance(data, dict): + if kwargs: + raise ValueError('Data cannot be mixed with keyword ' + 'arguments: %s' % data) + return data + for key in data: + try: + field = (key if isinstance(key, Field) + else cls._meta.combined[key]) + except KeyError: + raise ValueError('Unrecognized field name: "%s" in %s.' % + (key, data)) + normalized[field] = data[key] + if kwargs: + for key in kwargs: + try: + normalized[cls._meta.combined[key]] = kwargs[key] + except KeyError: + normalized[getattr(cls, key)] = kwargs[key] + return normalized + + @classmethod + def update(cls, __data=None, **update): + return ModelUpdate(cls, cls._normalize_data(__data, update)) + + @classmethod + def insert(cls, __data=None, **insert): + return ModelInsert(cls, cls._normalize_data(__data, insert)) + + @classmethod + def insert_many(cls, rows, fields=None): + return ModelInsert(cls, insert=rows, columns=fields) + + @classmethod + def insert_from(cls, query, fields): + columns = [getattr(cls, field) if isinstance(field, basestring) + else field for field in fields] + return ModelInsert(cls, insert=query, columns=columns) + + @classmethod + def replace(cls, __data=None, **insert): + return cls.insert(__data, **insert).on_conflict('REPLACE') + + @classmethod + def replace_many(cls, rows, fields=None): + return (cls + .insert_many(rows=rows, fields=fields) + .on_conflict('REPLACE')) + + @classmethod + def raw(cls, sql, *params): + return ModelRaw(cls, sql, params) + + @classmethod + def delete(cls): + return ModelDelete(cls) + + @classmethod + def create(cls, **query): + inst = cls(**query) + inst.save(force_insert=True) + return inst + + @classmethod + def bulk_create(cls, model_list, batch_size=None): + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + field_names = list(cls._meta.sorted_field_names) + if cls._meta.auto_increment: + pk_name = cls._meta.primary_key.name + field_names.remove(pk_name) + ids_returned = cls._meta.database.returning_clause + else: + ids_returned = False + + fields = [cls._meta.fields[field_name] for field_name in field_names] + for batch in batches: + accum = ([getattr(model, f) for f in field_names] + for model in batch) + res = cls.insert_many(accum, fields=fields).execute() + if ids_returned and res is not None: + for (obj_id,), model in zip(res, batch): + setattr(model, pk_name, obj_id) + + @classmethod + def bulk_update(cls, model_list, fields, batch_size=None): + if isinstance(cls._meta.primary_key, CompositeKey): + raise ValueError('bulk_update() is not supported for models with ' + 'a composite primary key.') + + # First normalize list of fields so all are field instances. + fields = [cls._meta.fields[f] if isinstance(f, basestring) else f + for f in fields] + # Now collect list of attribute names to use for values. + attrs = [field.object_id_name if isinstance(field, ForeignKeyField) + else field.name for field in fields] + + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + n = 0 + for batch in batches: + id_list = [model._pk for model in batch] + update = {} + for field, attr in zip(fields, attrs): + accum = [] + for model in batch: + value = getattr(model, attr) + if not isinstance(value, Node): + value = Value(value, converter=field.db_value) + accum.append((model._pk, value)) + case = Case(cls._meta.primary_key, accum) + update[field] = case + + n += (cls.update(update) + .where(cls._meta.primary_key.in_(id_list)) + .execute()) + return n + + @classmethod + def noop(cls): + return NoopModelSelect(cls, ()) + + @classmethod + def get(cls, *query, **filters): + sq = cls.select() + if query: + # Handle simple lookup using just the primary key. + if len(query) == 1 and isinstance(query[0], int): + sq = sq.where(cls._meta.primary_key == query[0]) + else: + sq = sq.where(*query) + if filters: + sq = sq.filter(**filters) + return sq.get() + + @classmethod + def get_or_none(cls, *query, **filters): + try: + return cls.get(*query, **filters) + except DoesNotExist: + pass + + @classmethod + def get_by_id(cls, pk): + return cls.get(cls._meta.primary_key == pk) + + @classmethod + def set_by_id(cls, key, value): + if key is None: + return cls.insert(value).execute() + else: + return (cls.update(value) + .where(cls._meta.primary_key == key).execute()) + + @classmethod + def delete_by_id(cls, pk): + return cls.delete().where(cls._meta.primary_key == pk).execute() + + @classmethod + def get_or_create(cls, **kwargs): + defaults = kwargs.pop('defaults', {}) + query = cls.select() + for field, value in kwargs.items(): + query = query.where(getattr(cls, field) == value) + + try: + return query.get(), False + except cls.DoesNotExist: + try: + if defaults: + kwargs.update(defaults) + with cls._meta.database.atomic(): + return cls.create(**kwargs), True + except IntegrityError as exc: + try: + return query.get(), False + except cls.DoesNotExist: + raise exc + + @classmethod + def filter(cls, *dq_nodes, **filters): + return cls.select().filter(*dq_nodes, **filters) + + def get_id(self): + return getattr(self, self._meta.primary_key.name) + + _pk = property(get_id) + + @_pk.setter + def _pk(self, value): + setattr(self, self._meta.primary_key.name, value) + + def _pk_expr(self): + return self._meta.primary_key == self._pk + + def _prune_fields(self, field_dict, only): + new_data = {} + for field in only: + if isinstance(field, basestring): + field = self._meta.combined[field] + if field.name in field_dict: + new_data[field.name] = field_dict[field.name] + return new_data + + def _populate_unsaved_relations(self, field_dict): + for foreign_key_field in self._meta.refs: + foreign_key = foreign_key_field.name + conditions = ( + foreign_key in field_dict and + field_dict[foreign_key] is None and + self.__rel__.get(foreign_key) is not None) + if conditions: + setattr(self, foreign_key, getattr(self, foreign_key)) + field_dict[foreign_key] = self.__data__[foreign_key] + + def save(self, force_insert=False, only=None): + field_dict = self.__data__.copy() + if self._meta.primary_key is not False: + pk_field = self._meta.primary_key + pk_value = self._pk + else: + pk_field = pk_value = None + if only: + field_dict = self._prune_fields(field_dict, only) + elif self._meta.only_save_dirty and not force_insert: + field_dict = self._prune_fields(field_dict, self.dirty_fields) + if not field_dict: + self._dirty.clear() + return False + + self._populate_unsaved_relations(field_dict) + rows = 1 + + if pk_value is not None and not force_insert: + if self._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + if not field_dict: + raise ValueError('no data to save!') + rows = self.update(**field_dict).where(self._pk_expr()).execute() + elif pk_field is not None: + pk = self.insert(**field_dict).execute() + if pk is not None and (self._meta.auto_increment or + pk_value is None): + self._pk = pk + else: + self.insert(**field_dict).execute() + + self._dirty.clear() + return rows + + def is_dirty(self): + return bool(self._dirty) + + @property + def dirty_fields(self): + return [f for f in self._meta.sorted_fields if f.name in self._dirty] + + def dependencies(self, search_nullable=False): + model_class = type(self) + stack = [(type(self), None)] + seen = set() + + while stack: + klass, query = stack.pop() + if klass in seen: + continue + seen.add(klass) + for fk, rel_model in klass._meta.backrefs.items(): + if rel_model is model_class or query is None: + node = (fk == self.__data__[fk.rel_field.name]) + else: + node = fk << query + subquery = (rel_model.select(rel_model._meta.primary_key) + .where(node)) + if not fk.null or search_nullable: + stack.append((rel_model, subquery)) + yield (node, fk) + + def delete_instance(self, recursive=False, delete_nullable=False): + if recursive: + dependencies = self.dependencies(delete_nullable) + for query, fk in reversed(list(dependencies)): + model = fk.model + if fk.null and not delete_nullable: + model.update(**{fk.name: None}).where(query).execute() + else: + model.delete().where(query).execute() + return type(self).delete().where(self._pk_expr()).execute() + + def __hash__(self): + return hash((self.__class__, self._pk)) + + def __eq__(self, other): + return ( + other.__class__ == self.__class__ and + self._pk is not None and + other._pk == self._pk) + + def __ne__(self, other): + return not self == other + + def __sql__(self, ctx): + return ctx.sql(getattr(self, self._meta.primary_key.name)) + + @classmethod + def bind(cls, database, bind_refs=True, bind_backrefs=True): + is_different = cls._meta.database is not database + cls._meta.set_database(database) + if bind_refs or bind_backrefs: + G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) + for _, model, is_backref in G: + model._meta.set_database(database) + return is_different + + @classmethod + def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) + + @classmethod + def table_exists(cls): + M = cls._meta + return cls._schema.database.table_exists(M.table.__name__, M.schema) + + @classmethod + def create_table(cls, safe=True, **options): + if 'fail_silently' in options: + __deprecated__('"fail_silently" has been deprecated in favor of ' + '"safe" for the create_table() method.') + safe = options.pop('fail_silently') + + if safe and not cls._schema.database.safe_create_index \ + and cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.create_all(safe, **options) + + @classmethod + def drop_table(cls, safe=True, drop_sequences=True, **options): + if safe and not cls._schema.database.safe_drop_index \ + and not cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.drop_all(safe, drop_sequences, **options) + + @classmethod + def truncate_table(cls, **options): + cls._schema.truncate_table(**options) + + @classmethod + def index(cls, *fields, **kwargs): + return ModelIndex(cls, fields, **kwargs) + + @classmethod + def add_index(cls, *fields, **kwargs): + if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): + cls._meta.indexes.append(fields[0]) + else: + cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) + + +class ModelAlias(Node): + """Provide a separate reference to a model in a query.""" + def __init__(self, model, alias=None): + self.__dict__['model'] = model + self.__dict__['alias'] = alias + + def __getattr__(self, attr): + model_attr = getattr(self.model, attr) + if isinstance(model_attr, Field): + self.__dict__[attr] = FieldAlias.create(self, model_attr) + return self.__dict__[attr] + return model_attr + + def __setattr__(self, attr, value): + raise AttributeError('Cannot set attributes on model aliases.') + + def get_field_aliases(self): + return [getattr(self, n) for n in self.model._meta.sorted_field_names] + + def select(self, *selection): + if not selection: + selection = self.get_field_aliases() + return ModelSelect(self, selection) + + def __call__(self, **kwargs): + return self.model(**kwargs) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(self.model) + + if self.alias: + ctx.alias_manager[self] = self.alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return (ctx + .sql(self.model._meta.entity) + .literal(' AS ') + .sql(Entity(ctx.alias_manager[self]))) + else: + # Refer to the table using the alias. + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class FieldAlias(Field): + def __init__(self, source, field): + self.source = source + self.model = source.model + self.field = field + + @classmethod + def create(cls, source, field): + class _FieldAlias(cls, type(field)): + pass + return _FieldAlias(source, field) + + def clone(self): + return FieldAlias(self.source, self.field) + + def adapt(self, value): return self.field.adapt(value) + def python_value(self, value): return self.field.python_value(value) + def db_value(self, value): return self.field.db_value(value) + def __getattr__(self, attr): + return self.source if attr == 'model' else getattr(self.field, attr) + + def __sql__(self, ctx): + return ctx.sql(Column(self.source, self.field.column_name)) + + +def sort_models(models): + models = set(models) + seen = set() + ordering = [] + def dfs(model): + if model in models and model not in seen: + seen.add(model) + for foreign_key, rel_model in model._meta.refs.items(): + # Do not depth-first search deferred foreign-keys as this can + # cause tables to be created in the incorrect order. + if not foreign_key.deferred: + dfs(rel_model) + if model._meta.depends_on: + for dependency in model._meta.depends_on: + dfs(dependency) + ordering.append(model) + + names = lambda m: (m._meta.name, m._meta.table_name) + for m in sorted(models, key=names): + dfs(m) + return ordering + + +class _ModelQueryHelper(object): + default_row_type = ROW.MODEL + + def __init__(self, *args, **kwargs): + super(_ModelQueryHelper, self).__init__(*args, **kwargs) + if not self._database: + self._database = self.model._meta.database + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR + self._constructor = self.model if constructor is None else constructor + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + if row_type == ROW.MODEL: + return self._get_model_cursor_wrapper(cursor) + elif row_type == ROW.DICT: + return ModelDictCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.TUPLE: + return ModelTupleCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.NAMED_TUPLE: + return ModelNamedTupleCursorWrapper(cursor, self.model, + self._returning) + elif row_type == ROW.CONSTRUCTOR: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def _get_model_cursor_wrapper(self, cursor): + return ModelObjectCursorWrapper(cursor, self.model, [], self.model) + + +class ModelRaw(_ModelQueryHelper, RawQuery): + def __init__(self, model, sql, params, **kwargs): + self.model = model + self._returning = () + super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) + + def get(self): + try: + return self.execute()[0] + except IndexError: + sql, params = self.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (self.model, sql, params)) + + +class BaseModelSelect(_ModelQueryHelper): + def union_all(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) + __add__ = union_all + + def union(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) + __or__ = union + + def intersect(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) + __and__ = intersect + + def except_(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) + __sub__ = except_ + + def __iter__(self): + if not self._cursor_wrapper: + self.execute() + return iter(self._cursor_wrapper) + + def prefetch(self, *subqueries): + return prefetch(self, *subqueries) + + def get(self, database=None): + clone = self.paginate(1, 1) + clone._cursor_wrapper = None + try: + return clone.execute(database)[0] + except IndexError: + sql, params = clone.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (clone.model, sql, params)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if is_model(column): + grouping.extend(column._meta.sorted_fields) + elif isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): + def __init__(self, model, *args, **kwargs): + self.model = model + super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) + + def _get_model_cursor_wrapper(self, cursor): + return self.lhs._get_model_cursor_wrapper(cursor) + + +def _normalize_model_select(fields_or_models): + fields = [] + for fm in fields_or_models: + if is_model(fm): + fields.extend(fm._meta.sorted_fields) + elif isinstance(fm, ModelAlias): + fields.extend(fm.get_field_aliases()) + elif isinstance(fm, Table) and fm._columns: + fields.extend([getattr(fm, col) for col in fm._columns]) + else: + fields.append(fm) + return fields + + +class ModelSelect(BaseModelSelect, Select): + def __init__(self, model, fields_or_models, is_default=False): + self.model = self._join_ctx = model + self._joins = {} + self._is_default = is_default + fields = _normalize_model_select(fields_or_models) + super(ModelSelect, self).__init__([model], fields) + + def clone(self): + clone = super(ModelSelect, self).clone() + if clone._joins: + clone._joins = dict(clone._joins) + return clone + + def select(self, *fields_or_models): + if fields_or_models or not self._is_default: + self._is_default = False + fields = _normalize_model_select(fields_or_models) + return super(ModelSelect, self).select(*fields) + return self + + def switch(self, ctx=None): + self._join_ctx = self.model if ctx is None else ctx + return self + + def _get_model(self, src): + if is_model(src): + return src, True + elif isinstance(src, Table) and src._model: + return src._model, False + elif isinstance(src, ModelAlias): + return src.model, False + elif isinstance(src, ModelSelect): + return src.model, False + return None, False + + def _normalize_join(self, src, dest, on, attr): + # Allow "on" expression to have an alias that determines the + # destination attribute for the joined data. + on_alias = isinstance(on, Alias) + if on_alias: + attr = attr or on._alias + on = on.alias() + + # Obtain references to the source and destination models being joined. + src_model, src_is_model = self._get_model(src) + dest_model, dest_is_model = self._get_model(dest) + + if src_model and dest_model: + self._join_ctx = dest + constructor = dest_model + + # In the case where the "on" clause is a Column or Field, we will + # convert that field into the appropriate predicate expression. + if not (src_is_model and dest_is_model) and isinstance(on, Column): + if on.source is src: + to_field = src_model._meta.columns[on.name] + elif on.source is dest: + to_field = dest_model._meta.columns[on.name] + else: + raise AttributeError('"on" clause Column %s does not ' + 'belong to %s or %s.' % + (on, src_model, dest_model)) + on = None + elif isinstance(on, Field): + to_field = on + on = None + else: + to_field = None + + fk_field, is_backref = self._generate_on_clause( + src_model, dest_model, to_field, on) + + if on is None: + src_attr = 'name' if src_is_model else 'column_name' + dest_attr = 'name' if dest_is_model else 'column_name' + if is_backref: + lhs = getattr(dest, getattr(fk_field, dest_attr)) + rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) + else: + lhs = getattr(src, getattr(fk_field, src_attr)) + rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) + on = (lhs == rhs) + + if not attr: + if fk_field is not None and not is_backref: + attr = fk_field.name + else: + attr = dest_model._meta.name + elif on_alias and fk_field is not None and \ + attr == fk_field.object_id_name and not is_backref: + raise ValueError('Cannot assign join alias to "%s", as this ' + 'attribute is the object_id_name for the ' + 'foreign-key field "%s"' % (attr, fk_field)) + + elif isinstance(dest, Source): + constructor = dict + attr = attr or dest._alias + if not attr and isinstance(dest, Table): + attr = attr or dest.__name__ + + return (on, attr, constructor) + + def _generate_on_clause(self, src, dest, to_field=None, on=None): + meta = src._meta + is_backref = fk_fields = False + + # Get all the foreign keys between source and dest, and determine if + # the join is via a back-reference. + if dest in meta.model_refs: + fk_fields = meta.model_refs[dest] + elif dest in meta.model_backrefs: + fk_fields = meta.model_backrefs[dest] + is_backref = True + + if not fk_fields: + if on is not None: + return None, False + raise ValueError('Unable to find foreign key between %s and %s. ' + 'Please specify an explicit join condition.' % + (src, dest)) + elif to_field is not None: + # If the foreign-key field was specified explicitly, remove all + # other foreign-key fields from the list. + target = (to_field.field if isinstance(to_field, FieldAlias) + else to_field) + fk_fields = [f for f in fk_fields if ( + (f is target) or + (is_backref and f.rel_field is to_field))] + + if len(fk_fields) == 1: + return fk_fields[0], is_backref + + if on is None: + raise ValueError('More than one foreign key between %s and %s.' + ' Please specify which you are joining on.' % + (src, dest)) + + # If there are multiple foreign-keys to choose from and the join + # predicate is an expression, we'll try to figure out which + # foreign-key field we're joining on so that we can assign to the + # correct attribute when resolving the model graph. + to_field = None + if isinstance(on, Expression): + lhs, rhs = on.lhs, on.rhs + # Coerce to set() so that we force Python to compare using the + # object's hash rather than equality test, which returns a + # false-positive due to overriding __eq__. + fk_set = set(fk_fields) + + if isinstance(lhs, Field): + lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs + if lhs_f in fk_set: + to_field = lhs_f + elif isinstance(rhs, Field): + rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs + if rhs_f in fk_set: + to_field = rhs_f + + return to_field, False + + @Node.copy + def join(self, dest, join_type='INNER', on=None, src=None, attr=None): + src = self._join_ctx if src is None else src + + if join_type != JOIN.CROSS: + on, attr, constructor = self._normalize_join(src, dest, on, attr) + if attr: + self._joins.setdefault(src, []) + self._joins[src].append((dest, attr, constructor)) + elif on is not None: + raise ValueError('Cannot specify on clause with cross join.') + + if not self._from_list: + raise ValueError('No sources to join on.') + + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + def join_from(self, src, dest, join_type='INNER', on=None, attr=None): + return self.join(dest, join_type, on, src, attr) + + def _get_model_cursor_wrapper(self, cursor): + if len(self._from_list) == 1 and not self._joins: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self.model) + return ModelCursorWrapper(cursor, self.model, self._returning, + self._from_list, self._joins) + + def ensure_join(self, lm, rm, on=None, **join_kwargs): + join_ctx = self._join_ctx + for dest, attr, constructor in self._joins.get(lm, []): + if dest == rm: + return self + return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) + + def convert_dict_to_node(self, qdict): + accum = [] + joins = [] + fks = (ForeignKeyField, BackrefAccessor) + for key, value in sorted(qdict.items()): + curr = self.model + if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: + key, op = key.rsplit('__', 1) + op = DJANGO_MAP[op] + elif value is None: + op = DJANGO_MAP['is'] + else: + op = DJANGO_MAP['eq'] + + if '__' not in key: + # Handle simplest case. This avoids joining over-eagerly when a + # direct FK lookup is all that is required. + model_attr = getattr(curr, key) + else: + for piece in key.split('__'): + for dest, attr, _ in self._joins.get(curr, ()): + if attr == piece or (isinstance(dest, ModelAlias) and + dest.alias == piece): + curr = dest + break + else: + model_attr = getattr(curr, piece) + if value is not None and isinstance(model_attr, fks): + curr = model_attr.rel_model + joins.append(model_attr) + accum.append(op(model_attr, value)) + return accum, joins + + def filter(self, *args, **kwargs): + # normalize args and kwargs into a new expression + dq_node = ColumnBase() + if args: + dq_node &= reduce(operator.and_, [a.clone() for a in args]) + if kwargs: + dq_node &= DQ(**kwargs) + + # dq_node should now be an Expression, lhs = Node(), rhs = ... + q = collections.deque([dq_node]) + dq_joins = set() + while q: + curr = q.popleft() + if not isinstance(curr, Expression): + continue + for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): + if isinstance(piece, DQ): + query, joins = self.convert_dict_to_node(piece.query) + dq_joins.update(joins) + expression = reduce(operator.and_, query) + # Apply values from the DQ object. + if piece._negated: + expression = Negated(expression) + #expression._alias = piece._alias + setattr(curr, side, expression) + else: + q.append(piece) + + dq_node = dq_node.rhs + + query = self.clone() + for field in dq_joins: + if isinstance(field, ForeignKeyField): + lm, rm = field.model, field.rel_model + field_obj = field + elif isinstance(field, BackrefAccessor): + lm, rm = field.model, field.rel_model + field_obj = field.field + query = query.ensure_join(lm, rm, field_obj) + return query.where(dq_node) + + def create_table(self, name, safe=True, **meta): + return self.model._schema.create_table_as(name, self, safe, **meta) + + def __sql_selection__(self, ctx, is_subquery=False): + if self._is_default and is_subquery and len(self._returning) > 1 and \ + self.model._meta.primary_key is not False: + return ctx.sql(self.model._meta.primary_key) + + return ctx.sql(CommaNodeList(self._returning)) + + +class NoopModelSelect(ModelSelect): + def __sql__(self, ctx): + return self.model._meta.database.get_noop_select(ctx) + + def _get_cursor_wrapper(self, cursor): + return CursorWrapper(cursor) + + +class _ModelWriteQueryHelper(_ModelQueryHelper): + def __init__(self, model, *args, **kwargs): + self.model = model + super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) + + def returning(self, *returning): + accum = [] + for item in returning: + if is_model(item): + accum.extend(item._meta.sorted_fields) + else: + accum.append(item) + return super(_ModelWriteQueryHelper, self).returning(*accum) + + def _set_table_alias(self, ctx): + table = self.model._meta.table + ctx.alias_manager[table] = table.__name__ + + +class ModelUpdate(_ModelWriteQueryHelper, Update): + pass + + +class ModelInsert(_ModelWriteQueryHelper, Insert): + default_row_type = ROW.TUPLE + + def __init__(self, *args, **kwargs): + super(ModelInsert, self).__init__(*args, **kwargs) + if self._returning is None and self.model._meta.database is not None: + if self.model._meta.database.returning_clause: + self._returning = self.model._meta.get_primary_keys() + + def returning(self, *returning): + # By default ModelInsert will yield a `tuple` containing the + # primary-key of the newly inserted row. But if we are explicitly + # specifying a returning clause and have not set a row type, we will + # default to returning model instances instead. + if returning and self._row_type is None: + self._row_type = ROW.MODEL + return super(ModelInsert, self).returning(*returning) + + def get_default_data(self): + return self.model._meta.defaults + + def get_default_columns(self): + fields = self.model._meta.sorted_fields + return fields[1:] if self.model._meta.auto_increment else fields + + +class ModelDelete(_ModelWriteQueryHelper, Delete): + pass + + +class ManyToManyQuery(ModelSelect): + def __init__(self, instance, accessor, rel, *args, **kwargs): + self._instance = instance + self._accessor = accessor + self._src_attr = accessor.src_fk.rel_field.name + self._dest_attr = accessor.dest_fk.rel_field.name + super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) + + def _id_list(self, model_or_id_list): + if isinstance(model_or_id_list[0], Model): + return [getattr(obj, self._dest_attr) for obj in model_or_id_list] + return model_or_id_list + + def add(self, value, clear_existing=False): + if clear_existing: + self.clear() + + accessor = self._accessor + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + query = value.columns( + Value(src_id), + accessor.dest_fk.rel_field) + accessor.through_model.insert_from( + fields=[accessor.src_fk, accessor.dest_fk], + query=query).execute() + else: + value = ensure_tuple(value) + if not value: return + + inserts = [{ + accessor.src_fk.name: src_id, + accessor.dest_fk.name: rel_id} + for rel_id in self._id_list(value)] + accessor.through_model.insert_many(inserts).execute() + + def remove(self, value): + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + column = getattr(value.model, self._dest_attr) + subquery = value.columns(column) + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << subquery) & + (self._accessor.src_fk == src_id)) + .execute()) + else: + value = ensure_tuple(value) + if not value: + return + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << self._id_list(value)) & + (self._accessor.src_fk == src_id)) + .execute()) + + def clear(self): + src_id = getattr(self._instance, self._src_attr) + return (self._accessor.through_model + .delete() + .where(self._accessor.src_fk == src_id) + .execute()) + + +def safe_python_value(conv_func): + def validate(value): + try: + return conv_func(value) + except (TypeError, ValueError): + return value + return validate + + +class BaseModelCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, model, columns): + super(BaseModelCursorWrapper, self).__init__(cursor) + self.model = model + self.select = columns or [] + + def _initialize_columns(self): + combined = self.model._meta.combined + table = self.model._meta.table + description = self.cursor.description + + self.ncols = len(self.cursor.description) + self.columns = [] + self.converters = converters = [None] * self.ncols + self.fields = fields = [None] * self.ncols + + for idx, description_item in enumerate(description): + column = description_item[0] + dot_index = column.find('.') + if dot_index != -1: + column = column[dot_index + 1:] + + column = column.strip('"') + self.columns.append(column) + try: + raw_node = self.select[idx] + except IndexError: + if column in combined: + raw_node = node = combined[column] + else: + continue + else: + node = raw_node.unwrap() + + # Heuristics used to attempt to get the field associated with a + # given SELECT column, so that we can accurately convert the value + # returned by the database-cursor into a Python object. + if isinstance(node, Field): + if raw_node._coerce: + converters[idx] = node.python_value + fields[idx] = node + if (column == node.name or column == node.column_name) and \ + not raw_node.is_alias(): + self.columns[idx] = node.name + elif isinstance(node, Function) and node._coerce: + if node._python_value is not None: + converters[idx] = node._python_value + elif node.arguments and isinstance(node.arguments[0], Node): + # If the first argument is a field or references a column + # on a Model, try using that field's conversion function. + # This usually works, but we use "safe_python_value()" so + # that if a TypeError or ValueError occurs during + # conversion we can just fall-back to the raw cursor value. + first = node.arguments[0].unwrap() + if isinstance(first, Entity): + path = first._path[-1] # Try to look-up by name. + first = combined.get(path) + if isinstance(first, Field): + converters[idx] = safe_python_value(first.python_value) + elif column in combined: + if node._coerce: + converters[idx] = combined[column].python_value + if isinstance(node, Column) and node.source == table: + fields[idx] = combined[column] + + initialize = _initialize_columns + + def process_row(self, row): + raise NotImplementedError + + +class ModelDictCursorWrapper(BaseModelCursorWrapper): + def process_row(self, row): + result = {} + columns, converters = self.columns, self.converters + fields = self.fields + + for i in range(self.ncols): + attr = columns[i] + if attr in result: continue # Don't overwrite if we have dupes. + if converters[i] is not None: + result[attr] = converters[i](row[i]) + else: + result[attr] = row[i] + + return result + + +class ModelTupleCursorWrapper(ModelDictCursorWrapper): + constructor = tuple + + def process_row(self, row): + columns, converters = self.columns, self.converters + return self.constructor([ + (converters[i](row[i]) if converters[i] is not None else row[i]) + for i in range(self.ncols)]) + + +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): + def initialize(self): + self._initialize_columns() + attributes = [] + for i in range(self.ncols): + attributes.append(self.columns[i]) + self.tuple_class = collections.namedtuple('Row', attributes) + self.constructor = lambda row: self.tuple_class(*row) + + +class ModelObjectCursorWrapper(ModelDictCursorWrapper): + def __init__(self, cursor, model, select, constructor): + self.constructor = constructor + self.is_model = is_model(constructor) + super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) + + def process_row(self, row): + data = super(ModelObjectCursorWrapper, self).process_row(row) + if self.is_model: + # Clear out any dirty fields before returning to the user. + obj = self.constructor(__no_default__=1, **data) + obj._dirty.clear() + return obj + else: + return self.constructor(**data) + + +class ModelCursorWrapper(BaseModelCursorWrapper): + def __init__(self, cursor, model, select, from_list, joins): + super(ModelCursorWrapper, self).__init__(cursor, model, select) + self.from_list = from_list + self.joins = joins + + def initialize(self): + self._initialize_columns() + selected_src = set([field.model for field in self.fields + if field is not None]) + select, columns = self.select, self.columns + + self.key_to_constructor = {self.model: self.model} + self.src_is_dest = {} + self.src_to_dest = [] + accum = collections.deque(self.from_list) + dests = set() + while accum: + curr = accum.popleft() + if isinstance(curr, Join): + accum.append(curr.lhs) + accum.append(curr.rhs) + continue + + if curr not in self.joins: + continue + + for key, attr, constructor in self.joins[curr]: + if key not in self.key_to_constructor: + self.key_to_constructor[key] = constructor + self.src_to_dest.append((curr, attr, key, + isinstance(curr, dict))) + dests.add(key) + accum.append(key) + + # Ensure that we accommodate everything selected. + for src in selected_src: + if src not in self.key_to_constructor: + if is_model(src): + self.key_to_constructor[src] = src + elif isinstance(src, ModelAlias): + self.key_to_constructor[src] = src.model + + # Indicate which sources are also dests. + for src, _, dest, _ in self.src_to_dest: + self.src_is_dest[src] = src in dests and (dest in selected_src + or src in selected_src) + + self.column_keys = [] + for idx, node in enumerate(select): + key = self.model + field = self.fields[idx] + if field is not None: + if isinstance(field, FieldAlias): + key = field.source + else: + key = field.model + else: + if isinstance(node, Node): + node = node.unwrap() + if isinstance(node, Column): + key = node.source + + self.column_keys.append(key) + + def process_row(self, row): + objects = {} + object_list = [] + for key, constructor in self.key_to_constructor.items(): + objects[key] = constructor(__no_default__=True) + object_list.append(objects[key]) + + set_keys = set() + for idx, key in enumerate(self.column_keys): + instance = objects[key] + column = self.columns[idx] + value = row[idx] + if value is not None: + set_keys.add(key) + if self.converters[idx]: + value = self.converters[idx](value) + + if isinstance(instance, dict): + instance[column] = value + else: + setattr(instance, column, value) + + # Need to do some analysis on the joins before this. + for (src, attr, dest, is_dict) in self.src_to_dest: + instance = objects[src] + try: + joined_instance = objects[dest] + except KeyError: + continue + + # If no fields were set on the destination instance then do not + # assign an "empty" instance. + if instance is None or dest is None or \ + (dest not in set_keys and not self.src_is_dest.get(dest)): + continue + + if is_dict: + instance[attr] = joined_instance + else: + setattr(instance, attr, joined_instance) + + # When instantiating models from a cursor, we clear the dirty fields. + for instance in object_list: + if isinstance(instance, Model): + instance._dirty.clear() + + return objects[self.model] + + +class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( + 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): + def __new__(cls, query, fields=None, is_backref=None, rel_models=None, + field_to_name=None, model=None): + if fields: + if is_backref: + if rel_models is None: + rel_models = [field.model for field in fields] + foreign_key_attrs = [field.rel_field.name for field in fields] + else: + if rel_models is None: + rel_models = [field.rel_model for field in fields] + foreign_key_attrs = [field.name for field in fields] + field_to_name = list(zip(fields, foreign_key_attrs)) + model = query.model + return super(PrefetchQuery, cls).__new__( + cls, query, fields, is_backref, rel_models, field_to_name, model) + + def populate_instance(self, instance, id_map): + if self.is_backref: + for field in self.fields: + identifier = instance.__data__[field.name] + key = (field, identifier) + if key in id_map: + setattr(instance, field.name, id_map[key]) + else: + for field, attname in self.field_to_name: + identifier = instance.__data__[field.rel_field.name] + key = (field, identifier) + rel_instances = id_map.get(key, []) + for inst in rel_instances: + setattr(inst, attname, instance) + setattr(instance, field.backref, rel_instances) + + def store_instance(self, instance, id_map): + for field, attname in self.field_to_name: + identity = field.rel_field.python_value(instance.__data__[attname]) + key = (field, identity) + if self.is_backref: + id_map[key] = instance + else: + id_map.setdefault(key, []) + id_map[key].append(instance) + + +def prefetch_add_subquery(sq, subqueries): + fixed_queries = [PrefetchQuery(sq)] + for i, subquery in enumerate(subqueries): + if isinstance(subquery, tuple): + subquery, target_model = subquery + else: + target_model = None + if not isinstance(subquery, Query) and is_model(subquery) or \ + isinstance(subquery, ModelAlias): + subquery = subquery.select() + subquery_model = subquery.model + fks = backrefs = None + for j in reversed(range(i + 1)): + fixed = fixed_queries[j] + last_query = fixed.query + last_model = last_obj = fixed.model + if isinstance(last_model, ModelAlias): + last_model = last_model.model + rels = subquery_model._meta.model_refs.get(last_model, []) + if rels: + fks = [getattr(subquery_model, fk.name) for fk in rels] + pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] + else: + backrefs = subquery_model._meta.model_backrefs.get(last_model) + if (fks or backrefs) and ((target_model is last_obj) or + (target_model is None)): + break + + if not fks and not backrefs: + tgt_err = ' using %s' % target_model if target_model else '' + raise AttributeError('Error: unable to find foreign key for ' + 'query: %s%s' % (subquery, tgt_err)) + + dest = (target_model,) if target_model else None + + if fks: + expr = reduce(operator.or_, [ + (fk << last_query.select(pk)) + for (fk, pk) in zip(fks, pks)]) + subquery = subquery.where(expr) + fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) + elif backrefs: + expressions = [] + for backref in backrefs: + rel_field = getattr(subquery_model, backref.rel_field.name) + fk_field = getattr(last_obj, backref.name) + expressions.append(rel_field << last_query.select(fk_field)) + subquery = subquery.where(reduce(operator.or_, expressions)) + fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) + + return fixed_queries + + +def prefetch(sq, *subqueries): + if not subqueries: + return sq + + fixed_queries = prefetch_add_subquery(sq, subqueries) + deps = {} + rel_map = {} + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps[query_model] = {} + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + + for instance in pq.query: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return list(pq.query) diff --git a/libs/playhouse/__init__.py b/libs/playhouse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/playhouse/apsw_ext.py b/libs/playhouse/apsw_ext.py new file mode 100644 index 000000000..0aa35939b --- /dev/null +++ b/libs/playhouse/apsw_ext.py @@ -0,0 +1,136 @@ +""" +Peewee integration with APSW, "another python sqlite wrapper". + +Project page: https://rogerbinns.github.io/apsw/ + +APSW is a really neat library that provides a thin wrapper on top of SQLite's +C interface. + +Here are just a few reasons to use APSW, taken from the documentation: + +* APSW gives all functionality of SQLite, including virtual tables, virtual + file system, blob i/o, backups and file control. +* Connections can be shared across threads without any additional locking. +* Transactions are managed explicitly by your code. +* APSW can handle nested transactions. +* Unicode is handled correctly. +* APSW is faster. +""" +import apsw +from peewee import * +from peewee import __exception_wrapper__ +from peewee import BooleanField as _BooleanField +from peewee import DateField as _DateField +from peewee import DateTimeField as _DateTimeField +from peewee import DecimalField as _DecimalField +from peewee import TimeField as _TimeField +from peewee import logger + +from playhouse.sqlite_ext import SqliteExtDatabase + + +class APSWDatabase(SqliteExtDatabase): + server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) + + def __init__(self, database, **kwargs): + self._modules = {} + super(APSWDatabase, self).__init__(database, **kwargs) + + def register_module(self, mod_name, mod_inst): + self._modules[mod_name] = mod_inst + if not self.is_closed(): + self.connection().createmodule(mod_name, mod_inst) + + def unregister_module(self, mod_name): + del(self._modules[mod_name]) + + def _connect(self): + conn = apsw.Connection(self.database, **self.connect_params) + if self._timeout is not None: + conn.setbusytimeout(self._timeout * 1000) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + super(APSWDatabase, self)._add_conn_hooks(conn) + self._load_modules(conn) # APSW-only. + + def _load_modules(self, conn): + for mod_name, mod_inst in self._modules.items(): + conn.createmodule(mod_name, mod_inst) + return conn + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + def make_aggregate(): + return (klass(), klass.step, klass.finalize) + conn.createaggregatefunction(name, make_aggregate) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.createcollation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.createscalarfunction(name, fn, num_params) + + def _load_extensions(self, conn): + conn.enableloadextension(True) + for extension in self._extensions: + conn.loadextension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enableloadextension(True) + conn.loadextension(extension) + + def last_insert_id(self, cursor, query_type=None): + return cursor.getconnection().last_insert_rowid() + + def rows_affected(self, cursor): + return cursor.getconnection().changes() + + def begin(self, lock_type='deferred'): + self.cursor().execute('begin %s;' % lock_type) + + def commit(self): + self.cursor().execute('commit;') + + def rollback(self): + self.cursor().execute('rollback;') + + def execute_sql(self, sql, params=None, commit=True): + logger.debug((sql, params)) + with __exception_wrapper__: + cursor = self.cursor() + cursor.execute(sql, params or ()) + return cursor + + +def nh(s, v): + if v is not None: + return str(v) + +class BooleanField(_BooleanField): + def db_value(self, v): + v = super(BooleanField, self).db_value(v) + if v is not None: + return v and 1 or 0 + +class DateField(_DateField): + db_value = nh + +class TimeField(_TimeField): + db_value = nh + +class DateTimeField(_DateTimeField): + db_value = nh + +class DecimalField(_DecimalField): + db_value = nh diff --git a/libs/playhouse/dataset.py b/libs/playhouse/dataset.py new file mode 100644 index 000000000..27f8189bb --- /dev/null +++ b/libs/playhouse/dataset.py @@ -0,0 +1,415 @@ +import csv +import datetime +from decimal import Decimal +import json +import operator +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse +import sys + +from peewee import * +from playhouse.db_url import connect +from playhouse.migrate import migrate +from playhouse.migrate import SchemaMigrator +from playhouse.reflection import Introspector + +if sys.version_info[0] == 3: + basestring = str + from functools import reduce + def open_file(f, mode): + return open(f, mode, encoding='utf8') +else: + open_file = open + + +class DataSet(object): + def __init__(self, url, bare_fields=False): + if isinstance(url, Database): + self._url = None + self._database = url + self._database_path = self._database.database + else: + self._url = url + parse_result = urlparse(url) + self._database_path = parse_result.path[1:] + + # Connect to the database. + self._database = connect(url) + + self._database.connect() + + # Introspect the database and generate models. + self._introspector = Introspector.from_database(self._database) + self._models = self._introspector.generate_models( + skip_invalid=True, + literal_column_names=True, + bare_fields=bare_fields) + self._migrator = SchemaMigrator.from_database(self._database) + + class BaseModel(Model): + class Meta: + database = self._database + self._base_model = BaseModel + self._export_formats = self.get_export_formats() + self._import_formats = self.get_import_formats() + + def __repr__(self): + return '' % self._database_path + + def get_export_formats(self): + return { + 'csv': CSVExporter, + 'json': JSONExporter} + + def get_import_formats(self): + return { + 'csv': CSVImporter, + 'json': JSONImporter} + + def __getitem__(self, table): + if table not in self._models and table in self.tables: + self.update_cache(table) + return Table(self, table, self._models.get(table)) + + @property + def tables(self): + return self._database.get_tables() + + def __contains__(self, table): + return table in self.tables + + def connect(self): + self._database.connect() + + def close(self): + self._database.close() + + def update_cache(self, table=None): + if table: + dependencies = [table] + if table in self._models: + model_class = self._models[table] + dependencies.extend([ + related._meta.table_name for _, related, _ in + model_class._meta.model_graph()]) + else: + dependencies.extend(self.get_table_dependencies(table)) + else: + dependencies = None # Update all tables. + self._models = {} + updated = self._introspector.generate_models( + skip_invalid=True, + table_names=dependencies, + literal_column_names=True) + self._models.update(updated) + + def get_table_dependencies(self, table): + stack = [table] + accum = [] + seen = set() + while stack: + table = stack.pop() + for fk_meta in self._database.get_foreign_keys(table): + dest = fk_meta.dest_table + if dest not in seen: + stack.append(dest) + accum.append(dest) + return accum + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._database.is_closed(): + self.close() + + def query(self, sql, params=None, commit=True): + return self._database.execute_sql(sql, params, commit) + + def transaction(self): + if self._database.transaction_depth() == 0: + return self._database.transaction() + else: + return self._database.savepoint() + + def _check_arguments(self, filename, file_obj, format, format_dict): + if filename and file_obj: + raise ValueError('file is over-specified. Please use either ' + 'filename or file_obj, but not both.') + if not filename and not file_obj: + raise ValueError('A filename or file-like object must be ' + 'specified.') + if format not in format_dict: + valid_formats = ', '.join(sorted(format_dict.keys())) + raise ValueError('Unsupported format "%s". Use one of %s.' % ( + format, valid_formats)) + + def freeze(self, query, format='csv', filename=None, file_obj=None, + **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'w') + + exporter = self._export_formats[format](query) + exporter.export(file_obj, **kwargs) + + if filename: + file_obj.close() + + def thaw(self, table, format='csv', filename=None, file_obj=None, + strict=False, **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'r') + + importer = self._import_formats[format](self[table], strict) + count = importer.load(file_obj, **kwargs) + + if filename: + file_obj.close() + + return count + + +class Table(object): + def __init__(self, dataset, name, model_class): + self.dataset = dataset + self.name = name + if model_class is None: + model_class = self._create_model() + model_class.create_table() + self.dataset._models[name] = model_class + + @property + def model_class(self): + return self.dataset._models[self.name] + + def __repr__(self): + return '' % self.name + + def __len__(self): + return self.find().count() + + def __iter__(self): + return iter(self.find().iterator()) + + def _create_model(self): + class Meta: + table_name = self.name + return type( + str(self.name), + (self.dataset._base_model,), + {'Meta': Meta}) + + def create_index(self, columns, unique=False): + self.dataset._database.create_index( + self.model_class, + columns, + unique=unique) + + def _guess_field_type(self, value): + if isinstance(value, basestring): + return TextField + if isinstance(value, (datetime.date, datetime.datetime)): + return DateTimeField + elif value is True or value is False: + return BooleanField + elif isinstance(value, int): + return IntegerField + elif isinstance(value, float): + return FloatField + elif isinstance(value, Decimal): + return DecimalField + return TextField + + @property + def columns(self): + return [f.name for f in self.model_class._meta.sorted_fields] + + def _migrate_new_columns(self, data): + new_keys = set(data) - set(self.model_class._meta.fields) + if new_keys: + operations = [] + for key in new_keys: + field_class = self._guess_field_type(data[key]) + field = field_class(null=True) + operations.append( + self.dataset._migrator.add_column(self.name, key, field)) + field.bind(self.model_class, key) + + migrate(*operations) + + self.dataset.update_cache(self.name) + + def insert(self, **data): + self._migrate_new_columns(data) + return self.model_class.insert(**data).execute() + + def _apply_where(self, query, filters, conjunction=None): + conjunction = conjunction or operator.and_ + if filters: + expressions = [ + (self.model_class._meta.fields[column] == value) + for column, value in filters.items()] + query = query.where(reduce(conjunction, expressions)) + return query + + def update(self, columns=None, conjunction=None, **data): + self._migrate_new_columns(data) + filters = {} + if columns: + for column in columns: + filters[column] = data.pop(column) + + return self._apply_where( + self.model_class.update(**data), + filters, + conjunction).execute() + + def _query(self, **query): + return self._apply_where(self.model_class.select(), query) + + def find(self, **query): + return self._query(**query).dicts() + + def find_one(self, **query): + try: + return self.find(**query).get() + except self.model_class.DoesNotExist: + return None + + def all(self): + return self.find() + + def delete(self, **query): + return self._apply_where(self.model_class.delete(), query).execute() + + def freeze(self, *args, **kwargs): + return self.dataset.freeze(self.all(), *args, **kwargs) + + def thaw(self, *args, **kwargs): + return self.dataset.thaw(self.name, *args, **kwargs) + + +class Exporter(object): + def __init__(self, query): + self.query = query + + def export(self, file_obj): + raise NotImplementedError + + +class JSONExporter(Exporter): + def __init__(self, query, iso8601_datetimes=False): + super(JSONExporter, self).__init__(query) + self.iso8601_datetimes = iso8601_datetimes + + def _make_default(self): + datetime_types = (datetime.datetime, datetime.date, datetime.time) + + if self.iso8601_datetimes: + def default(o): + if isinstance(o, datetime_types): + return o.isoformat() + elif isinstance(o, Decimal): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + else: + def default(o): + if isinstance(o, datetime_types + (Decimal,)): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + return default + + def export(self, file_obj, **kwargs): + json.dump( + list(self.query), + file_obj, + default=self._make_default(), + **kwargs) + + +class CSVExporter(Exporter): + def export(self, file_obj, header=True, **kwargs): + writer = csv.writer(file_obj, **kwargs) + tuples = self.query.tuples().execute() + tuples.initialize() + if header and getattr(tuples, 'columns', None): + writer.writerow([column for column in tuples.columns]) + for row in tuples: + writer.writerow(row) + + +class Importer(object): + def __init__(self, table, strict=False): + self.table = table + self.strict = strict + + model = self.table.model_class + self.columns = model._meta.columns + self.columns.update(model._meta.fields) + + def load(self, file_obj): + raise NotImplementedError + + +class JSONImporter(Importer): + def load(self, file_obj, **kwargs): + data = json.load(file_obj, **kwargs) + count = 0 + + for row in data: + if self.strict: + obj = {} + for key in row: + field = self.columns.get(key) + if field is not None: + obj[field.name] = field.python_value(row[key]) + else: + obj = row + + if obj: + self.table.insert(**obj) + count += 1 + + return count + + +class CSVImporter(Importer): + def load(self, file_obj, header=True, **kwargs): + count = 0 + reader = csv.reader(file_obj, **kwargs) + if header: + try: + header_keys = next(reader) + except StopIteration: + return count + + if self.strict: + header_fields = [] + for idx, key in enumerate(header_keys): + if key in self.columns: + header_fields.append((idx, self.columns[key])) + else: + header_fields = list(enumerate(header_keys)) + else: + header_fields = list(enumerate(self.model._meta.sorted_fields)) + + if not header_fields: + return count + + for row in reader: + obj = {} + for idx, field in header_fields: + if self.strict: + obj[field.name] = field.python_value(row[idx]) + else: + obj[field] = row[idx] + + self.table.insert(**obj) + count += 1 + + return count diff --git a/libs/playhouse/db_url.py b/libs/playhouse/db_url.py new file mode 100644 index 000000000..fcc4ab87a --- /dev/null +++ b/libs/playhouse/db_url.py @@ -0,0 +1,124 @@ +try: + from urlparse import parse_qsl, unquote, urlparse +except ImportError: + from urllib.parse import parse_qsl, unquote, urlparse + +from peewee import * +from playhouse.pool import PooledMySQLDatabase +from playhouse.pool import PooledPostgresqlDatabase +from playhouse.pool import PooledSqliteDatabase +from playhouse.pool import PooledSqliteExtDatabase +from playhouse.sqlite_ext import SqliteExtDatabase + + +schemes = { + 'mysql': MySQLDatabase, + 'mysql+pool': PooledMySQLDatabase, + 'postgres': PostgresqlDatabase, + 'postgresql': PostgresqlDatabase, + 'postgres+pool': PooledPostgresqlDatabase, + 'postgresql+pool': PooledPostgresqlDatabase, + 'sqlite': SqliteDatabase, + 'sqliteext': SqliteExtDatabase, + 'sqlite+pool': PooledSqliteDatabase, + 'sqliteext+pool': PooledSqliteExtDatabase, +} + +def register_database(db_class, *names): + global schemes + for name in names: + schemes[name] = db_class + +def parseresult_to_dict(parsed, unquote_password=False): + + # urlparse in python 2.6 is broken so query will be empty and instead + # appended to path complete with '?' + path_parts = parsed.path[1:].split('?') + try: + query = path_parts[1] + except IndexError: + query = parsed.query + + connect_kwargs = {'database': path_parts[0]} + if parsed.username: + connect_kwargs['user'] = parsed.username + if parsed.password: + connect_kwargs['password'] = parsed.password + if unquote_password: + connect_kwargs['password'] = unquote(connect_kwargs['password']) + if parsed.hostname: + connect_kwargs['host'] = parsed.hostname + if parsed.port: + connect_kwargs['port'] = parsed.port + + # Adjust parameters for MySQL. + if parsed.scheme == 'mysql' and 'password' in connect_kwargs: + connect_kwargs['passwd'] = connect_kwargs.pop('password') + elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: + connect_kwargs['database'] = ':memory:' + + # Get additional connection args from the query string + qs_args = parse_qsl(query, keep_blank_values=True) + for key, value in qs_args: + if value.lower() == 'false': + value = False + elif value.lower() == 'true': + value = True + elif value.isdigit(): + value = int(value) + elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): + try: + value = float(value) + except ValueError: + pass + elif value.lower() in ('null', 'none'): + value = None + + connect_kwargs[key] = value + + return connect_kwargs + +def parse(url, unquote_password=False): + parsed = urlparse(url) + return parseresult_to_dict(parsed, unquote_password) + +def connect(url, unquote_password=False, **connect_params): + parsed = urlparse(url) + connect_kwargs = parseresult_to_dict(parsed, unquote_password) + connect_kwargs.update(connect_params) + database_class = schemes.get(parsed.scheme) + + if database_class is None: + if database_class in schemes: + raise RuntimeError('Attempted to use "%s" but a required library ' + 'could not be imported.' % parsed.scheme) + else: + raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % + parsed.scheme) + + return database_class(**connect_kwargs) + +# Conditionally register additional databases. +try: + from playhouse.pool import PooledPostgresqlExtDatabase +except ImportError: + pass +else: + register_database( + PooledPostgresqlExtDatabase, + 'postgresext+pool', + 'postgresqlext+pool') + +try: + from playhouse.apsw_ext import APSWDatabase +except ImportError: + pass +else: + register_database(APSWDatabase, 'apsw') + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase +except ImportError: + pass +else: + register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/libs/playhouse/fields.py b/libs/playhouse/fields.py new file mode 100644 index 000000000..fce1a3d6d --- /dev/null +++ b/libs/playhouse/fields.py @@ -0,0 +1,64 @@ +try: + import bz2 +except ImportError: + bz2 = None +try: + import zlib +except ImportError: + zlib = None +try: + import cPickle as pickle +except ImportError: + import pickle +import sys + +from peewee import BlobField +from peewee import buffer_type + + +PY2 = sys.version_info[0] == 2 + + +class CompressedField(BlobField): + ZLIB = 'zlib' + BZ2 = 'bz2' + algorithm_to_import = { + ZLIB: zlib, + BZ2: bz2, + } + + def __init__(self, compression_level=6, algorithm=ZLIB, *args, + **kwargs): + self.compression_level = compression_level + if algorithm not in self.algorithm_to_import: + raise ValueError('Unrecognized algorithm %s' % algorithm) + compress_module = self.algorithm_to_import[algorithm] + if compress_module is None: + raise ValueError('Missing library required for %s.' % algorithm) + + self.algorithm = algorithm + self.compress = compress_module.compress + self.decompress = compress_module.decompress + super(CompressedField, self).__init__(*args, **kwargs) + + def python_value(self, value): + if value is not None: + return self.decompress(value) + + def db_value(self, value): + if value is not None: + return self._constructor( + self.compress(value, self.compression_level)) + + +class PickleField(BlobField): + def python_value(self, value): + if value is not None: + if isinstance(value, buffer_type): + value = bytes(value) + return pickle.loads(value) + + def db_value(self, value): + if value is not None: + pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) + return self._constructor(pickled) diff --git a/libs/playhouse/flask_utils.py b/libs/playhouse/flask_utils.py new file mode 100644 index 000000000..76a2a62f4 --- /dev/null +++ b/libs/playhouse/flask_utils.py @@ -0,0 +1,185 @@ +import math +import sys + +from flask import abort +from flask import render_template +from flask import request +from peewee import Database +from peewee import DoesNotExist +from peewee import Model +from peewee import Proxy +from peewee import SelectQuery +from playhouse.db_url import connect as db_url_connect + + +class PaginatedQuery(object): + def __init__(self, query_or_model, paginate_by, page_var='page', page=None, + check_bounds=False): + self.paginate_by = paginate_by + self.page_var = page_var + self.page = page or None + self.check_bounds = check_bounds + + if isinstance(query_or_model, SelectQuery): + self.query = query_or_model + self.model = self.query.model + else: + self.model = query_or_model + self.query = self.model.select() + + def get_page(self): + if self.page is not None: + return self.page + + curr_page = request.args.get(self.page_var) + if curr_page and curr_page.isdigit(): + return max(1, int(curr_page)) + return 1 + + def get_page_count(self): + if not hasattr(self, '_page_count'): + self._page_count = int(math.ceil( + float(self.query.count()) / self.paginate_by)) + return self._page_count + + def get_object_list(self): + if self.check_bounds and self.get_page() > self.get_page_count(): + abort(404) + return self.query.paginate(self.get_page(), self.paginate_by) + + +def get_object_or_404(query_or_model, *query): + if not isinstance(query_or_model, SelectQuery): + query_or_model = query_or_model.select() + try: + return query_or_model.where(*query).get() + except DoesNotExist: + abort(404) + +def object_list(template_name, query, context_variable='object_list', + paginate_by=20, page_var='page', page=None, check_bounds=True, + **kwargs): + paginated_query = PaginatedQuery( + query, + paginate_by=paginate_by, + page_var=page_var, + page=page, + check_bounds=check_bounds) + kwargs[context_variable] = paginated_query.get_object_list() + return render_template( + template_name, + pagination=paginated_query, + page=paginated_query.get_page(), + **kwargs) + +def get_current_url(): + if not request.query_string: + return request.path + return '%s?%s' % (request.path, request.query_string) + +def get_next_url(default='/'): + if request.args.get('next'): + return request.args['next'] + elif request.form.get('next'): + return request.form['next'] + return default + +class FlaskDB(object): + def __init__(self, app=None, database=None, model_class=Model): + self.database = None # Reference to actual Peewee database instance. + self.base_model_class = model_class + self._app = app + self._db = database # dict, url, Database, or None (default). + if app is not None: + self.init_app(app) + + def init_app(self, app): + self._app = app + + if self._db is None: + if 'DATABASE' in app.config: + initial_db = app.config['DATABASE'] + elif 'DATABASE_URL' in app.config: + initial_db = app.config['DATABASE_URL'] + else: + raise ValueError('Missing required configuration data for ' + 'database: DATABASE or DATABASE_URL.') + else: + initial_db = self._db + + self._load_database(app, initial_db) + self._register_handlers(app) + + def _load_database(self, app, config_value): + if isinstance(config_value, Database): + database = config_value + elif isinstance(config_value, dict): + database = self._load_from_config_dict(dict(config_value)) + else: + # Assume a database connection URL. + database = db_url_connect(config_value) + + if isinstance(self.database, Proxy): + self.database.initialize(database) + else: + self.database = database + + def _load_from_config_dict(self, config_dict): + try: + name = config_dict.pop('name') + engine = config_dict.pop('engine') + except KeyError: + raise RuntimeError('DATABASE configuration must specify a ' + '`name` and `engine`.') + + if '.' in engine: + path, class_name = engine.rsplit('.', 1) + else: + path, class_name = 'peewee', engine + + try: + __import__(path) + module = sys.modules[path] + database_class = getattr(module, class_name) + assert issubclass(database_class, Database) + except ImportError: + raise RuntimeError('Unable to import %s' % engine) + except AttributeError: + raise RuntimeError('Database engine not found %s' % engine) + except AssertionError: + raise RuntimeError('Database engine not a subclass of ' + 'peewee.Database: %s' % engine) + + return database_class(name, **config_dict) + + def _register_handlers(self, app): + app.before_request(self.connect_db) + app.teardown_request(self.close_db) + + def get_model_class(self): + if self.database is None: + raise RuntimeError('Database must be initialized.') + + class BaseModel(self.base_model_class): + class Meta: + database = self.database + + return BaseModel + + @property + def Model(self): + if self._app is None: + database = getattr(self, 'database', None) + if database is None: + self.database = Proxy() + + if not hasattr(self, '_model_class'): + self._model_class = self.get_model_class() + return self._model_class + + def connect_db(self): + self.database.connect() + + def close_db(self, exc): + if not self.database.is_closed(): + self.database.close() diff --git a/libs/playhouse/hybrid.py b/libs/playhouse/hybrid.py new file mode 100644 index 000000000..53f226288 --- /dev/null +++ b/libs/playhouse/hybrid.py @@ -0,0 +1,50 @@ +# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: +# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html +class hybrid_method(object): + def __init__(self, func, expr=None): + self.func = func + self.expr = expr or func + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr.__get__(instance_type, instance_type.__class__) + return self.func.__get__(instance, instance_type) + + def expression(self, expr): + self.expr = expr + return self + + +class hybrid_property(object): + def __init__(self, fget, fset=None, fdel=None, expr=None): + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr(instance_type) + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError('Cannot set attribute.') + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError('Cannot delete attribute.') + self.fdel(instance) + + def setter(self, fset): + self.fset = fset + return self + + def deleter(self, fdel): + self.fdel = fdel + return self + + def expression(self, expr): + self.expr = expr + return self diff --git a/libs/playhouse/kv.py b/libs/playhouse/kv.py new file mode 100644 index 000000000..742b49cad --- /dev/null +++ b/libs/playhouse/kv.py @@ -0,0 +1,172 @@ +import operator + +from peewee import * +from peewee import Expression +from playhouse.fields import PickleField +try: + from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase +except ImportError: + from playhouse.sqlite_ext import SqliteExtDatabase + + +Sentinel = type('Sentinel', (object,), {}) + + +class KeyValue(object): + """ + Persistent dictionary. + + :param Field key_field: field to use for key. Defaults to CharField. + :param Field value_field: field to use for value. Defaults to PickleField. + :param bool ordered: data should be returned in key-sorted order. + :param Database database: database where key/value data is stored. + :param str table_name: table name for data. + """ + def __init__(self, key_field=None, value_field=None, ordered=False, + database=None, table_name='keyvalue'): + if key_field is None: + key_field = CharField(max_length=255, primary_key=True) + if not key_field.primary_key: + raise ValueError('key_field must have primary_key=True.') + + if value_field is None: + value_field = PickleField() + + self._key_field = key_field + self._value_field = value_field + self._ordered = ordered + self._database = database or SqliteExtDatabase(':memory:') + self._table_name = table_name + if isinstance(self._database, PostgresqlDatabase): + self.upsert = self._postgres_upsert + self.update = self._postgres_update + else: + self.upsert = self._upsert + self.update = self._update + + self.model = self.create_model() + self.key = self.model.key + self.value = self.model.value + + # Ensure table exists. + self.model.create_table() + + def create_model(self): + class KeyValue(Model): + key = self._key_field + value = self._value_field + class Meta: + database = self._database + table_name = self._table_name + return KeyValue + + def query(self, *select): + query = self.model.select(*select).tuples() + if self._ordered: + query = query.order_by(self.key) + return query + + def convert_expression(self, expr): + if not isinstance(expr, Expression): + return (self.key == expr), True + return expr, False + + def __contains__(self, key): + expr, _ = self.convert_expression(key) + return self.model.select().where(expr).exists() + + def __len__(self): + return len(self.model) + + def __getitem__(self, expr): + converted, is_single = self.convert_expression(expr) + query = self.query(self.value).where(converted) + item_getter = operator.itemgetter(0) + result = [item_getter(row) for row in query] + if len(result) == 0 and is_single: + raise KeyError(expr) + elif is_single: + return result[0] + return result + + def _upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict('replace') + .execute()) + + def _postgres_upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def __setitem__(self, expr, value): + if isinstance(expr, Expression): + self.model.update(value=value).where(expr).execute() + else: + self.upsert(expr, value) + + def __delitem__(self, expr): + converted, _ = self.convert_expression(expr) + self.model.delete().where(converted).execute() + + def __iter__(self): + return iter(self.query().execute()) + + def keys(self): + return map(operator.itemgetter(0), self.query(self.key)) + + def values(self): + return map(operator.itemgetter(0), self.query(self.value)) + + def items(self): + return iter(self.query().execute()) + + def _update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict('replace') + .execute()) + + def _postgres_update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, default=Sentinel): + with self._database.atomic(): + try: + result = self[key] + except KeyError: + if default is Sentinel: + raise + return default + del self[key] + + return result + + def clear(self): + self.model.delete().execute() diff --git a/libs/playhouse/migrate.py b/libs/playhouse/migrate.py new file mode 100644 index 000000000..0abde2123 --- /dev/null +++ b/libs/playhouse/migrate.py @@ -0,0 +1,823 @@ +""" +Lightweight schema migrations. + +NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some +features. + +Example Usage +------------- + +Instantiate a migrator: + + # Postgres example: + my_db = PostgresqlDatabase(...) + migrator = PostgresqlMigrator(my_db) + + # SQLite example: + my_db = SqliteDatabase('my_database.db') + migrator = SqliteMigrator(my_db) + +Then you will use the `migrate` function to run various `Operation`s which +are generated by the migrator: + + migrate( + migrator.add_column('some_table', 'column_name', CharField(default='')) + ) + +Migrations are not run inside a transaction, so if you wish the migration to +run in a transaction you will need to wrap the call to `migrate` in a +transaction block, e.g.: + + with my_db.transaction(): + migrate(...) + +Supported Operations +-------------------- + +Add new field(s) to an existing model: + + # Create your field instances. For non-null fields you must specify a + # default value. + pubdate_field = DateTimeField(null=True) + comment_field = TextField(default='') + + # Run the migration, specifying the database table, field name and field. + migrate( + migrator.add_column('comment_tbl', 'pub_date', pubdate_field), + migrator.add_column('comment_tbl', 'comment', comment_field), + ) + +Renaming a field: + + # Specify the table, original name of the column, and its new name. + migrate( + migrator.rename_column('story', 'pub_date', 'publish_date'), + migrator.rename_column('story', 'mod_date', 'modified_date'), + ) + +Dropping a field: + + migrate( + migrator.drop_column('story', 'some_old_field'), + ) + +Making a field nullable or not nullable: + + # Note that when making a field not null that field must not have any + # NULL values present. + migrate( + # Make `pub_date` allow NULL values. + migrator.drop_not_null('story', 'pub_date'), + + # Prevent `modified_date` from containing NULL values. + migrator.add_not_null('story', 'modified_date'), + ) + +Renaming a table: + + migrate( + migrator.rename_table('story', 'stories_tbl'), + ) + +Adding an index: + + # Specify the table, column names, and whether the index should be + # UNIQUE or not. + migrate( + # Create an index on the `pub_date` column. + migrator.add_index('story', ('pub_date',), False), + + # Create a multi-column index on the `pub_date` and `status` fields. + migrator.add_index('story', ('pub_date', 'status'), False), + + # Create a unique index on the category and title fields. + migrator.add_index('story', ('category_id', 'title'), True), + ) + +Dropping an index: + + # Specify the index name. + migrate(migrator.drop_index('story', 'story_pub_date_status')) + +Adding or dropping table constraints: + +.. code-block:: python + + # Add a CHECK() constraint to enforce the price cannot be negative. + migrate(migrator.add_constraint( + 'products', + 'price_check', + Check('price >= 0'))) + + # Remove the price check constraint. + migrate(migrator.drop_constraint('products', 'price_check')) + + # Add a UNIQUE constraint on the first and last names. + migrate(migrator.add_unique('person', 'first_name', 'last_name')) +""" +from collections import namedtuple +import functools +import hashlib +import re + +from peewee import * +from peewee import CommaNodeList +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import callable_ +from peewee import sort_models +from peewee import _truncate_constraint_name + + +class Operation(object): + """Encapsulate a single schema altering operation.""" + def __init__(self, migrator, method, *args, **kwargs): + self.migrator = migrator + self.method = method + self.args = args + self.kwargs = kwargs + + def execute(self, node): + self.migrator.database.execute(node) + + def _handle_result(self, result): + if isinstance(result, (Node, Context)): + self.execute(result) + elif isinstance(result, Operation): + result.run() + elif isinstance(result, (list, tuple)): + for item in result: + self._handle_result(item) + + def run(self): + kwargs = self.kwargs.copy() + kwargs['with_context'] = True + method = getattr(self.migrator, self.method) + self._handle_result(method(*self.args, **kwargs)) + + +def operation(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + with_context = kwargs.pop('with_context', False) + if with_context: + return fn(self, *args, **kwargs) + return Operation(self, fn.__name__, *args, **kwargs) + return inner + + +def make_index_name(table_name, columns): + index_name = '_'.join((table_name,) + tuple(columns)) + if len(index_name) > 64: + index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() + index_name = '%s_%s' % (index_name[:56], index_hash[:7]) + return index_name + + +class SchemaMigrator(object): + explicit_create_foreign_key = False + explicit_delete_foreign_key = False + + def __init__(self, database): + self.database = database + + def make_context(self): + return self.database.get_sql_context() + + @classmethod + def from_database(cls, database): + if isinstance(database, PostgresqlDatabase): + return PostgresqlMigrator(database) + elif isinstance(database, MySQLDatabase): + return MySQLMigrator(database) + elif isinstance(database, SqliteDatabase): + return SqliteMigrator(database) + raise ValueError('Unsupported database: %s' % database) + + @operation + def apply_default(self, table, column_name, field): + default = field.default + if callable_(default): + default = default() + + return (self.make_context() + .literal('UPDATE ') + .sql(Entity(table)) + .literal(' SET ') + .sql(Expression( + Entity(column_name), + OP.EQ, + field.db_value(default), + flat=True))) + + def _alter_table(self, ctx, table): + return ctx.literal('ALTER TABLE ').sql(Entity(table)) + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' ALTER COLUMN ') + .sql(Entity(column))) + + @operation + def alter_add_column(self, table, column_name, field): + # Make field null at first. + ctx = self.make_context() + field_null, field.null = field.null, True + field.name = field.column_name = column_name + (self + ._alter_table(ctx, table) + .literal(' ADD COLUMN ') + .sql(field.ddl(ctx))) + + field.null = field_null + if isinstance(field, ForeignKeyField): + self.add_inline_fk_sql(ctx, field) + return ctx + + @operation + def add_constraint(self, table, name, constraint): + return (self + ._alter_table(self.make_context(), table) + .literal(' ADD CONSTRAINT ') + .sql(Entity(name)) + .literal(' ') + .sql(constraint)) + + @operation + def add_unique(self, table, *column_names): + constraint_name = 'uniq_%s' % '_'.join(column_names) + constraint = NodeList(( + SQL('UNIQUE'), + EnclosedNodeList([Entity(column) for column in column_names]))) + return self.add_constraint(table, constraint_name, constraint) + + @operation + def drop_constraint(self, table, name): + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP CONSTRAINT ') + .sql(Entity(name))) + + def add_inline_fk_sql(self, ctx, field): + ctx = (ctx + .literal(' REFERENCES ') + .sql(Entity(field.rel_model._meta.table_name)) + .literal(' ') + .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) + if field.on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % field.on_delete) + if field.on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % field.on_update) + return ctx + + @operation + def add_foreign_key_constraint(self, table, column_name, rel, rel_column, + on_delete=None, on_update=None): + constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) + ctx = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(constraint))) + .literal(' FOREIGN KEY ') + .sql(EnclosedNodeList((Entity(column_name),))) + .literal(' REFERENCES ') + .sql(Entity(rel)) + .literal(' (') + .sql(Entity(rel_column)) + .literal(')')) + if on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % on_delete) + if on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % on_update) + return ctx + + @operation + def add_column(self, table, column_name, field): + # Adding a column is complicated by the fact that if there are rows + # present and the field is non-null, then we need to first add the + # column as a nullable field, then set the value, then add a not null + # constraint. + if not field.null and field.default is None: + raise ValueError('%s is not null but has no default' % column_name) + + is_foreign_key = isinstance(field, ForeignKeyField) + if is_foreign_key and not field.rel_field: + raise ValueError('Foreign keys must specify a `field`.') + + operations = [self.alter_add_column(table, column_name, field)] + + # In the event the field is *not* nullable, update with the default + # value and set not null. + if not field.null: + operations.extend([ + self.apply_default(table, column_name, field), + self.add_not_null(table, column_name)]) + + if is_foreign_key and self.explicit_create_foreign_key: + operations.append( + self.add_foreign_key_constraint( + table, + column_name, + field.rel_model._meta.table_name, + field.rel_field.column_name, + field.on_delete, + field.on_update)) + + if field.index or field.unique: + using = getattr(field, 'index_type', None) + operations.append(self.add_index(table, (column_name,), + field.unique, using)) + + return operations + + @operation + def drop_foreign_key_constraint(self, table, column_name): + raise NotImplementedError + + @operation + def drop_column(self, table, column_name, cascade=True): + ctx = self.make_context() + (self._alter_table(ctx, table) + .literal(' DROP COLUMN ') + .sql(Entity(column_name))) + + if cascade: + ctx.literal(' CASCADE') + + fk_columns = [ + foreign_key.column + for foreign_key in self.database.get_foreign_keys(table)] + if column_name in fk_columns and self.explicit_delete_foreign_key: + return [self.drop_foreign_key_constraint(table, column_name), ctx] + + return ctx + + @operation + def rename_column(self, table, old_name, new_name): + return (self + ._alter_table(self.make_context(), table) + .literal(' RENAME COLUMN ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + @operation + def add_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' SET NOT NULL')) + + @operation + def drop_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' DROP NOT NULL')) + + @operation + def rename_table(self, old_name, new_name): + return (self + ._alter_table(self.make_context(), old_name) + .literal(' RENAME TO ') + .sql(Entity(new_name))) + + @operation + def add_index(self, table, columns, unique=False, using=None): + ctx = self.make_context() + index_name = make_index_name(table, columns) + table_obj = Table(table) + cols = [getattr(table_obj.c, column) for column in columns] + index = Index(index_name, table_obj, cols, unique=unique, using=using) + return ctx.sql(index) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name))) + + +class PostgresqlMigrator(SchemaMigrator): + def _primary_key_columns(self, tbl): + query = """ + SELECT pg_attribute.attname + FROM pg_index, pg_class, pg_attribute + WHERE + pg_class.oid = '%s'::regclass AND + indrelid = pg_class.oid AND + pg_attribute.attrelid = pg_class.oid AND + pg_attribute.attnum = any(pg_index.indkey) AND + indisprimary; + """ + cursor = self.database.execute_sql(query % tbl) + return [row[0] for row in cursor.fetchall()] + + @operation + def set_search_path(self, schema_name): + return (self + .make_context() + .literal('SET search_path TO %s' % schema_name)) + + @operation + def rename_table(self, old_name, new_name): + pk_names = self._primary_key_columns(old_name) + ParentClass = super(PostgresqlMigrator, self) + + operations = [ + ParentClass.rename_table(old_name, new_name, with_context=True)] + + if len(pk_names) == 1: + # Check for existence of primary key sequence. + seq_name = '%s_%s_seq' % (old_name, pk_names[0]) + query = """ + SELECT 1 + FROM information_schema.sequences + WHERE LOWER(sequence_name) = LOWER(%s) + """ + cursor = self.database.execute_sql(query, (seq_name,)) + if bool(cursor.fetchone()): + new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) + operations.append(ParentClass.rename_table( + seq_name, new_seq_name)) + + return operations + + +class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', + 'default', 'extra'))): + @property + def is_pk(self): + return self.pk == 'PRI' + + @property + def is_unique(self): + return self.pk == 'UNI' + + @property + def is_null(self): + return self.null == 'YES' + + def sql(self, column_name=None, is_null=None): + if is_null is None: + is_null = self.is_null + if column_name is None: + column_name = self.name + parts = [ + Entity(column_name), + SQL(self.definition)] + if self.is_unique: + parts.append(SQL('UNIQUE')) + if is_null: + parts.append(SQL('NULL')) + else: + parts.append(SQL('NOT NULL')) + if self.is_pk: + parts.append(SQL('PRIMARY KEY')) + if self.extra: + parts.append(SQL(self.extra)) + return NodeList(parts) + + +class MySQLMigrator(SchemaMigrator): + explicit_create_foreign_key = True + explicit_delete_foreign_key = True + + @operation + def rename_table(self, old_name, new_name): + return (self + .make_context() + .literal('RENAME TABLE ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + def _get_column_definition(self, table, column_name): + cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) + rows = cursor.fetchall() + for row in rows: + column = MySQLColumn(*row) + if column.name == column_name: + return column + return False + + def get_foreign_key_constraint(self, table, column_name): + cursor = self.database.execute_sql( + ('SELECT constraint_name ' + 'FROM information_schema.key_column_usage WHERE ' + 'table_schema = DATABASE() AND ' + 'table_name = %s AND ' + 'column_name = %s AND ' + 'referenced_table_name IS NOT NULL AND ' + 'referenced_column_name IS NOT NULL;'), + (table, column_name)) + result = cursor.fetchone() + if not result: + raise AttributeError( + 'Unable to find foreign key constraint for ' + '"%s" on table "%s".' % (table, column_name)) + return result[0] + + @operation + def drop_foreign_key_constraint(self, table, column_name): + fk_constraint = self.get_foreign_key_constraint(table, column_name) + return (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' DROP FOREIGN KEY ') + .sql(Entity(fk_constraint))) + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def add_not_null(self, table, column): + column_def = self._get_column_definition(table, column) + add_not_null = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' MODIFY ') + .sql(column_def.sql(is_null=False))) + + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + if column not in fk_objects: + return add_not_null + + fk_metadata = fk_objects[column] + return (self.drop_foreign_key_constraint(table, column), + add_not_null, + self.add_foreign_key_constraint( + table, + column, + fk_metadata.dest_table, + fk_metadata.dest_column)) + + @operation + def drop_not_null(self, table, column): + column = self._get_column_definition(table, column) + if column.is_pk: + raise ValueError('Primary keys can not be null') + return (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' MODIFY ') + .sql(column.sql(is_null=True))) + + @operation + def rename_column(self, table, old_name, new_name): + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + is_foreign_key = old_name in fk_objects + + column = self._get_column_definition(table, old_name) + rename_ctx = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' CHANGE ') + .sql(Entity(old_name)) + .literal(' ') + .sql(column.sql(column_name=new_name))) + if is_foreign_key: + fk_metadata = fk_objects[old_name] + return [ + self.drop_foreign_key_constraint(table, old_name), + rename_ctx, + self.add_foreign_key_constraint( + table, + new_name, + fk_metadata.dest_table, + fk_metadata.dest_column), + ] + else: + return rename_ctx + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' ON ') + .sql(Entity(table))) + + +class SqliteMigrator(SchemaMigrator): + """ + SQLite supports a subset of ALTER TABLE queries, view the docs for the + full details http://sqlite.org/lang_altertable.html + """ + column_re = re.compile('(.+?)\((.+)\)') + column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') + column_name_re = re.compile('["`\']?([\w]+)') + fk_re = re.compile('FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) + + def _get_column_names(self, table): + res = self.database.execute_sql('select * from "%s" limit 1' % table) + return [item[0] for item in res.description] + + def _get_create_table(self, table): + res = self.database.execute_sql( + ('select name, sql from sqlite_master ' + 'where type=? and LOWER(name)=?'), + ['table', table.lower()]) + return res.fetchone() + + @operation + def _update_column(self, table, column_to_update, fn): + columns = set(column.name.lower() + for column in self.database.get_columns(table)) + if column_to_update.lower() not in columns: + raise ValueError('Column "%s" does not exist on "%s"' % + (column_to_update, table)) + + # Get the SQL used to create the given table. + table, create_table = self._get_create_table(table) + + # Get the indexes and SQL to re-create indexes. + indexes = self.database.get_indexes(table) + + # Find any foreign keys we may need to remove. + self.database.get_foreign_keys(table) + + # Make sure the create_table does not contain any newlines or tabs, + # allowing the regex to work correctly. + create_table = re.sub(r'\s+', ' ', create_table) + + # Parse out the `CREATE TABLE` and column list portions of the query. + raw_create, raw_columns = self.column_re.search(create_table).groups() + + # Clean up the individual column definitions. + split_columns = self.column_split_re.findall(raw_columns) + column_defs = [col.strip() for col in split_columns] + + new_column_defs = [] + new_column_names = [] + original_column_names = [] + constraint_terms = ('foreign ', 'primary ', 'constraint ') + + for column_def in column_defs: + column_name, = self.column_name_re.match(column_def).groups() + + if column_name == column_to_update: + new_column_def = fn(column_name, column_def) + if new_column_def: + new_column_defs.append(new_column_def) + original_column_names.append(column_name) + column_name, = self.column_name_re.match( + new_column_def).groups() + new_column_names.append(column_name) + else: + new_column_defs.append(column_def) + + # Avoid treating constraints as columns. + if not column_def.lower().startswith(constraint_terms): + new_column_names.append(column_name) + original_column_names.append(column_name) + + # Create a mapping of original columns to new columns. + original_to_new = dict(zip(original_column_names, new_column_names)) + new_column = original_to_new.get(column_to_update) + + fk_filter_fn = lambda column_def: column_def + if not new_column: + # Remove any foreign keys associated with this column. + fk_filter_fn = lambda column_def: None + elif new_column != column_to_update: + # Update any foreign keys for this column. + fk_filter_fn = lambda column_def: self.fk_re.sub( + 'FOREIGN KEY ("%s") ' % new_column, + column_def) + + cleaned_columns = [] + for column_def in new_column_defs: + match = self.fk_re.match(column_def) + if match is not None and match.groups()[0] == column_to_update: + column_def = fk_filter_fn(column_def) + if column_def: + cleaned_columns.append(column_def) + + # Update the name of the new CREATE TABLE query. + temp_table = table + '__tmp__' + rgx = re.compile('("?)%s("?)' % table, re.I) + create = rgx.sub( + '\\1%s\\2' % temp_table, + raw_create) + + # Create the new table. + columns = ', '.join(cleaned_columns) + queries = [ + NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), + SQL('%s (%s)' % (create.strip(), columns))] + + # Populate new table. + populate_table = NodeList(( + SQL('INSERT INTO'), + Entity(temp_table), + EnclosedNodeList([Entity(col) for col in new_column_names]), + SQL('SELECT'), + CommaNodeList([Entity(col) for col in original_column_names]), + SQL('FROM'), + Entity(table))) + drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) + + # Drop existing table and rename temp table. + queries += [ + populate_table, + drop_original, + self.rename_table(temp_table, table)] + + # Re-create user-defined indexes. User-defined indexes will have a + # non-empty SQL attribute. + for index in filter(lambda idx: idx.sql, indexes): + if column_to_update not in index.columns: + queries.append(SQL(index.sql)) + elif new_column: + sql = self._fix_index(index.sql, column_to_update, new_column) + if sql is not None: + queries.append(SQL(sql)) + + return queries + + def _fix_index(self, sql, column_to_update, new_column): + # Split on the name of the column to update. If it splits into two + # pieces, then there's no ambiguity and we can simply replace the + # old with the new. + parts = sql.split(column_to_update) + if len(parts) == 2: + return sql.replace(column_to_update, new_column) + + # Find the list of columns in the index expression. + lhs, rhs = sql.rsplit('(', 1) + + # Apply the same "split in two" logic to the column list portion of + # the query. + if len(rhs.split(column_to_update)) == 2: + return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) + + # Strip off the trailing parentheses and go through each column. + parts = rhs.rsplit(')', 1)[0].split(',') + columns = [part.strip('"`[]\' ') for part in parts] + + # `columns` looks something like: ['status', 'timestamp" DESC'] + # https://www.sqlite.org/lang_keywords.html + # Strip out any junk after the column name. + clean = [] + for column in columns: + if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): + column = new_columne + column[len(column_to_update):] + clean.append(column) + + return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) + + @operation + def drop_column(self, table, column_name, cascade=True): + return self._update_column(table, column_name, lambda a, b: None) + + @operation + def rename_column(self, table, old_name, new_name): + def _rename(column_name, column_def): + return column_def.replace(column_name, new_name) + return self._update_column(table, old_name, _rename) + + @operation + def add_not_null(self, table, column): + def _add_not_null(column_name, column_def): + return column_def + ' NOT NULL' + return self._update_column(table, column, _add_not_null) + + @operation + def drop_not_null(self, table, column): + def _drop_not_null(column_name, column_def): + return column_def.replace('NOT NULL', '') + return self._update_column(table, column, _drop_not_null) + + @operation + def add_constraint(self, table, name, constraint): + raise NotImplementedError + + @operation + def drop_constraint(self, table, name): + raise NotImplementedError + + @operation + def add_foreign_key_constraint(self, table, column_name, field, + on_delete=None, on_update=None): + raise NotImplementedError + + +def migrate(*operations, **kwargs): + for operation in operations: + operation.run() diff --git a/libs/playhouse/mysql_ext.py b/libs/playhouse/mysql_ext.py new file mode 100644 index 000000000..8eb2a43fa --- /dev/null +++ b/libs/playhouse/mysql_ext.py @@ -0,0 +1,34 @@ +import json + +try: + import mysql.connector as mysql_connector +except ImportError: + mysql_connector = None + +from peewee import ImproperlyConfigured +from peewee import MySQLDatabase +from peewee import TextField + + +class MySQLConnectorDatabase(MySQLDatabase): + def _connect(self): + if mysql_connector is None: + raise ImproperlyConfigured('MySQL connector not installed!') + return mysql_connector.connect(db=self.database, **self.connect_params) + + def cursor(self, commit=None): + if self.is_closed(): + self.connect() + return self._state.conn.cursor(buffered=True) + + +class JSONField(TextField): + field_type = 'JSON' + + def db_value(self, value): + if value is not None: + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) diff --git a/libs/playhouse/pool.py b/libs/playhouse/pool.py new file mode 100644 index 000000000..9ade1da94 --- /dev/null +++ b/libs/playhouse/pool.py @@ -0,0 +1,317 @@ +""" +Lightweight connection pooling for peewee. + +In a multi-threaded application, up to `max_connections` will be opened. Each +thread (or, if using gevent, greenlet) will have it's own connection. + +In a single-threaded application, only one connection will be created. It will +be continually recycled until either it exceeds the stale timeout or is closed +explicitly (using `.manual_close()`). + +By default, all your application needs to do is ensure that connections are +closed when you are finished with them, and they will be returned to the pool. +For web applications, this typically means that at the beginning of a request, +you will open a connection, and when you return a response, you will close the +connection. + +Simple Postgres pool example code: + + # Use the special postgresql extensions. + from playhouse.pool import PooledPostgresqlExtDatabase + + db = PooledPostgresqlExtDatabase( + 'my_app', + max_connections=32, + stale_timeout=300, # 5 minutes. + user='postgres') + + class BaseModel(Model): + class Meta: + database = db + +That's it! +""" +import heapq +import logging +import time +from collections import namedtuple +from itertools import chain + +try: + from psycopg2.extensions import TRANSACTION_STATUS_IDLE + from psycopg2.extensions import TRANSACTION_STATUS_INERROR + from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN +except ImportError: + TRANSACTION_STATUS_IDLE = \ + TRANSACTION_STATUS_INERROR = \ + TRANSACTION_STATUS_UNKNOWN = None + +from peewee import MySQLDatabase +from peewee import PostgresqlDatabase +from peewee import SqliteDatabase + +logger = logging.getLogger('peewee.pool') + + +def make_int(val): + if val is not None and not isinstance(val, (int, float)): + return int(val) + return val + + +class MaxConnectionsExceeded(ValueError): pass + + +PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', + 'checked_out')) + + +class PooledDatabase(object): + def __init__(self, database, max_connections=20, stale_timeout=None, + timeout=None, **kwargs): + self._max_connections = make_int(max_connections) + self._stale_timeout = make_int(stale_timeout) + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + # Available / idle connections stored in a heap, sorted oldest first. + self._connections = [] + + # Mapping of connection id to PoolConnection. Ordinarily we would want + # to use something like a WeakKeyDictionary, but Python typically won't + # allow us to create weak references to connection objects. + self._in_use = {} + + # Use the memory address of the connection as the key in the event the + # connection object is not hashable. Connections will not get + # garbage-collected, however, because a reference to them will persist + # in "_in_use" as long as the conn has not been closed. + self.conn_key = id + + super(PooledDatabase, self).__init__(database, **kwargs) + + def init(self, database, max_connections=None, stale_timeout=None, + timeout=None, **connect_kwargs): + super(PooledDatabase, self).init(database, **connect_kwargs) + if max_connections is not None: + self._max_connections = make_int(max_connections) + if stale_timeout is not None: + self._stale_timeout = make_int(stale_timeout) + if timeout is not None: + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + def connect(self, reuse_if_open=False): + if not self._wait_timeout: + return super(PooledDatabase, self).connect(reuse_if_open) + + expires = time.time() + self._wait_timeout + while expires > time.time(): + try: + ret = super(PooledDatabase, self).connect(reuse_if_open) + except MaxConnectionsExceeded: + time.sleep(0.1) + else: + return ret + raise MaxConnectionsExceeded('Max connections exceeded, timed out ' + 'attempting to connect.') + + def _connect(self): + while True: + try: + # Remove the oldest connection from the heap. + ts, conn = heapq.heappop(self._connections) + key = self.conn_key(conn) + except IndexError: + ts = conn = None + logger.debug('No connection available in pool.') + break + else: + if self._is_closed(conn): + # This connecton was closed, but since it was not stale + # it got added back to the queue of available conns. We + # then closed it and marked it as explicitly closed, so + # it's safe to throw it away now. + # (Because Database.close() calls Database._close()). + logger.debug('Connection %s was closed.', key) + ts = conn = None + elif self._stale_timeout and self._is_stale(ts): + # If we are attempting to check out a stale connection, + # then close it. We don't need to mark it in the "closed" + # set, because it is not in the list of available conns + # anymore. + logger.debug('Connection %s was stale, closing.', key) + self._close(conn, True) + ts = conn = None + else: + break + + if conn is None: + if self._max_connections and ( + len(self._in_use) >= self._max_connections): + raise MaxConnectionsExceeded('Exceeded maximum connections.') + conn = super(PooledDatabase, self)._connect() + ts = time.time() + key = self.conn_key(conn) + logger.debug('Created new connection %s.', key) + + self._in_use[key] = PoolConnection(ts, conn, time.time()) + return conn + + def _is_stale(self, timestamp): + # Called on check-out and check-in to ensure the connection has + # not outlived the stale timeout. + return (time.time() - timestamp) > self._stale_timeout + + def _is_closed(self, conn): + return False + + def _can_reuse(self, conn): + # Called on check-in to make sure the connection can be re-used. + return True + + def _close(self, conn, close_conn=False): + key = self.conn_key(conn) + if close_conn: + super(PooledDatabase, self)._close(conn) + elif key in self._in_use: + pool_conn = self._in_use.pop(key) + if self._stale_timeout and self._is_stale(pool_conn.timestamp): + logger.debug('Closing stale connection %s.', key) + super(PooledDatabase, self)._close(conn) + elif self._can_reuse(conn): + logger.debug('Returning %s to pool.', key) + heapq.heappush(self._connections, (pool_conn.timestamp, conn)) + else: + logger.debug('Closed %s.', key) + + def manual_close(self): + """ + Close the underlying connection without returning it to the pool. + """ + if self.is_closed(): + return False + + # Obtain reference to the connection in-use by the calling thread. + conn = self.connection() + + # A connection will only be re-added to the available list if it is + # marked as "in use" at the time it is closed. We will explicitly + # remove it from the "in use" list, call "close()" for the + # side-effects, and then explicitly close the connection. + self._in_use.pop(self.conn_key(conn), None) + self.close() + self._close(conn, close_conn=True) + + def close_idle(self): + # Close any open connections that are not currently in-use. + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + self._connections = [] + + def close_stale(self, age=600): + # Close any connections that are in-use but were checked out quite some + # time ago and can be considered stale. + with self._lock: + in_use = {} + cutoff = time.time() - age + n = 0 + for key, pool_conn in self._in_use.items(): + if pool_conn.checked_out < cutoff: + self._close(pool_conn.connection, close_conn=True) + n += 1 + else: + in_use[key] = pool_conn + self._in_use = in_use + return n + + def close_all(self): + # Close all connections -- available and in-use. Warning: may break any + # active connections used by other threads. + self.close() + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + for pool_conn in self._in_use.values(): + self._close(pool_conn.connection, close_conn=True) + self._connections = [] + self._in_use = {} + + +class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): + def _is_closed(self, conn): + try: + conn.ping(False) + except: + return True + else: + return False + + +class _PooledPostgresqlDatabase(PooledDatabase): + def _is_closed(self, conn): + if conn.closed: + return True + + txn_status = conn.get_transaction_status() + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return True + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return False + + def _can_reuse(self, conn): + txn_status = conn.get_transaction_status() + # Do not return connection in an error state, as subsequent queries + # will all fail. If the status is unknown then we lost the connection + # to the server and the connection should not be re-used. + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return False + elif txn_status == TRANSACTION_STATUS_INERROR: + conn.reset() + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return True + +class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): + pass + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase + + class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): + pass +except ImportError: + PooledPostgresqlExtDatabase = None + + +class _PooledSqliteDatabase(PooledDatabase): + def _is_closed(self, conn): + try: + conn.total_changes + except: + return True + else: + return False + +class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): + pass + +try: + from playhouse.sqlite_ext import SqliteExtDatabase + + class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): + pass +except ImportError: + PooledSqliteExtDatabase = None + +try: + from playhouse.sqlite_ext import CSqliteExtDatabase + + class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): + pass +except ImportError: + PooledCSqliteExtDatabase = None diff --git a/libs/playhouse/postgres_ext.py b/libs/playhouse/postgres_ext.py new file mode 100644 index 000000000..6a2893eb5 --- /dev/null +++ b/libs/playhouse/postgres_ext.py @@ -0,0 +1,468 @@ +""" +Collection of postgres-specific extensions, currently including: + +* Support for hstore, a key/value type storage +""" +import logging +import uuid + +from peewee import * +from peewee import ColumnBase +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import SENTINEL +from peewee import __exception_wrapper__ + +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass + +from psycopg2.extras import register_hstore +try: + from psycopg2.extras import Json +except: + Json = None + + +logger = logging.getLogger('peewee') + + +HCONTAINS_DICT = '@>' +HCONTAINS_KEYS = '?&' +HCONTAINS_KEY = '?' +HCONTAINS_ANY_KEY = '?|' +HKEY = '->' +HUPDATE = '||' +ACONTAINS = '@>' +ACONTAINS_ANY = '&&' +TS_MATCH = '@@' +JSONB_CONTAINS = '@>' +JSONB_CONTAINED_BY = '<@' +JSONB_CONTAINS_KEY = '?' +JSONB_CONTAINS_ANY_KEY = '?|' +JSONB_CONTAINS_ALL_KEYS = '?&' +JSONB_EXISTS = '?' +JSONB_REMOVE = '-' + + +class _LookupNode(ColumnBase): + def __init__(self, node, parts): + self.node = node + self.parts = parts + super(_LookupNode, self).__init__() + + def clone(self): + return type(self)(self.node, list(self.parts)) + + +class _JsonLookupBase(_LookupNode): + def __init__(self, node, parts, as_json=False): + super(_JsonLookupBase, self).__init__(node, parts) + self._as_json = as_json + + def clone(self): + return type(self)(self.node, list(self.parts), self._as_json) + + @Node.copy + def as_json(self, as_json=True): + self._as_json = as_json + + def concat(self, rhs): + return Expression(self.as_json(True), OP.CONCAT, Json(rhs)) + + def contains(self, other): + clone = self.as_json(True) + if isinstance(other, (list, dict)): + return Expression(clone, JSONB_CONTAINS, Json(other)) + return Expression(clone, JSONB_EXISTS, other) + + def contains_any(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ANY_KEY, + Value(list(keys), unpack=False)) + + def contains_all(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ALL_KEYS, + Value(list(keys), unpack=False)) + + def has_key(self, key): + return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) + + +class JsonLookup(_JsonLookupBase): + def __getitem__(self, value): + return JsonLookup(self.node, self.parts + [value], self._as_json) + + def __sql__(self, ctx): + ctx.sql(self.node) + for part in self.parts[:-1]: + ctx.literal('->').sql(part) + if self.parts: + (ctx + .literal('->' if self._as_json else '->>') + .sql(self.parts[-1])) + + return ctx + + +class JsonPath(_JsonLookupBase): + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('#>' if self._as_json else '#>>') + .sql(Value('{%s}' % ','.join(map(str, self.parts))))) + + +class ObjectSlice(_LookupNode): + @classmethod + def create(cls, node, value): + if isinstance(value, slice): + parts = [value.start or 0, value.stop or 0] + elif isinstance(value, int): + parts = [value] + else: + parts = map(int, value.split(':')) + return cls(node, parts) + + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('[%s]' % ':'.join(str(p + 1) for p in self.parts))) + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + +class IndexedFieldMixin(object): + default_index_type = 'GIN' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('index', True) # By default, use an index. + super(IndexedFieldMixin, self).__init__(*args, **kwargs) + + +class ArrayField(IndexedFieldMixin, Field): + passthrough = True + + def __init__(self, field_class=IntegerField, field_kwargs=None, + dimensions=1, convert_values=False, *args, **kwargs): + self.__field = field_class(**(field_kwargs or {})) + self.dimensions = dimensions + self.convert_values = convert_values + self.field_type = self.__field.field_type + super(ArrayField, self).__init__(*args, **kwargs) + + def bind(self, model, name, set_attribute=True): + ret = super(ArrayField, self).bind(model, name, set_attribute) + self.__field.bind(model, '__array_%s' % name, False) + return ret + + def ddl_datatype(self, ctx): + data_type = self.__field.ddl_datatype(ctx) + return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') + + def db_value(self, value): + if value is None or isinstance(value, Node): + return value + elif self.convert_values: + return self._process(self.__field.db_value, value, self.dimensions) + else: + return value if isinstance(value, list) else list(value) + + def python_value(self, value): + if self.convert_values and value is not None: + conv = self.__field.python_value + if isinstance(value, list): + return self._process(conv, value, self.dimensions) + else: + return conv(value) + else: + return value + + def _process(self, conv, value, dimensions): + dimensions -= 1 + if dimensions == 0: + return [conv(v) for v in value] + else: + return [self._process(conv, v, dimensions) for v in value] + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, ArrayValue(self, rhs)) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def contains(self, *items): + return Expression(self, ACONTAINS, ArrayValue(self, items)) + + def contains_any(self, *items): + return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) + + +class ArrayValue(Node): + def __init__(self, field, value): + self.field = field + self.value = value + + def __sql__(self, ctx): + return (ctx + .sql(Value(self.value, unpack=False)) + .literal('::') + .sql(self.field.ddl_datatype(ctx))) + + +class DateTimeTZField(DateTimeField): + field_type = 'TIMESTAMPTZ' + + +class HStoreField(IndexedFieldMixin, Field): + field_type = 'HSTORE' + __hash__ = Field.__hash__ + + def __getitem__(self, key): + return Expression(self, HKEY, Value(key)) + + def keys(self): + return fn.akeys(self) + + def values(self): + return fn.avals(self) + + def items(self): + return fn.hstore_to_matrix(self) + + def slice(self, *args): + return fn.slice(self, Value(list(args), unpack=False)) + + def exists(self, key): + return fn.exist(self, key) + + def defined(self, key): + return fn.defined(self, key) + + def update(self, **data): + return Expression(self, HUPDATE, data) + + def delete(self, *keys): + return fn.delete(self, Value(list(keys), unpack=False)) + + def contains(self, value): + if isinstance(value, dict): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_DICT, rhs) + elif isinstance(value, (list, tuple)): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_KEYS, rhs) + return Expression(self, HCONTAINS_KEY, value) + + def contains_any(self, *keys): + return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), + unpack=False)) + + +class JSONField(Field): + field_type = 'JSON' + + def __init__(self, dumps=None, *args, **kwargs): + if Json is None: + raise Exception('Your version of psycopg2 does not support JSON.') + self.dumps = dumps + super(JSONField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if value is None: + return value + if not isinstance(value, Json): + return Json(value, dumps=self.dumps) + return value + + def __getitem__(self, value): + return JsonLookup(self, [value]) + + def path(self, *keys): + return JsonPath(self, keys) + + def concat(self, value): + return super(JSONField, self).concat(Json(value)) + + +def cast_jsonb(node): + return NodeList((node, SQL('::jsonb')), glue='') + + +class BinaryJSONField(IndexedFieldMixin, JSONField): + field_type = 'JSONB' + __hash__ = Field.__hash__ + + def contains(self, other): + if isinstance(other, (list, dict)): + return Expression(self, JSONB_CONTAINS, Json(other)) + return Expression(cast_jsonb(self), JSONB_EXISTS, other) + + def contained_by(self, other): + return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) + + def contains_any(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ANY_KEY, + Value(list(items), unpack=False)) + + def contains_all(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ALL_KEYS, + Value(list(items), unpack=False)) + + def has_key(self, key): + return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) + + def remove(self, *items): + return Expression( + cast_jsonb(self), + JSONB_REMOVE, + Value(list(items), unpack=False)) + + +class TSVectorField(IndexedFieldMixin, TextField): + field_type = 'TSVECTOR' + __hash__ = Field.__hash__ + + def match(self, query, language=None, plain=False): + params = (language, query) if language is not None else (query,) + func = fn.plainto_tsquery if plain else fn.to_tsquery + return Expression(self, TS_MATCH, func(*params)) + + +def Match(field, query, language=None): + params = (language, query) if language is not None else (query,) + field_params = (language, field) if language is not None else (field,) + return Expression( + fn.to_tsvector(*field_params), + TS_MATCH, + fn.to_tsquery(*params)) + + +class IntervalField(Field): + field_type = 'INTERVAL' + + +class FetchManyCursor(object): + __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') + + def __init__(self, cursor, array_size=None): + self.cursor = cursor + self.array_size = array_size or cursor.itersize + self.exhausted = False + self.iterable = self.row_gen() + + @property + def description(self): + return self.cursor.description + + def close(self): + self.cursor.close() + + def row_gen(self): + while True: + rows = self.cursor.fetchmany(self.array_size) + if not rows: + return + for row in rows: + yield row + + def fetchone(self): + if self.exhausted: + return + try: + return next(self.iterable) + except StopIteration: + self.exhausted = True + + +class ServerSideQuery(Node): + def __init__(self, query, array_size=None): + self.query = query + self.array_size = array_size + self._cursor_wrapper = None + + def __sql__(self, ctx): + return self.query.__sql__(ctx) + + def __iter__(self): + if self._cursor_wrapper is None: + self._execute(self.query._database) + return iter(self._cursor_wrapper.iterator()) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self.query, named_cursor=True, + array_size=self.array_size) + self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +def ServerSide(query, database=None, array_size=None): + if database is None: + database = query._database + with database.transaction(): + server_side_query = ServerSideQuery(query, array_size=array_size) + for row in server_side_query: + yield row + + +class _empty_object(object): + __slots__ = () + def __nonzero__(self): + return False + __bool__ = __nonzero__ + +__named_cursor__ = _empty_object() + + +class PostgresqlExtDatabase(PostgresqlDatabase): + def __init__(self, *args, **kwargs): + self._register_hstore = kwargs.pop('register_hstore', False) + self._server_side_cursors = kwargs.pop('server_side_cursors', False) + super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) + + def _connect(self): + conn = super(PostgresqlExtDatabase, self)._connect() + if self._register_hstore: + register_hstore(conn, globally=True) + return conn + + def cursor(self, commit=None): + if self.is_closed(): + self.connect() + if commit is __named_cursor__: + return self._state.conn.cursor(name=str(uuid.uuid1())) + return self._state.conn.cursor() + + def execute(self, query, commit=SENTINEL, named_cursor=False, + array_size=None, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + named_cursor = named_cursor or (self._server_side_cursors and + sql[:6].lower() == 'select') + if named_cursor: + commit = __named_cursor__ + cursor = self.execute_sql(sql, params, commit=commit) + if named_cursor: + cursor = FetchManyCursor(cursor, array_size) + return cursor diff --git a/libs/playhouse/reflection.py b/libs/playhouse/reflection.py new file mode 100644 index 000000000..14ab1ba4b --- /dev/null +++ b/libs/playhouse/reflection.py @@ -0,0 +1,799 @@ +try: + from collections import OrderedDict +except ImportError: + OrderedDict = dict +from collections import namedtuple +from inspect import isclass +import re + +from peewee import * +from peewee import _StringField +from peewee import _query_val_transform +from peewee import CommaNodeList +from peewee import SCOPE_VALUES +from peewee import make_snake_case +from peewee import text_type +try: + from pymysql.constants import FIELD_TYPE +except ImportError: + try: + from MySQLdb.constants import FIELD_TYPE + except ImportError: + FIELD_TYPE = None +try: + from playhouse import postgres_ext +except ImportError: + postgres_ext = None + +RESERVED_WORDS = set([ + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', + 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', + 'return', 'try', 'while', 'with', 'yield', +]) + + +class UnknownField(object): + pass + + +class Column(object): + """ + Store metadata about a database column. + """ + primary_key_types = (IntegerField, AutoField) + + def __init__(self, name, field_class, raw_column_type, nullable, + primary_key=False, column_name=None, index=False, + unique=False, default=None, extra_parameters=None): + self.name = name + self.field_class = field_class + self.raw_column_type = raw_column_type + self.nullable = nullable + self.primary_key = primary_key + self.column_name = column_name + self.index = index + self.unique = unique + self.default = default + self.extra_parameters = extra_parameters + + # Foreign key metadata. + self.rel_model = None + self.related_name = None + self.to_field = None + + def __repr__(self): + attrs = [ + 'field_class', + 'raw_column_type', + 'nullable', + 'primary_key', + 'column_name'] + keyword_args = ', '.join( + '%s=%s' % (attr, getattr(self, attr)) + for attr in attrs) + return 'Column(%s, %s)' % (self.name, keyword_args) + + def get_field_parameters(self): + params = {} + if self.extra_parameters is not None: + params.update(self.extra_parameters) + + # Set up default attributes. + if self.nullable: + params['null'] = True + if self.field_class is ForeignKeyField or self.name != self.column_name: + params['column_name'] = "'%s'" % self.column_name + if self.primary_key and not issubclass(self.field_class, AutoField): + params['primary_key'] = True + if self.default is not None: + params['constraints'] = '[SQL("DEFAULT %s")]' % self.default + + # Handle ForeignKeyField-specific attributes. + if self.is_foreign_key(): + params['model'] = self.rel_model + if self.to_field: + params['field'] = "'%s'" % self.to_field + if self.related_name: + params['backref'] = "'%s'" % self.related_name + + # Handle indexes on column. + if not self.is_primary_key(): + if self.unique: + params['unique'] = 'True' + elif self.index and not self.is_foreign_key(): + params['index'] = 'True' + + return params + + def is_primary_key(self): + return self.field_class is AutoField or self.primary_key + + def is_foreign_key(self): + return self.field_class is ForeignKeyField + + def is_self_referential_fk(self): + return (self.field_class is ForeignKeyField and + self.rel_model == "'self'") + + def set_foreign_key(self, foreign_key, model_names, dest=None, + related_name=None): + self.foreign_key = foreign_key + self.field_class = ForeignKeyField + if foreign_key.dest_table == foreign_key.table: + self.rel_model = "'self'" + else: + self.rel_model = model_names[foreign_key.dest_table] + self.to_field = dest and dest.name or None + self.related_name = related_name or None + + def get_field(self): + # Generate the field definition for this column. + field_params = {} + for key, value in self.get_field_parameters().items(): + if isclass(value) and issubclass(value, Field): + value = value.__name__ + field_params[key] = value + + param_str = ', '.join('%s=%s' % (k, v) + for k, v in sorted(field_params.items())) + field = '%s = %s(%s)' % ( + self.name, + self.field_class.__name__, + param_str) + + if self.field_class is UnknownField: + field = '%s # %s' % (field, self.raw_column_type) + + return field + + +class Metadata(object): + column_map = {} + extension_import = '' + + def __init__(self, database): + self.database = database + self.requires_extension = False + + def execute(self, sql, *params): + return self.database.execute_sql(sql, params) + + def get_columns(self, table, schema=None): + metadata = OrderedDict( + (metadata.name, metadata) + for metadata in self.database.get_columns(table, schema)) + + # Look up the actual column type for each column. + column_types, extra_params = self.get_column_types(table, schema) + + # Look up the primary keys. + pk_names = self.get_primary_keys(table, schema) + if len(pk_names) == 1: + pk = pk_names[0] + if column_types[pk] is IntegerField: + column_types[pk] = AutoField + elif column_types[pk] is BigIntegerField: + column_types[pk] = BigAutoField + + columns = OrderedDict() + for name, column_data in metadata.items(): + field_class = column_types[name] + default = self._clean_default(field_class, column_data.default) + + columns[name] = Column( + name, + field_class=field_class, + raw_column_type=column_data.data_type, + nullable=column_data.null, + primary_key=column_data.primary_key, + column_name=name, + default=default, + extra_parameters=extra_params.get(name)) + + return columns + + def get_column_types(self, table, schema=None): + raise NotImplementedError + + def _clean_default(self, field_class, default): + if default is None or field_class in (AutoField, BigAutoField) or \ + default.lower() == 'null': + return + if issubclass(field_class, _StringField) and \ + isinstance(default, text_type) and not default.startswith("'"): + default = "'%s'" % default + return default or "''" + + def get_foreign_keys(self, table, schema=None): + return self.database.get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + return self.database.get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + return self.database.get_indexes(table, schema) + + +class PostgresqlMetadata(Metadata): + column_map = { + 16: BooleanField, + 17: BlobField, + 20: BigIntegerField, + 21: IntegerField, + 23: IntegerField, + 25: TextField, + 700: FloatField, + 701: FloatField, + 1042: CharField, # blank-padded CHAR + 1043: CharField, + 1082: DateField, + 1114: DateTimeField, + 1184: DateTimeField, + 1083: TimeField, + 1266: TimeField, + 1700: DecimalField, + 2950: TextField, # UUID + } + array_types = { + 1000: BooleanField, + 1001: BlobField, + 1005: SmallIntegerField, + 1007: IntegerField, + 1009: TextField, + 1014: CharField, + 1015: CharField, + 1016: BigIntegerField, + 1115: DateTimeField, + 1182: DateField, + 1183: TimeField, + } + extension_import = 'from playhouse.postgres_ext import *' + + def __init__(self, database): + super(PostgresqlMetadata, self).__init__(database) + + if postgres_ext is not None: + # Attempt to add types like HStore and JSON. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'json': + self.column_map[oid] = postgres_ext.JSONField + elif typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + elif typname == 'hstore': + self.column_map[oid] = postgres_ext.HStoreField + elif typname == 'tsvector': + self.column_map[oid] = postgres_ext.TSVectorField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + def get_column_types(self, table, schema): + column_types = {} + extra_params = {} + extension_types = set(( + postgres_ext.ArrayField, + postgres_ext.BinaryJSONField, + postgres_ext.JSONField, + postgres_ext.TSVectorField, + postgres_ext.HStoreField)) if postgres_ext is not None else set() + + # Look up the actual column type for each column. + identifier = '"%s"."%s"' % (schema, table) + cursor = self.execute('SELECT * FROM %s LIMIT 1' % identifier) + + # Store column metadata in dictionary keyed by column name. + for column_description in cursor.description: + name = column_description.name + oid = column_description.type_code + column_types[name] = self.column_map.get(oid, UnknownField) + if column_types[name] in extension_types: + self.requires_extension = True + if oid in self.array_types: + extra_params[name] = {'field_class': self.array_types[oid]} + + return column_types, extra_params + + def get_columns(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_columns(table, schema) + + def get_foreign_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_indexes(table, schema) + + +class MySQLMetadata(Metadata): + if FIELD_TYPE is None: + column_map = {} + else: + column_map = { + FIELD_TYPE.BLOB: TextField, + FIELD_TYPE.CHAR: CharField, + FIELD_TYPE.DATE: DateField, + FIELD_TYPE.DATETIME: DateTimeField, + FIELD_TYPE.DECIMAL: DecimalField, + FIELD_TYPE.DOUBLE: FloatField, + FIELD_TYPE.FLOAT: FloatField, + FIELD_TYPE.INT24: IntegerField, + FIELD_TYPE.LONG_BLOB: TextField, + FIELD_TYPE.LONG: IntegerField, + FIELD_TYPE.LONGLONG: BigIntegerField, + FIELD_TYPE.MEDIUM_BLOB: TextField, + FIELD_TYPE.NEWDECIMAL: DecimalField, + FIELD_TYPE.SHORT: IntegerField, + FIELD_TYPE.STRING: CharField, + FIELD_TYPE.TIMESTAMP: DateTimeField, + FIELD_TYPE.TIME: TimeField, + FIELD_TYPE.TINY_BLOB: TextField, + FIELD_TYPE.TINY: IntegerField, + FIELD_TYPE.VAR_STRING: CharField, + } + + def __init__(self, database, **kwargs): + if 'password' in kwargs: + kwargs['passwd'] = kwargs.pop('password') + super(MySQLMetadata, self).__init__(database, **kwargs) + + def get_column_types(self, table, schema=None): + column_types = {} + + # Look up the actual column type for each column. + cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) + + # Store column metadata in dictionary keyed by column name. + for column_description in cursor.description: + name, type_code = column_description[:2] + column_types[name] = self.column_map.get(type_code, UnknownField) + + return column_types, {} + + +class SqliteMetadata(Metadata): + column_map = { + 'bigint': BigIntegerField, + 'blob': BlobField, + 'bool': BooleanField, + 'boolean': BooleanField, + 'char': CharField, + 'date': DateField, + 'datetime': DateTimeField, + 'decimal': DecimalField, + 'float': FloatField, + 'integer': IntegerField, + 'integer unsigned': IntegerField, + 'int': IntegerField, + 'long': BigIntegerField, + 'numeric': DecimalField, + 'real': FloatField, + 'smallinteger': IntegerField, + 'smallint': IntegerField, + 'smallint unsigned': IntegerField, + 'text': TextField, + 'time': TimeField, + 'varchar': CharField, + } + + begin = '(?:["\[\(]+)?' + end = '(?:["\]\)]+)?' + re_foreign_key = ( + '(?:FOREIGN KEY\s*)?' + '{begin}(.+?){end}\s+(?:.+\s+)?' + 'references\s+{begin}(.+?){end}' + '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) + re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' + + def _map_col(self, column_type): + raw_column_type = column_type.lower() + if raw_column_type in self.column_map: + field_class = self.column_map[raw_column_type] + elif re.search(self.re_varchar, raw_column_type): + field_class = CharField + else: + column_type = re.sub('\(.+\)', '', raw_column_type) + if column_type == '': + field_class = BareField + else: + field_class = self.column_map.get(column_type, UnknownField) + return field_class + + def get_column_types(self, table, schema=None): + column_types = {} + columns = self.database.get_columns(table) + + for column in columns: + column_types[column.name] = self._map_col(column.data_type) + + return column_types, {} + + +_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( + 'columns', + 'primary_keys', + 'foreign_keys', + 'model_names', + 'indexes')) + + +class DatabaseMetadata(_DatabaseMetadata): + def multi_column_indexes(self, table): + accum = [] + for index in self.indexes[table]: + if len(index.columns) > 1: + field_names = [self.columns[table][column].name + for column in index.columns + if column in self.columns[table]] + accum.append((field_names, index.unique)) + return accum + + def column_indexes(self, table): + accum = {} + for index in self.indexes[table]: + if len(index.columns) == 1: + accum[index.columns[0]] = index.unique + return accum + + +class Introspector(object): + pk_classes = [AutoField, IntegerField] + + def __init__(self, metadata, schema=None): + self.metadata = metadata + self.schema = schema + + def __repr__(self): + return '' % self.metadata.database + + @classmethod + def from_database(cls, database, schema=None): + if isinstance(database, PostgresqlDatabase): + metadata = PostgresqlMetadata(database) + elif isinstance(database, MySQLDatabase): + metadata = MySQLMetadata(database) + elif isinstance(database, SqliteDatabase): + metadata = SqliteMetadata(database) + else: + raise ValueError('Introspection not supported for %r' % database) + return cls(metadata, schema=schema) + + def get_database_class(self): + return type(self.metadata.database) + + def get_database_name(self): + return self.metadata.database.database + + def get_database_kwargs(self): + return self.metadata.database.connect_params + + def get_additional_imports(self): + if self.metadata.requires_extension: + return '\n' + self.metadata.extension_import + return '' + + def make_model_name(self, table, snake_case=True): + if snake_case: + table = make_snake_case(table) + model = re.sub('[^\w]+', '', table) + model_name = ''.join(sub.title() for sub in model.split('_')) + if not model_name[0].isalpha(): + model_name = 'T' + model_name + return model_name + + def make_column_name(self, column, is_foreign_key=False, snake_case=True): + column = column.strip() + if snake_case: + column = make_snake_case(column) + column = column.lower() + if is_foreign_key: + # Strip "_id" from foreign keys, unless the foreign-key happens to + # be named "_id", in which case the name is retained. + column = re.sub('_id$', '', column) or column + + # Remove characters that are invalid for Python identifiers. + column = re.sub('[^\w]+', '_', column) + if column in RESERVED_WORDS: + column += '_' + if len(column) and column[0].isdigit(): + column = '_' + column + return column + + def introspect(self, table_names=None, literal_column_names=False, + include_views=False, snake_case=True): + # Retrieve all the tables in the database. + tables = self.metadata.database.get_tables(schema=self.schema) + if include_views: + views = self.metadata.database.get_views(schema=self.schema) + tables.extend([view.name for view in views]) + + if table_names is not None: + tables = [table for table in tables if table in table_names] + table_set = set(tables) + + # Store a mapping of table name -> dictionary of columns. + columns = {} + + # Store a mapping of table name -> set of primary key columns. + primary_keys = {} + + # Store a mapping of table -> foreign keys. + foreign_keys = {} + + # Store a mapping of table name -> model name. + model_names = {} + + # Store a mapping of table name -> indexes. + indexes = {} + + # Gather the columns for each table. + for table in tables: + table_indexes = self.metadata.get_indexes(table, self.schema) + table_columns = self.metadata.get_columns(table, self.schema) + try: + foreign_keys[table] = self.metadata.get_foreign_keys( + table, self.schema) + except ValueError as exc: + err(*exc.args) + foreign_keys[table] = [] + else: + # If there is a possibility we could exclude a dependent table, + # ensure that we introspect it so FKs will work. + if table_names is not None: + for foreign_key in foreign_keys[table]: + if foreign_key.dest_table not in table_set: + tables.append(foreign_key.dest_table) + table_set.add(foreign_key.dest_table) + + model_names[table] = self.make_model_name(table, snake_case) + + # Collect sets of all the column names as well as all the + # foreign-key column names. + lower_col_names = set(column_name.lower() + for column_name in table_columns) + fks = set(fk_col.column for fk_col in foreign_keys[table]) + + for col_name, column in table_columns.items(): + if literal_column_names: + new_name = re.sub('[^\w]+', '_', col_name) + else: + new_name = self.make_column_name(col_name, col_name in fks, + snake_case) + + # If we have two columns, "parent" and "parent_id", ensure + # that when we don't introduce naming conflicts. + lower_name = col_name.lower() + if lower_name.endswith('_id') and new_name in lower_col_names: + new_name = col_name.lower() + + column.name = new_name + + for index in table_indexes: + if len(index.columns) == 1: + column = index.columns[0] + if column in table_columns: + table_columns[column].unique = index.unique + table_columns[column].index = True + + primary_keys[table] = self.metadata.get_primary_keys( + table, self.schema) + columns[table] = table_columns + indexes[table] = table_indexes + + # Gather all instances where we might have a `related_name` conflict, + # either due to multiple FKs on a table pointing to the same table, + # or a related_name that would conflict with an existing field. + related_names = {} + sort_fn = lambda foreign_key: foreign_key.column + for table in tables: + models_referenced = set() + for foreign_key in sorted(foreign_keys[table], key=sort_fn): + try: + column = columns[table][foreign_key.column] + except KeyError: + continue + + dest_table = foreign_key.dest_table + if dest_table in models_referenced: + related_names[column] = '%s_%s_set' % ( + dest_table, + column.name) + else: + models_referenced.add(dest_table) + + # On the second pass convert all foreign keys. + for table in tables: + for foreign_key in foreign_keys[table]: + src = columns[foreign_key.table][foreign_key.column] + try: + dest = columns[foreign_key.dest_table][ + foreign_key.dest_column] + except KeyError: + dest = None + + src.set_foreign_key( + foreign_key=foreign_key, + model_names=model_names, + dest=dest, + related_name=related_names.get(src)) + + return DatabaseMetadata( + columns, + primary_keys, + foreign_keys, + model_names, + indexes) + + def generate_models(self, skip_invalid=False, table_names=None, + literal_column_names=False, bare_fields=False, + include_views=False): + database = self.introspect(table_names, literal_column_names, + include_views) + models = {} + + class BaseModel(Model): + class Meta: + database = self.metadata.database + schema = self.schema + + def _create_model(table, models): + for foreign_key in database.foreign_keys[table]: + dest = foreign_key.dest_table + + if dest not in models and dest != table: + _create_model(dest, models) + + primary_keys = [] + columns = database.columns[table] + for column_name, column in columns.items(): + if column.primary_key: + primary_keys.append(column.name) + + multi_column_indexes = database.multi_column_indexes(table) + column_indexes = database.column_indexes(table) + + class Meta: + indexes = multi_column_indexes + table_name = table + + # Fix models with multi-column primary keys. + composite_key = False + if len(primary_keys) == 0: + primary_keys = columns.keys() + if len(primary_keys) > 1: + Meta.primary_key = CompositeKey(*[ + field.name for col, field in columns.items() + if col in primary_keys]) + composite_key = True + + attrs = {'Meta': Meta} + for column_name, column in columns.items(): + FieldClass = column.field_class + if FieldClass is not ForeignKeyField and bare_fields: + FieldClass = BareField + elif FieldClass is UnknownField: + FieldClass = BareField + + params = { + 'column_name': column_name, + 'null': column.nullable} + if column.primary_key and composite_key: + if FieldClass is AutoField: + FieldClass = IntegerField + params['primary_key'] = False + elif column.primary_key and FieldClass is not AutoField: + params['primary_key'] = True + if column.is_foreign_key(): + if column.is_self_referential_fk(): + params['model'] = 'self' + else: + dest_table = column.foreign_key.dest_table + params['model'] = models[dest_table] + if column.to_field: + params['field'] = column.to_field + + # Generate a unique related name. + params['backref'] = '%s_%s_rel' % (table, column_name) + + if column.default is not None: + constraint = SQL('DEFAULT %s' % column.default) + params['constraints'] = [constraint] + + if column_name in column_indexes and not \ + column.is_primary_key(): + if column_indexes[column_name]: + params['unique'] = True + elif not column.is_foreign_key(): + params['index'] = True + + attrs[column.name] = FieldClass(**params) + + try: + models[table] = type(str(table), (BaseModel,), attrs) + except ValueError: + if not skip_invalid: + raise + + # Actually generate Model classes. + for table, model in sorted(database.model_names.items()): + if table not in models: + _create_model(table, models) + + return models + + +def introspect(database, schema=None): + introspector = Introspector.from_database(database, schema=schema) + return introspector.introspect() + + +def generate_models(database, schema=None, **options): + introspector = Introspector.from_database(database, schema=schema) + return introspector.generate_models(**options) + + +def print_model(model, indexes=True, inline_indexes=False): + print(model._meta.name) + for field in model._meta.sorted_fields: + parts = [' %s %s' % (field.name, field.field_type)] + if field.primary_key: + parts.append(' PK') + elif inline_indexes: + if field.unique: + parts.append(' UNIQUE') + elif field.index: + parts.append(' INDEX') + if isinstance(field, ForeignKeyField): + parts.append(' FK: %s.%s' % (field.rel_model.__name__, + field.rel_field.name)) + print(''.join(parts)) + + if indexes: + index_list = model._meta.fields_to_index() + if not index_list: + return + + print('\nindex(es)') + for index in index_list: + parts = [' '] + ctx = model._meta.database.get_sql_context() + with ctx.scope_values(param='%s', quote='""'): + ctx.sql(CommaNodeList(index._expressions)) + if index._where: + ctx.literal(' WHERE ') + ctx.sql(index._where) + sql, params = ctx.query() + + clean = sql % tuple(map(_query_val_transform, params)) + parts.append(clean.replace('"', '')) + + if index._unique: + parts.append(' UNIQUE') + print(''.join(parts)) + + +def get_table_sql(model): + sql, params = model._schema._create_table().query() + if model._meta.database.param != '%s': + sql = sql.replace(model._meta.database.param, '%s') + + # Format and indent the table declaration, simplest possible approach. + match_obj = re.match('^(.+?\()(.+)(\).*)', sql) + create, columns, extra = match_obj.groups() + indented = ',\n'.join(' %s' % column for column in columns.split(', ')) + + clean = '\n'.join((create, indented, extra)).strip() + return clean % tuple(map(_query_val_transform, params)) + +def print_table_sql(model): + print(get_table_sql(model)) diff --git a/libs/playhouse/shortcuts.py b/libs/playhouse/shortcuts.py new file mode 100644 index 000000000..e1851b181 --- /dev/null +++ b/libs/playhouse/shortcuts.py @@ -0,0 +1,224 @@ +from peewee import * +from peewee import Alias +from peewee import SENTINEL +from peewee import callable_ + + +_clone_set = lambda s: set(s) if s else set() + + +def model_to_dict(model, recurse=True, backrefs=False, only=None, + exclude=None, seen=None, extra_attrs=None, + fields_from_query=None, max_depth=None, manytomany=False): + """ + Convert a model instance (and any related objects) to a dictionary. + + :param bool recurse: Whether foreign-keys should be recursed. + :param bool backrefs: Whether lists of related objects should be recursed. + :param only: A list (or set) of field instances indicating which fields + should be included. + :param exclude: A list (or set) of field instances that should be + excluded from the dictionary. + :param list extra_attrs: Names of model instance attributes or methods + that should be included. + :param SelectQuery fields_from_query: Query that was source of model. Take + fields explicitly selected by the query and serialize them. + :param int max_depth: Maximum depth to recurse, value <= 0 means no max. + :param bool manytomany: Process many-to-many fields. + """ + max_depth = -1 if max_depth is None else max_depth + if max_depth == 0: + recurse = False + + only = _clone_set(only) + extra_attrs = _clone_set(extra_attrs) + should_skip = lambda n: (n in exclude) or (only and (n not in only)) + + if fields_from_query is not None: + for item in fields_from_query._returning: + if isinstance(item, Field): + only.add(item) + elif isinstance(item, Alias): + extra_attrs.add(item._alias) + + data = {} + exclude = _clone_set(exclude) + seen = _clone_set(seen) + exclude |= seen + model_class = type(model) + + if manytomany: + for name, m2m in model._meta.manytomany.items(): + if should_skip(name): + continue + + exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) + for fkf in m2m.through_model._meta.refs: + exclude.add(fkf) + + accum = [] + for rel_obj in getattr(model, name): + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + data[name] = accum + + for field in model._meta.sorted_fields: + if should_skip(field): + continue + + field_data = model.__data__.get(field.name) + if isinstance(field, ForeignKeyField) and recurse: + if field_data: + seen.add(field) + rel_obj = getattr(model, field.name) + field_data = model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + seen=seen, + max_depth=max_depth - 1) + else: + field_data = None + + data[field.name] = field_data + + if extra_attrs: + for attr_name in extra_attrs: + attr = getattr(model, attr_name) + if callable_(attr): + data[attr_name] = attr() + else: + data[attr_name] = attr + + if backrefs and recurse: + for foreign_key, rel_model in model._meta.backrefs.items(): + if foreign_key.backref == '+': continue + descriptor = getattr(model_class, foreign_key.backref) + if descriptor in exclude or foreign_key in exclude: + continue + if only and (descriptor not in only) and (foreign_key not in only): + continue + + accum = [] + exclude.add(foreign_key) + related_query = getattr(model, foreign_key.backref) + + for rel_obj in related_query: + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + + data[foreign_key.backref] = accum + + return data + + +def update_model_from_dict(instance, data, ignore_unknown=False): + meta = instance._meta + backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) + + for key, value in data.items(): + if key in meta.combined: + field = meta.combined[key] + is_backref = False + elif key in backrefs: + field = backrefs[key] + is_backref = True + elif ignore_unknown: + setattr(instance, key, value) + continue + else: + raise AttributeError('Unrecognized attribute "%s" for model ' + 'class %s.' % (key, type(instance))) + + is_foreign_key = isinstance(field, ForeignKeyField) + + if not is_backref and is_foreign_key and isinstance(value, dict): + try: + rel_instance = instance.__rel__[field.name] + except KeyError: + rel_instance = field.rel_model() + setattr( + instance, + field.name, + update_model_from_dict(rel_instance, value, ignore_unknown)) + elif is_backref and isinstance(value, (list, tuple)): + instances = [ + dict_to_model(field.model, row_data, ignore_unknown) + for row_data in value] + for rel_instance in instances: + setattr(rel_instance, field.name, instance) + setattr(instance, field.backref, instances) + else: + setattr(instance, field.name, value) + + return instance + + +def dict_to_model(model_class, data, ignore_unknown=False): + return update_model_from_dict(model_class(), data, ignore_unknown) + + +class ReconnectMixin(object): + """ + Mixin class that attempts to automatically reconnect to the database under + certain error conditions. + + For example, MySQL servers will typically close connections that are idle + for 28800 seconds ("wait_timeout" setting). If your application makes use + of long-lived connections, you may find your connections are closed after + a period of no activity. This mixin will attempt to reconnect automatically + when these errors occur. + + This mixin class probably should not be used with Postgres (unless you + REALLY know what you are doing) and definitely has no business being used + with Sqlite. If you wish to use with Postgres, you will need to adapt the + `reconnect_errors` attribute to something appropriate for Postgres. + """ + reconnect_errors = ( + # Error class, error message fragment (or empty string for all). + (OperationalError, '2006'), # MySQL server has gone away. + (OperationalError, '2013'), # Lost connection to MySQL server. + (OperationalError, '2014'), # Commands out of sync. + ) + + def __init__(self, *args, **kwargs): + super(ReconnectMixin, self).__init__(*args, **kwargs) + + # Normalize the reconnect errors to a more efficient data-structure. + self._reconnect_errors = {} + for exc_class, err_fragment in self.reconnect_errors: + self._reconnect_errors.setdefault(exc_class, []) + self._reconnect_errors[exc_class].append(err_fragment.lower()) + + def execute_sql(self, sql, params=None, commit=SENTINEL): + try: + return super(ReconnectMixin, self).execute_sql(sql, params, commit) + except Exception as exc: + exc_class = type(exc) + if exc_class not in self._reconnect_errors: + raise exc + + exc_repr = str(exc).lower() + for err_fragment in self._reconnect_errors[exc_class]: + if err_fragment in exc_repr: + break + else: + raise exc + + if not self.is_closed(): + self.close() + self.connect() + + return super(ReconnectMixin, self).execute_sql(sql, params, commit) diff --git a/libs/playhouse/signals.py b/libs/playhouse/signals.py new file mode 100644 index 000000000..f070bdfdb --- /dev/null +++ b/libs/playhouse/signals.py @@ -0,0 +1,79 @@ +""" +Provide django-style hooks for model events. +""" +from peewee import Model as _Model + + +class Signal(object): + def __init__(self): + self._flush() + + def _flush(self): + self._receivers = set() + self._receiver_list = [] + + def connect(self, receiver, name=None, sender=None): + name = name or receiver.__name__ + key = (name, sender) + if key not in self._receivers: + self._receivers.add(key) + self._receiver_list.append((name, receiver, sender)) + else: + raise ValueError('receiver named %s (for sender=%s) already ' + 'connected' % (name, sender or 'any')) + + def disconnect(self, receiver=None, name=None, sender=None): + if receiver: + name = name or receiver.__name__ + if not name: + raise ValueError('a receiver or a name must be provided') + + key = (name, sender) + if key not in self._receivers: + raise ValueError('receiver named %s for sender=%s not found.' % + (name, sender or 'any')) + + self._receivers.remove(key) + self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list + if n != name and s != sender] + + def __call__(self, name=None, sender=None): + def decorator(fn): + self.connect(fn, name, sender) + return fn + return decorator + + def send(self, instance, *args, **kwargs): + sender = type(instance) + responses = [] + for n, r, s in self._receiver_list: + if s is None or isinstance(instance, s): + responses.append((r, r(sender, instance, *args, **kwargs))) + return responses + + +pre_save = Signal() +post_save = Signal() +pre_delete = Signal() +post_delete = Signal() +pre_init = Signal() + + +class Model(_Model): + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + pre_init.send(self) + + def save(self, *args, **kwargs): + pk_value = self._pk + created = kwargs.get('force_insert', False) or not bool(pk_value) + pre_save.send(self, created=created) + ret = super(Model, self).save(*args, **kwargs) + post_save.send(self, created=created) + return ret + + def delete_instance(self, *args, **kwargs): + pre_delete.send(self) + ret = super(Model, self).delete_instance(*args, **kwargs) + post_delete.send(self) + return ret diff --git a/libs/playhouse/sqlcipher_ext.py b/libs/playhouse/sqlcipher_ext.py new file mode 100644 index 000000000..9bad1eca6 --- /dev/null +++ b/libs/playhouse/sqlcipher_ext.py @@ -0,0 +1,103 @@ +""" +Peewee integration with pysqlcipher. + +Project page: https://github.com/leapcode/pysqlcipher/ + +**WARNING!!! EXPERIMENTAL!!!** + +* Although this extention's code is short, it has not been properly + peer-reviewed yet and may have introduced vulnerabilities. + +Also note that this code relies on pysqlcipher and sqlcipher, and +the code there might have vulnerabilities as well, but since these +are widely used crypto modules, we can expect "short zero days" there. + +Example usage: + + from peewee.playground.ciphersql_ext import SqlCipherDatabase + db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") + +* `passphrase`: should be "long enough". + Note that *length beats vocabulary* (much exponential), and even + a lowercase-only passphrase like easytorememberyethardforotherstoguess + packs more noise than 8 random printable characters and *can* be memorized. + +When opening an existing database, passphrase should be the one used when the +database was created. If the passphrase is incorrect, an exception will only be +raised **when you access the database**. + +If you need to ask for an interactive passphrase, here's example code you can +put after the `db = ...` line: + + try: # Just access the database so that it checks the encryption. + db.get_tables() + # We're looking for a DatabaseError with a specific error message. + except peewee.DatabaseError as e: + # Check whether the message *means* "passphrase is wrong" + if e.args[0] == 'file is encrypted or is not a database': + raise Exception('Developer should Prompt user for passphrase ' + 'again.') + else: + # A different DatabaseError. Raise it. + raise e + +See a more elaborate example with this code at +https://gist.github.com/thedod/11048875 +""" +import datetime +import decimal +import sys + +from peewee import * +from playhouse.sqlite_ext import SqliteExtDatabase +if sys.version_info[0] != 3: + from pysqlcipher import dbapi2 as sqlcipher +else: + try: + from sqlcipher3 import dbapi2 as sqlcipher + except ImportError: + from pysqlcipher3 import dbapi2 as sqlcipher + +sqlcipher.register_adapter(decimal.Decimal, str) +sqlcipher.register_adapter(datetime.date, str) +sqlcipher.register_adapter(datetime.time, str) + + +class _SqlCipherDatabase(object): + def _connect(self): + params = dict(self.connect_params) + passphrase = params.pop('passphrase', '').replace("'", "''") + + conn = sqlcipher.connect(self.database, isolation_level=None, **params) + try: + if passphrase: + conn.execute("PRAGMA key='%s'" % passphrase) + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def set_passphrase(self, passphrase): + if not self.is_closed(): + raise ImproperlyConfigured('Cannot set passphrase when database ' + 'is open. To change passphrase of an ' + 'open database use the rekey() method.') + + self.connect_params['passphrase'] = passphrase + + def rekey(self, passphrase): + if self.is_closed(): + self.connect() + + self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) + self.connect_params['passphrase'] = passphrase + return True + + +class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): + pass + + +class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): + pass diff --git a/libs/playhouse/sqlite_ext.py b/libs/playhouse/sqlite_ext.py new file mode 100644 index 000000000..c97cbd252 --- /dev/null +++ b/libs/playhouse/sqlite_ext.py @@ -0,0 +1,1261 @@ +import json +import math +import re +import struct +import sys + +from peewee import * +from peewee import ColumnBase +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import VirtualField +from peewee import merge_dict +from peewee import sqlite3 +try: + from playhouse._sqlite_ext import ( + backup, + backup_to_file, + Blob, + ConnectionHelper, + register_bloomfilter, + register_hash_functions, + register_rank_functions, + sqlite_get_db_status, + sqlite_get_status, + TableFunction, + ZeroBlob, + ) + CYTHON_SQLITE_EXTENSIONS = True +except ImportError: + CYTHON_SQLITE_EXTENSIONS = False + + +if sys.version_info[0] == 3: + basestring = str + + +FTS3_MATCHINFO = 'pcx' +FTS4_MATCHINFO = 'pcnalx' +if sqlite3 is not None: + FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 +else: + FTS_VERSION = 3 + +FTS5_MIN_SQLITE_VERSION = (3, 9, 0) + + +class RowIDField(AutoField): + auto_increment = True + column_name = name = required_name = 'rowid' + + def bind(self, model, name, *args): + if name != self.required_name: + raise ValueError('%s must be named "%s".' % + (type(self), self.required_name)) + super(RowIDField, self).bind(model, name, *args) + + +class DocIDField(RowIDField): + column_name = name = required_name = 'docid' + + +class AutoIncrementField(AutoField): + def ddl(self, ctx): + node_list = super(AutoIncrementField, self).ddl(ctx) + return NodeList((node_list, SQL('AUTOINCREMENT'))) + + +class JSONPath(ColumnBase): + def __init__(self, field, path=None): + super(JSONPath, self).__init__() + self._field = field + self._path = path or () + + @property + def path(self): + return Value('$%s' % ''.join(self._path)) + + def __getitem__(self, idx): + if isinstance(idx, int): + item = '[%s]' % idx + else: + item = '.%s' % idx + return JSONPath(self._field, self._path + (item,)) + + def set(self, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.json(self._field._json_dumps(value)) + return fn.json_set(self._field, self.path, value) + + def update(self, value): + return self.set(fn.json_patch(self, self._field._json_dumps(value))) + + def remove(self): + return fn.json_remove(self._field, self.path) + + def json_type(self): + return fn.json_type(self._field, self.path) + + def length(self): + return fn.json_array_length(self._field, self.path) + + def children(self): + return fn.json_each(self._field, self.path) + + def tree(self): + return fn.json_tree(self._field, self.path) + + def __sql__(self, ctx): + return ctx.sql(fn.json_extract(self._field, self.path) + if self._path else self._field) + + +class JSONField(TextField): + field_type = 'JSON' + + def __init__(self, json_dumps=None, json_loads=None, **kwargs): + self._json_dumps = json_dumps or json.dumps + self._json_loads = json_loads or json.loads + super(JSONField, self).__init__(**kwargs) + + def python_value(self, value): + if value is not None: + try: + return self._json_loads(value) + except (TypeError, ValueError): + return value + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = fn.json(self._json_dumps(value)) + return value + + def _e(op): + def inner(self, rhs): + if isinstance(rhs, (list, dict)): + rhs = Value(rhs, converter=self.db_value, unpack=False) + return Expression(self, op, rhs) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def __getitem__(self, item): + return JSONPath(self)[item] + + def set(self, value, as_json=None): + return JSONPath(self).set(value, as_json) + + def update(self, data): + return JSONPath(self).update(data) + + def remove(self): + return JSONPath(self).remove() + + def json_type(self): + return fn.json_type(self) + + def length(self): + return fn.json_array_length(self) + + def children(self): + """ + Schema of `json_each` and `json_tree`: + + key, + value, + type TEXT (object, array, string, etc), + atom (value for primitive/scalar types, NULL for array and object) + id INTEGER (unique identifier for element) + parent INTEGER (unique identifier of parent element or NULL) + fullkey TEXT (full path describing element) + path TEXT (path to the container of the current element) + json JSON hidden (1st input parameter to function) + root TEXT hidden (2nd input parameter, path at which to start) + """ + return fn.json_each(self) + + def tree(self): + return fn.json_tree(self) + + +class SearchField(Field): + def __init__(self, unindexed=False, column_name=None, **k): + if k: + raise ValueError('SearchField does not accept these keyword ' + 'arguments: %s.' % sorted(k)) + super(SearchField, self).__init__(unindexed=unindexed, + column_name=column_name, null=True) + + def match(self, term): + return match(self, term) + + +class VirtualTableSchemaManager(SchemaManager): + def _create_virtual_table(self, safe=True, **options): + options = self.model.clean_options( + merge_dict(self.model._meta.options, options)) + + # Structure: + # CREATE VIRTUAL TABLE + # USING + # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) + ctx = self._create_context() + ctx.literal('CREATE VIRTUAL TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + (ctx + .sql(self.model) + .literal(' USING ')) + + ext_module = self.model._meta.extension_module + if isinstance(ext_module, Node): + return ctx.sql(ext_module) + + ctx.sql(SQL(ext_module)).literal(' ') + arguments = [] + meta = self.model._meta + + if meta.prefix_arguments: + arguments.extend([SQL(a) for a in meta.prefix_arguments]) + + # Constraints, data-types, foreign and primary keys are all omitted. + for field in meta.sorted_fields: + if isinstance(field, (RowIDField)) or field._hidden: + continue + field_def = [Entity(field.column_name)] + if field.unindexed: + field_def.append(SQL('UNINDEXED')) + arguments.append(NodeList(field_def)) + + if meta.arguments: + arguments.extend([SQL(a) for a in meta.arguments]) + + if options: + arguments.extend(self._create_table_option_sql(options)) + return ctx.sql(EnclosedNodeList(arguments)) + + def _create_table(self, safe=True, **options): + if issubclass(self.model, VirtualModel): + return self._create_virtual_table(safe, **options) + + return super(VirtualTableSchemaManager, self)._create_table( + safe, **options) + + +class VirtualModel(Model): + class Meta: + arguments = None + extension_module = None + prefix_arguments = None + primary_key = False + schema_manager_class = VirtualTableSchemaManager + + @classmethod + def clean_options(cls, options): + return options + + +class BaseFTSModel(VirtualModel): + @classmethod + def clean_options(cls, options): + content = options.get('content') + prefix = options.get('prefix') + tokenize = options.get('tokenize') + + if isinstance(content, basestring) and content == '': + # Special-case content-less full-text search tables. + options['content'] = "''" + elif isinstance(content, Field): + # Special-case to ensure fields are fully-qualified. + options['content'] = Entity(content.model._meta.table_name, + content.column_name) + + if prefix: + if isinstance(prefix, (list, tuple)): + prefix = ','.join([str(i) for i in prefix]) + options['prefix'] = "'%s'" % prefix.strip("' ") + + if tokenize and cls._meta.extension_module.lower() == 'fts5': + # Tokenizers need to be in quoted string for FTS5, but not for FTS3 + # or FTS4. + options['tokenize'] = '"%s"' % tokenize + + return options + + +class FTSModel(BaseFTSModel): + """ + VirtualModel class for creating tables that use either the FTS3 or FTS4 + search extensions. Peewee automatically determines which version of the + FTS extension is supported and will use FTS4 if possible. + """ + # FTS3/4 uses "docid" in the same way a normal table uses "rowid". + docid = DocIDField() + + class Meta: + extension_module = 'FTS%s' % FTS_VERSION + + @classmethod + def _fts_cmd(cls, cmd): + tbl = cls._meta.table_name + res = cls._meta.database.execute_sql( + "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) + return res.fetchone() + + @classmethod + def optimize(cls): + return cls._fts_cmd('optimize') + + @classmethod + def rebuild(cls): + return cls._fts_cmd('rebuild') + + @classmethod + def integrity_check(cls): + return cls._fts_cmd('integrity-check') + + @classmethod + def merge(cls, blocks=200, segments=8): + return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) + + @classmethod + def automerge(cls, state=True): + return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *weights): + matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) + return fn.fts_rank(matchinfo, *weights) + + @classmethod + def bm25(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25(match_info, *weights) + + @classmethod + def bm25f(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25f(match_info, *weights) + + @classmethod + def lucene(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_lucene(match_info, *weights) + + @classmethod + def _search(cls, term, weights, with_score, score_alias, score_fn, + explicit_ordering): + if not weights: + rank = score_fn() + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + # Attempt to get the specified weight of the field by looking + # it up using it's field instance followed by name. + field_weight = weights.get(field, weights.get(field.name, 1.0)) + weight_args.append(field_weight) + rank = score_fn(*weight_args) + else: + rank = score_fn(*weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(term)) + .order_by(order_by)) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.rank, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25, + explicit_ordering) + + @classmethod + def search_bm25f(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25f, + explicit_ordering) + + @classmethod + def search_lucene(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.lucene, + explicit_ordering) + + +_alphabet = 'abcdefghijklmnopqrstuvwxyz' +_alphanum = (set('\t ,"(){}*:_+0123456789') | + set(_alphabet) | + set(_alphabet.upper()) | + set((chr(26),))) +_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) +_quote_re = re.compile('(?:[^\s"]|"(?:\\.|[^"])*")+') + + +class FTS5Model(BaseFTSModel): + """ + Requires SQLite >= 3.9.0. + + Table options: + + content: table name of external content, or empty string for "contentless" + content_rowid: column name of external content primary key + prefix: integer(s). Ex: '2' or '2 3 4' + tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' + + The unicode tokenizer supports the following parameters: + + * remove_diacritics (1 or 0, default is 1) + * tokenchars (string of characters, e.g. '-_' + * separators (string of characters) + + Parameters are passed as alternating parameter name and value, so: + + {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} + + Content-less tables: + + If you don't need the full-text content in it's original form, you can + specify a content-less table. Searches and auxiliary functions will work + as usual, but the only values returned when SELECT-ing can be rowid. Also + content-less tables do not support UPDATE or DELETE. + + External content tables: + + You can set up triggers to sync these, e.g. + + -- Create a table. And an external content fts5 table to index it. + CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); + CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); + + -- Triggers to keep the FTS index up to date. + CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + END; + CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + + Built-in auxiliary functions: + + * bm25(tbl[, weight_0, ... weight_n]) + * highlight(tbl, col_idx, prefix, suffix) + * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) + """ + # FTS5 does not support declared primary keys, but we can use the + # implicit rowid. + rowid = RowIDField() + + class Meta: + extension_module = 'fts5' + + _error_messages = { + 'field_type': ('Besides the implicit `rowid` column, all columns must ' + 'be instances of SearchField'), + 'index': 'Secondary indexes are not supported for FTS5 models', + 'pk': 'FTS5 models must use the default `rowid` primary key', + } + + @classmethod + def validate_model(cls): + # Perform FTS5-specific validation and options post-processing. + if cls._meta.primary_key.name != 'rowid': + raise ImproperlyConfigured(cls._error_messages['pk']) + for field in cls._meta.fields.values(): + if not isinstance(field, (SearchField, RowIDField)): + raise ImproperlyConfigured(cls._error_messages['field_type']) + if cls._meta.indexes: + raise ImproperlyConfigured(cls._error_messages['index']) + + @classmethod + def fts5_installed(cls): + if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: + return False + + # Test in-memory DB to determine if the FTS5 extension is installed. + tmp_db = sqlite3.connect(':memory:') + try: + tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') + except: + try: + tmp_db.enable_load_extension(True) + tmp_db.load_extension('fts5') + except: + return False + else: + cls._meta.database.load_extension('fts5') + finally: + tmp_db.close() + + return True + + @staticmethod + def validate_query(query): + """ + Simple helper function to indicate whether a search query is a + valid FTS5 query. Note: this simply looks at the characters being + used, and is not guaranteed to catch all problematic queries. + """ + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + continue + if set(token) & _invalid_ascii: + return False + return True + + @staticmethod + def clean_query(query, replace=chr(26)): + """ + Clean a query of invalid tokens. + """ + accum = [] + any_invalid = False + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + accum.append(token) + continue + token_set = set(token) + invalid_for_token = token_set & _invalid_ascii + if invalid_for_token: + any_invalid = True + for c in invalid_for_token: + token = token.replace(c, replace) + accum.append(token) + + if any_invalid: + return ' '.join(accum) + return query + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *args): + return cls.bm25(*args) if args else SQL('rank') + + @classmethod + def bm25(cls, *weights): + return fn.bm25(cls._meta.entity, *weights) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls.search_bm25( + FTS5Model.clean_query(term), + weights, + with_score, + score_alias, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search using selected `term`.""" + if not weights: + rank = SQL('rank') + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + if isinstance(field, SearchField) and not field.unindexed: + weight_args.append( + weights.get(field, weights.get(field.name, 1.0))) + rank = fn.bm25(cls._meta.entity, *weight_args) + else: + rank = fn.bm25(cls._meta.entity, *weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(FTS5Model.clean_query(term))) + .order_by(order_by)) + + @classmethod + def _fts_cmd_sql(cls, cmd, **extra_params): + tbl = cls._meta.entity + columns = [tbl] + values = [cmd] + for key, value in extra_params.items(): + columns.append(Entity(key)) + values.append(value) + + return NodeList(( + SQL('INSERT INTO'), + cls._meta.entity, + EnclosedNodeList(columns), + SQL('VALUES'), + EnclosedNodeList(values))) + + @classmethod + def _fts_cmd(cls, cmd, **extra_params): + query = cls._fts_cmd_sql(cmd, **extra_params) + return cls._meta.database.execute(query) + + @classmethod + def automerge(cls, level): + if not (0 <= level <= 16): + raise ValueError('level must be between 0 and 16') + return cls._fts_cmd('automerge', rank=level) + + @classmethod + def merge(cls, npages): + return cls._fts_cmd('merge', rank=npages) + + @classmethod + def set_pgsz(cls, pgsz): + return cls._fts_cmd('pgsz', rank=pgsz) + + @classmethod + def set_rank(cls, rank_expression): + return cls._fts_cmd('rank', rank=rank_expression) + + @classmethod + def delete_all(cls): + return cls._fts_cmd('delete-all') + + @classmethod + def VocabModel(cls, table_type='row', table=None): + if table_type not in ('row', 'col', 'instance'): + raise ValueError('table_type must be either "row", "col" or ' + '"instance".') + + attr = '_vocab_model_%s' % table_type + + if not hasattr(cls, attr): + class Meta: + database = cls._meta.database + table_name = table or cls._meta.table_name + '_v' + extension_module = fn.fts5vocab( + cls._meta.entity, + SQL(table_type)) + + attrs = { + 'term': VirtualField(TextField), + 'doc': IntegerField(), + 'cnt': IntegerField(), + 'rowid': RowIDField(), + 'Meta': Meta, + } + if table_type == 'col': + attrs['col'] = VirtualField(TextField) + elif table_type == 'instance': + attrs['offset'] = VirtualField(IntegerField) + + class_name = '%sVocab' % cls.__name__ + setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) + + return getattr(cls, attr) + + +def ClosureTable(model_class, foreign_key=None, referencing_class=None, + referencing_key=None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.refs: + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + source_key = model_class._meta.primary_key + if referencing_key is None: + referencing_key = source_key + + class BaseClosureTable(VirtualModel): + depth = VirtualField(IntegerField) + id = VirtualField(IntegerField) + idcolumn = VirtualField(TextField) + parentcolumn = VirtualField(TextField) + root = VirtualField(IntegerField) + tablename = VirtualField(TextField) + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.id)) + .where(cls.root == node) + .objects()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.root)) + .where(cls.id == node) + .objects()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node.__data__.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(referencing_key) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(source_key << siblings) + .objects()) + + if not include_node: + query = query.where(source_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + options = { + 'tablename': referencing_class._meta.table_name, + 'idcolumn': referencing_key.column_name, + 'parentcolumn': foreign_key.column_name} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta}) + + +class LSMTable(VirtualModel): + class Meta: + extension_module = 'lsm1' + filename = None + + @classmethod + def clean_options(cls, options): + filename = cls._meta.filename + if not filename: + raise ValueError('LSM1 extension requires that you specify a ' + 'filename for the LSM database.') + else: + if len(filename) >= 2 and filename[0] != '"': + filename = '"%s"' % filename + if not cls._meta.primary_key: + raise ValueError('LSM1 models must specify a primary-key field.') + + key = cls._meta.primary_key + if isinstance(key, AutoField): + raise ValueError('LSM1 models must explicitly declare a primary ' + 'key field.') + if not isinstance(key, (TextField, BlobField, IntegerField)): + raise ValueError('LSM1 key must be a TextField, BlobField, or ' + 'IntegerField.') + key._hidden = True + if isinstance(key, IntegerField): + data_type = 'UINT' + elif isinstance(key, BlobField): + data_type = 'BLOB' + else: + data_type = 'TEXT' + cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] + + # Does the key map to a scalar value, or a tuple of values? + if len(cls._meta.sorted_fields) == 2: + cls._meta._value_field = cls._meta.sorted_fields[1] + else: + cls._meta._value_field = None + + return options + + @classmethod + def load_extension(cls, path='lsm.so'): + cls._meta.database.load_extension(path) + + @staticmethod + def slice_to_expr(key, idx): + if idx.start is not None and idx.stop is not None: + return key.between(idx.start, idx.stop) + elif idx.start is not None: + return key >= idx.start + elif idx.stop is not None: + return key <= idx.stop + + @staticmethod + def _apply_lookup_to_query(query, key, lookup): + if isinstance(lookup, slice): + expr = LSMTable.slice_to_expr(key, lookup) + if expr is not None: + query = query.where(expr) + return query, False + elif isinstance(lookup, Expression): + return query.where(lookup), False + else: + return query.where(key == lookup), True + + @classmethod + def get_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.select().namedtuples(), + cls._meta.primary_key, + pk) + + if is_single: + try: + row = query.get() + except cls.DoesNotExist: + raise KeyError(pk) + return row[1] if cls._meta._value_field is not None else row + else: + return query + + @classmethod + def set_by_id(cls, key, value): + if cls._meta._value_field is not None: + data = {cls._meta._value_field: value} + elif isinstance(value, tuple): + data = {} + for field, fval in zip(cls._meta.sorted_fields[1:], value): + data[field] = fval + elif isinstance(value, dict): + data = value + elif isinstance(value, cls): + data = value.__dict__ + data[cls._meta.primary_key] = key + cls.replace(data).execute() + + @classmethod + def delete_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.delete(), + cls._meta.primary_key, + pk) + return query.execute() + + +OP.MATCH = 'MATCH' + +def _sqlite_regexp(regex, value): + return re.search(regex, value) is not None + + +class SqliteExtDatabase(SqliteDatabase): + def __init__(self, database, c_extensions=None, rank_functions=True, + hash_functions=False, regexp_function=False, + bloomfilter=False, json_contains=False, *args, **kwargs): + super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) + self._row_factory = None + + if c_extensions and not CYTHON_SQLITE_EXTENSIONS: + raise ImproperlyConfigured('SqliteExtDatabase initialized with ' + 'C extensions, but shared library was ' + 'not found!') + prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) + if rank_functions: + if prefer_c: + register_rank_functions(self) + else: + self.register_function(bm25, 'fts_bm25') + self.register_function(rank, 'fts_rank') + self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. + self.register_function(bm25, 'fts_lucene') + if hash_functions: + if not prefer_c: + raise ValueError('C extension required to register hash ' + 'functions.') + register_hash_functions(self) + if regexp_function: + self.register_function(_sqlite_regexp, 'regexp', 2) + if bloomfilter: + if not prefer_c: + raise ValueError('C extension required to use bloomfilter.') + register_bloomfilter(self) + if json_contains: + self.register_function(_json_contains, 'json_contains') + + self._c_extensions = prefer_c + + def _add_conn_hooks(self, conn): + super(SqliteExtDatabase, self)._add_conn_hooks(conn) + if self._row_factory: + conn.row_factory = self._row_factory + + def row_factory(self, fn): + self._row_factory = fn + + +if CYTHON_SQLITE_EXTENSIONS: + SQLITE_STATUS_MEMORY_USED = 0 + SQLITE_STATUS_PAGECACHE_USED = 1 + SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 + SQLITE_STATUS_SCRATCH_USED = 3 + SQLITE_STATUS_SCRATCH_OVERFLOW = 4 + SQLITE_STATUS_MALLOC_SIZE = 5 + SQLITE_STATUS_PARSER_STACK = 6 + SQLITE_STATUS_PAGECACHE_SIZE = 7 + SQLITE_STATUS_SCRATCH_SIZE = 8 + SQLITE_STATUS_MALLOC_COUNT = 9 + SQLITE_DBSTATUS_LOOKASIDE_USED = 0 + SQLITE_DBSTATUS_CACHE_USED = 1 + SQLITE_DBSTATUS_SCHEMA_USED = 2 + SQLITE_DBSTATUS_STMT_USED = 3 + SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 + SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 + SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 + SQLITE_DBSTATUS_CACHE_HIT = 7 + SQLITE_DBSTATUS_CACHE_MISS = 8 + SQLITE_DBSTATUS_CACHE_WRITE = 9 + SQLITE_DBSTATUS_DEFERRED_FKS = 10 + #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 + + def __status__(flag, return_highwater=False): + """ + Expose a sqlite3_status() call for a particular flag as a property of + the Database object. + """ + def getter(self): + result = sqlite_get_status(flag) + return result[1] if return_highwater else result + return property(getter) + + def __dbstatus__(flag, return_highwater=False, return_current=False): + """ + Expose a sqlite3_dbstatus() call for a particular flag as a property of + the Database instance. Unlike sqlite3_status(), the dbstatus properties + pertain to the current connection. + """ + def getter(self): + if self._state.conn is None: + raise ImproperlyConfigured('database connection not opened.') + result = sqlite_get_db_status(self._state.conn, flag) + if return_current: + return result[0] + return result[1] if return_highwater else result + return property(getter) + + class CSqliteExtDatabase(SqliteExtDatabase): + def __init__(self, *args, **kwargs): + self._conn_helper = None + self._commit_hook = self._rollback_hook = self._update_hook = None + self._replace_busy_handler = False + super(CSqliteExtDatabase, self).__init__(*args, **kwargs) + + def init(self, database, replace_busy_handler=False, **kwargs): + super(CSqliteExtDatabase, self).init(database, **kwargs) + self._replace_busy_handler = replace_busy_handler + + def _close(self, conn): + if self._commit_hook: + self._conn_helper.set_commit_hook(None) + if self._rollback_hook: + self._conn_helper.set_rollback_hook(None) + if self._update_hook: + self._conn_helper.set_update_hook(None) + return super(CSqliteExtDatabase, self)._close(conn) + + def _add_conn_hooks(self, conn): + super(CSqliteExtDatabase, self)._add_conn_hooks(conn) + self._conn_helper = ConnectionHelper(conn) + if self._commit_hook is not None: + self._conn_helper.set_commit_hook(self._commit_hook) + if self._rollback_hook is not None: + self._conn_helper.set_rollback_hook(self._rollback_hook) + if self._update_hook is not None: + self._conn_helper.set_update_hook(self._update_hook) + if self._replace_busy_handler: + timeout = self._timeout or 5 + self._conn_helper.set_busy_handler(timeout * 1000) + + def on_commit(self, fn): + self._commit_hook = fn + if not self.is_closed(): + self._conn_helper.set_commit_hook(fn) + return fn + + def on_rollback(self, fn): + self._rollback_hook = fn + if not self.is_closed(): + self._conn_helper.set_rollback_hook(fn) + return fn + + def on_update(self, fn): + self._update_hook = fn + if not self.is_closed(): + self._conn_helper.set_update_hook(fn) + return fn + + def changes(self): + return self._conn_helper.changes() + + @property + def last_insert_rowid(self): + return self._conn_helper.last_insert_rowid() + + @property + def autocommit(self): + return self._conn_helper.autocommit() + + def backup(self, destination, pages=None, name=None, progress=None): + return backup(self.connection(), destination.connection(), + pages=pages, name=name, progress=progress) + + def backup_to_file(self, filename, pages=None, name=None, + progress=None): + return backup_to_file(self.connection(), filename, pages=pages, + name=name, progress=progress) + + def blob_open(self, table, column, rowid, read_only=False): + return Blob(self, table, column, rowid, read_only) + + # Status properties. + memory_used = __status__(SQLITE_STATUS_MEMORY_USED) + malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) + malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) + pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) + pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) + pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) + scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) + scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) + scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) + + # Connection status properties. + lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) + lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) + lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, + True) + lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, + True) + cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) + #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, + # False, True) + schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) + statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) + cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) + cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) + cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) + + +def match(lhs, rhs): + return Expression(lhs, OP.MATCH, rhs) + +def _parse_match_info(buf): + # See http://sqlite.org/fts3.html#matchinfo + bufsize = len(buf) # Length in bytes. + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + +def get_weights(ncol, raw_weights): + if not raw_weights: + return [1] * ncol + else: + weights = [0] * ncol + for i, weight in enumerate(raw_weights): + weights[i] = weight + return weights + +# Ranking implementation, which parse matchinfo. +def rank(raw_match_info, *raw_weights): + # Handle match_info called w/default args 'pcx' - based on the example rank + # function http://sqlite.org/fts3.html#appendix_a + match_info = _parse_match_info(raw_match_info) + score = 0.0 + + p, c = match_info[:2] + weights = get_weights(c, raw_weights) + + # matchinfo X value corresponds to, for each phrase in the search query, a + # list of 3 values for each column in the search table. + # So if we have a two-phrase search query and three columns of data, the + # following would be the layout: + # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] + # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + weight = weights[col_num] + if not weight: + continue + + col_idx = phrase_info_idx + (col_num * 3) + + # The idea is that we count the number of times the phrase appears + # in this column of the current row, compared to how many times it + # appears in this column across all rows. The ratio of these values + # provides a rough way to score based on "high value" terms. + row_hits = match_info[col_idx] + all_rows_hits = match_info[col_idx + 1] + if row_hits > 0: + score += weight * (float(row_hits) / all_rows_hits) + + return -score + +# Okapi BM25 ranking implementation (FTS4 only). +def bm25(raw_match_info, *args): + """ + Usage: + + # Format string *must* be pcnalx + # Second parameter to bm25 specifies the index of the column, on + # the table being queries. + bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank + """ + match_info = _parse_match_info(raw_match_info) + K = 1.2 + B = 0.75 + score = 0.0 + + P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. + term_count = match_info[P_O] # n + col_count = match_info[C_O] + total_docs = match_info[N_O] # N + L_O = A_O + col_count + X_O = L_O + col_count + + weights = get_weights(col_count, args) + + for i in range(term_count): + for j in range(col_count): + weight = weights[j] + if weight == 0: + continue + + x = X_O + (3 * (j + i * col_count)) + term_frequency = float(match_info[x]) # f(qi, D) + docs_with_term = float(match_info[x + 2]) # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = math.log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + if idf <= 0.0: + idf = 1e-6 + + doc_length = float(match_info[L_O + j]) # |D| + avg_length = float(match_info[A_O + j]) or 1. # avgdl + ratio = doc_length / avg_length + + num = term_frequency * (K + 1) + b_part = 1 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * (num / denom) + score += (pc_score * weight) + + return -score + + +def _json_contains(src_json, obj_json): + stack = [] + try: + stack.append((json.loads(obj_json), json.loads(src_json))) + except: + # Invalid JSON! + return False + + while stack: + obj, src = stack.pop() + if isinstance(src, dict): + if isinstance(obj, dict): + for key in obj: + if key not in src: + return False + stack.append((obj[key], src[key])) + elif isinstance(obj, list): + for item in obj: + if item not in src: + return False + elif obj not in src: + return False + elif isinstance(src, list): + if isinstance(obj, dict): + return False + elif isinstance(obj, list): + try: + for i in range(len(obj)): + stack.append((obj[i], src[i])) + except IndexError: + return False + elif obj not in src: + return False + elif obj != src: + return False + return True diff --git a/libs/playhouse/sqlite_udf.py b/libs/playhouse/sqlite_udf.py new file mode 100644 index 000000000..28dbd8560 --- /dev/null +++ b/libs/playhouse/sqlite_udf.py @@ -0,0 +1,522 @@ +import datetime +import hashlib +import heapq +import math +import os +import random +import re +import sys +import threading +import zlib +try: + from collections import Counter +except ImportError: + Counter = None +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + +try: + from playhouse._sqlite_ext import TableFunction +except ImportError: + TableFunction = None + + +SQLITE_DATETIME_FORMATS = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +from peewee import format_date_time + +def format_date_time_sqlite(date_value): + return format_date_time(date_value, SQLITE_DATETIME_FORMATS) + +try: + from playhouse import _sqlite_udf as cython_udf +except ImportError: + cython_udf = None + + +# Group udf by function. +CONTROL_FLOW = 'control_flow' +DATE = 'date' +FILE = 'file' +HELPER = 'helpers' +MATH = 'math' +STRING = 'string' + +AGGREGATE_COLLECTION = {} +TABLE_FUNCTION_COLLECTION = {} +UDF_COLLECTION = {} + + +class synchronized_dict(dict): + def __init__(self, *args, **kwargs): + super(synchronized_dict, self).__init__(*args, **kwargs) + self._lock = threading.Lock() + + def __getitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__getitem__(key) + + def __setitem__(self, key, value): + with self._lock: + return super(synchronized_dict, self).__setitem__(key, value) + + def __delitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__delitem__(key) + + +STATE = synchronized_dict() +SETTINGS = synchronized_dict() + +# Class and function decorators. +def aggregate(*groups): + def decorator(klass): + for group in groups: + AGGREGATE_COLLECTION.setdefault(group, []) + AGGREGATE_COLLECTION[group].append(klass) + return klass + return decorator + +def table_function(*groups): + def decorator(klass): + for group in groups: + TABLE_FUNCTION_COLLECTION.setdefault(group, []) + TABLE_FUNCTION_COLLECTION[group].append(klass) + return klass + return decorator + +def udf(*groups): + def decorator(fn): + for group in groups: + UDF_COLLECTION.setdefault(group, []) + UDF_COLLECTION[group].append(fn) + return fn + return decorator + +# Register aggregates / functions with connection. +def register_aggregate_groups(db, *groups): + seen = set() + for group in groups: + klasses = AGGREGATE_COLLECTION.get(group, ()) + for klass in klasses: + name = getattr(klass, 'name', klass.__name__) + if name not in seen: + seen.add(name) + db.register_aggregate(klass, name) + +def register_table_function_groups(db, *groups): + seen = set() + for group in groups: + klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) + for klass in klasses: + if klass.name not in seen: + seen.add(klass.name) + db.register_table_function(klass) + +def register_udf_groups(db, *groups): + seen = set() + for group in groups: + functions = UDF_COLLECTION.get(group, ()) + for function in functions: + name = function.__name__ + if name not in seen: + seen.add(name) + db.register_function(function, name) + +def register_groups(db, *groups): + register_aggregate_groups(db, *groups) + register_table_function_groups(db, *groups) + register_udf_groups(db, *groups) + +def register_all(db): + register_aggregate_groups(db, *AGGREGATE_COLLECTION) + register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) + register_udf_groups(db, *UDF_COLLECTION) + + +# Begin actual user-defined functions and aggregates. + +# Scalar functions. +@udf(CONTROL_FLOW) +def if_then_else(cond, truthy, falsey=None): + if cond: + return truthy + return falsey + +@udf(DATE) +def strip_tz(date_str): + date_str = date_str.replace('T', ' ') + tz_idx1 = date_str.find('+') + if tz_idx1 != -1: + return date_str[:tz_idx1] + tz_idx2 = date_str.find('-') + if tz_idx2 > 13: + return date_str[:tz_idx2] + return date_str + +@udf(DATE) +def human_delta(nseconds, glue=', '): + parts = ( + (86400 * 365, 'year'), + (86400 * 30, 'month'), + (86400 * 7, 'week'), + (86400, 'day'), + (3600, 'hour'), + (60, 'minute'), + (1, 'second'), + ) + accum = [] + for offset, name in parts: + val, nseconds = divmod(nseconds, offset) + if val: + suffix = val != 1 and 's' or '' + accum.append('%s %s%s' % (val, name, suffix)) + if not accum: + return '0 seconds' + return glue.join(accum) + +@udf(FILE) +def file_ext(filename): + try: + res = os.path.splitext(filename) + except ValueError: + return None + return res[1] + +@udf(FILE) +def file_read(filename): + try: + with open(filename) as fh: + return fh.read() + except: + pass + +if sys.version_info[0] == 2: + @udf(HELPER) + def gzip(data, compression=9): + return buffer(zlib.compress(data, compression)) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) +else: + @udf(HELPER) + def gzip(data, compression=9): + if isinstance(data, str): + data = bytes(data.encode('raw_unicode_escape')) + return zlib.compress(data, compression) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) + +@udf(HELPER) +def hostname(url): + parse_result = urlparse(url) + if parse_result: + return parse_result.netloc + +@udf(HELPER) +def toggle(key): + key = key.lower() + STATE[key] = ret = not STATE.get(key) + return ret + +@udf(HELPER) +def setting(key, value=None): + if value is None: + return SETTINGS.get(key) + else: + SETTINGS[key] = value + return value + +@udf(HELPER) +def clear_settings(): + SETTINGS.clear() + +@udf(HELPER) +def clear_toggles(): + STATE.clear() + +@udf(MATH) +def randomrange(start, end=None, step=None): + if end is None: + start, end = 0, start + elif step is None: + step = 1 + return random.randrange(start, end, step) + +@udf(MATH) +def gauss_distribution(mean, sigma): + try: + return random.gauss(mean, sigma) + except ValueError: + return None + +@udf(MATH) +def sqrt(n): + try: + return math.sqrt(n) + except ValueError: + return None + +@udf(MATH) +def tonumber(s): + try: + return int(s) + except ValueError: + try: + return float(s) + except: + return None + +@udf(STRING) +def substr_count(haystack, needle): + if not haystack or not needle: + return 0 + return haystack.count(needle) + +@udf(STRING) +def strip_chars(haystack, chars): + return haystack.strip(chars) + +def _hash(constructor, *args): + hash_obj = constructor() + for arg in args: + hash_obj.update(arg) + return hash_obj.hexdigest() + +# Aggregates. +class _heap_agg(object): + def __init__(self): + self.heap = [] + self.ct = 0 + + def process(self, value): + return value + + def step(self, value): + self.ct += 1 + heapq.heappush(self.heap, self.process(value)) + +class _datetime_heap_agg(_heap_agg): + def process(self, value): + return format_date_time_sqlite(value) + +if sys.version_info[:2] == (2, 6): + def total_seconds(td): + return (td.seconds + + (td.days * 86400) + + (td.microseconds / (10.**6))) +else: + total_seconds = lambda td: td.total_seconds() + +@aggregate(DATE) +class mintdiff(_datetime_heap_agg): + def finalize(self): + dtp = min_diff = None + while self.heap: + if min_diff is None: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + dt = heapq.heappop(self.heap) + diff = dt - dtp + if min_diff is None or min_diff > diff: + min_diff = diff + dtp = dt + if min_diff is not None: + return total_seconds(min_diff) + +@aggregate(DATE) +class avgtdiff(_datetime_heap_agg): + def finalize(self): + if self.ct < 1: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + dtp = None + while self.heap: + if total == 0: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + + dt = heapq.heappop(self.heap) + diff = dt - dtp + ct += 1 + total += total_seconds(diff) + dtp = dt + + return float(total) / ct + +@aggregate(DATE) +class duration(object): + def __init__(self): + self._min = self._max = None + + def step(self, value): + dt = format_date_time_sqlite(value) + if self._min is None or dt < self._min: + self._min = dt + if self._max is None or dt > self._max: + self._max = dt + + def finalize(self): + if self._min and self._max: + td = (self._max - self._min) + return total_seconds(td) + return None + +@aggregate(MATH) +class mode(object): + if Counter: + def __init__(self): + self.items = Counter() + + def step(self, *args): + self.items.update(args) + + def finalize(self): + if self.items: + return self.items.most_common(1)[0][0] + else: + def __init__(self): + self.items = [] + + def step(self, item): + self.items.append(item) + + def finalize(self): + if self.items: + return max(set(self.items), key=self.items.count) + +@aggregate(MATH) +class minrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + prev = min_diff = None + + while self.heap: + if min_diff is None: + if prev is None: + prev = heapq.heappop(self.heap) + continue + curr = heapq.heappop(self.heap) + diff = curr - prev + if min_diff is None or min_diff > diff: + min_diff = diff + prev = curr + return min_diff + +@aggregate(MATH) +class avgrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + prev = None + while self.heap: + if total == 0: + if prev is None: + prev = heapq.heappop(self.heap) + continue + + curr = heapq.heappop(self.heap) + diff = curr - prev + ct += 1 + total += diff + prev = curr + + return float(total) / ct + +@aggregate(MATH) +class _range(object): + name = 'range' + + def __init__(self): + self._min = self._max = None + + def step(self, value): + if self._min is None or value < self._min: + self._min = value + if self._max is None or value > self._max: + self._max = value + + def finalize(self): + if self._min is not None and self._max is not None: + return self._max - self._min + return None + + +if cython_udf is not None: + damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) + levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) + str_dist = udf(STRING)(cython_udf.str_dist) + median = aggregate(MATH)(cython_udf.median) + + +if TableFunction is not None: + @table_function(STRING) + class RegexSearch(TableFunction): + params = ['regex', 'search_string'] + columns = ['match'] + name = 'regex_search' + + def initialize(self, regex=None, search_string=None): + self._iter = re.finditer(regex, search_string) + + def iterate(self, idx): + return (next(self._iter).group(0),) + + @table_function(DATE) + class DateSeries(TableFunction): + params = ['start', 'stop', 'step_seconds'] + columns = ['date'] + name = 'date_series' + + def initialize(self, start, stop, step_seconds=86400): + self.start = format_date_time_sqlite(start) + self.stop = format_date_time_sqlite(stop) + step_seconds = int(step_seconds) + self.step_seconds = datetime.timedelta(seconds=step_seconds) + + if (self.start.hour == 0 and + self.start.minute == 0 and + self.start.second == 0 and + step_seconds >= 86400): + self.format = '%Y-%m-%d' + elif (self.start.year == 1900 and + self.start.month == 1 and + self.start.day == 1 and + self.stop.year == 1900 and + self.stop.month == 1 and + self.stop.day == 1 and + step_seconds < 86400): + self.format = '%H:%M:%S' + else: + self.format = '%Y-%m-%d %H:%M:%S' + + def iterate(self, idx): + if self.start > self.stop: + raise StopIteration + current = self.start + self.start += self.step_seconds + return (current.strftime(self.format),) diff --git a/libs/playhouse/sqliteq.py b/libs/playhouse/sqliteq.py new file mode 100644 index 000000000..bd213549d --- /dev/null +++ b/libs/playhouse/sqliteq.py @@ -0,0 +1,330 @@ +import logging +import weakref +from threading import local as thread_local +from threading import Event +from threading import Thread +try: + from Queue import Queue +except ImportError: + from queue import Queue + +try: + import gevent + from gevent import Greenlet as GThread + from gevent.event import Event as GEvent + from gevent.local import local as greenlet_local + from gevent.queue import Queue as GQueue +except ImportError: + GThread = GQueue = GEvent = None + +from peewee import SENTINEL +from playhouse.sqlite_ext import SqliteExtDatabase + + +logger = logging.getLogger('peewee.sqliteq') + + +class ResultTimeout(Exception): + pass + +class WriterPaused(Exception): + pass + +class ShutdownException(Exception): + pass + + +class AsyncCursor(object): + __slots__ = ('sql', 'params', 'commit', 'timeout', + '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') + + def __init__(self, event, sql, params, commit, timeout): + self._event = event + self.sql = sql + self.params = params + self.commit = commit + self.timeout = timeout + self._cursor = self._exc = self._idx = self._rows = None + self._ready = False + + def set_result(self, cursor, exc=None): + self._cursor = cursor + self._exc = exc + self._idx = 0 + self._rows = cursor.fetchall() if exc is None else [] + self._event.set() + return self + + def _wait(self, timeout=None): + timeout = timeout if timeout is not None else self.timeout + if not self._event.wait(timeout=timeout) and timeout: + raise ResultTimeout('results not ready, timed out.') + if self._exc is not None: + raise self._exc + self._ready = True + + def __iter__(self): + if not self._ready: + self._wait() + if self._exc is not None: + raise self._exec + return self + + def next(self): + if not self._ready: + self._wait() + try: + obj = self._rows[self._idx] + except IndexError: + raise StopIteration + else: + self._idx += 1 + return obj + __next__ = next + + @property + def lastrowid(self): + if not self._ready: + self._wait() + return self._cursor.lastrowid + + @property + def rowcount(self): + if not self._ready: + self._wait() + return self._cursor.rowcount + + @property + def description(self): + return self._cursor.description + + def close(self): + self._cursor.close() + + def fetchall(self): + return list(self) # Iterating implies waiting until populated. + + def fetchone(self): + if not self._ready: + self._wait() + try: + return next(self) + except StopIteration: + return None + +SHUTDOWN = StopIteration +PAUSE = object() +UNPAUSE = object() + + +class Writer(object): + __slots__ = ('database', 'queue') + + def __init__(self, database, queue): + self.database = database + self.queue = queue + + def run(self): + conn = self.database.connection() + try: + while True: + try: + if conn is None: # Paused. + if self.wait_unpause(): + conn = self.database.connection() + else: + conn = self.loop(conn) + except ShutdownException: + logger.info('writer received shutdown request, exiting.') + return + finally: + if conn is not None: + self.database._close(conn) + self.database._state.reset() + + def wait_unpause(self): + obj = self.queue.get() + if obj is UNPAUSE: + logger.info('writer unpaused - reconnecting to database.') + return True + elif obj is SHUTDOWN: + raise ShutdownException() + elif obj is PAUSE: + logger.error('writer received pause, but is already paused.') + else: + obj.set_result(None, WriterPaused()) + logger.warning('writer paused, not handling %s', obj) + + def loop(self, conn): + obj = self.queue.get() + if isinstance(obj, AsyncCursor): + self.execute(obj) + elif obj is PAUSE: + logger.info('writer paused - closing database connection.') + self.database._close(conn) + self.database._state.reset() + return + elif obj is UNPAUSE: + logger.error('writer received unpause, but is already running.') + elif obj is SHUTDOWN: + raise ShutdownException() + else: + logger.error('writer received unsupported object: %s', obj) + return conn + + def execute(self, obj): + logger.debug('received query %s', obj.sql) + try: + cursor = self.database._execute(obj.sql, obj.params, obj.commit) + except Exception as execute_err: + cursor = None + exc = execute_err # python3 is so fucking lame. + else: + exc = None + return obj.set_result(cursor, exc) + + +class SqliteQueueDatabase(SqliteExtDatabase): + WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' + 'journal mode when using this feature. WAL mode ' + 'allows one or more readers to continue reading ' + 'while another connection writes to the ' + 'database.') + + def __init__(self, database, use_gevent=False, autostart=True, + queue_max_size=None, results_timeout=None, *args, **kwargs): + kwargs['check_same_thread'] = False + + # Ensure that journal_mode is WAL. This value is passed to the parent + # class constructor below. + pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) + + # Reference to execute_sql on the parent class. Since we've overridden + # execute_sql(), this is just a handy way to reference the real + # implementation. + Parent = super(SqliteQueueDatabase, self) + self._execute = Parent.execute_sql + + # Call the parent class constructor with our modified pragmas. + Parent.__init__(database, pragmas=pragmas, *args, **kwargs) + + self._autostart = autostart + self._results_timeout = results_timeout + self._is_stopped = True + + # Get different objects depending on the threading implementation. + self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) + + # Create the writer thread, optionally starting it. + self._create_write_queue() + if self._autostart: + self.start() + + def get_thread_impl(self, use_gevent): + return GreenletHelper if use_gevent else ThreadHelper + + def _validate_journal_mode(self, pragmas=None): + if pragmas: + pdict = dict((k.lower(), v) for (k, v) in pragmas) + if pdict.get('journal_mode', 'wal').lower() != 'wal': + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) + + return [(k, v) for (k, v) in pragmas + if k != 'journal_mode'] + [('journal_mode', 'wal')] + else: + return [('journal_mode', 'wal')] + + def _create_write_queue(self): + self._write_queue = self._thread_helper.queue() + + def queue_size(self): + return self._write_queue.qsize() + + def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None): + if commit is SENTINEL: + commit = not sql.lower().startswith('select') + + if not commit: + return self._execute(sql, params, commit=commit) + + cursor = AsyncCursor( + event=self._thread_helper.event(), + sql=sql, + params=params, + commit=commit, + timeout=self._results_timeout if timeout is None else timeout) + self._write_queue.put(cursor) + return cursor + + def start(self): + with self._lock: + if not self._is_stopped: + return False + def run(): + writer = Writer(self, self._write_queue) + writer.run() + + self._writer = self._thread_helper.thread(run) + self._writer.start() + self._is_stopped = False + return True + + def stop(self): + logger.debug('environment stop requested.') + with self._lock: + if self._is_stopped: + return False + self._write_queue.put(SHUTDOWN) + self._writer.join() + self._is_stopped = True + return True + + def is_stopped(self): + with self._lock: + return self._is_stopped + + def pause(self): + with self._lock: + self._write_queue.put(PAUSE) + + def unpause(self): + with self._lock: + self._write_queue.put(UNPAUSE) + + def __unsupported__(self, *args, **kwargs): + raise ValueError('This method is not supported by %r.' % type(self)) + atomic = transaction = savepoint = __unsupported__ + + +class ThreadHelper(object): + __slots__ = ('queue_max_size',) + + def __init__(self, queue_max_size=None): + self.queue_max_size = queue_max_size + + def event(self): return Event() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return Queue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + thread = Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + return thread + + +class GreenletHelper(ThreadHelper): + __slots__ = () + + def event(self): return GEvent() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return GQueue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + def wrap(*a, **k): + gevent.sleep() + return fn(*a, **k) + return GThread(wrap, *args, **kwargs) diff --git a/libs/playhouse/test_utils.py b/libs/playhouse/test_utils.py new file mode 100644 index 000000000..333dc078b --- /dev/null +++ b/libs/playhouse/test_utils.py @@ -0,0 +1,62 @@ +from functools import wraps +import logging + + +logger = logging.getLogger('peewee') + + +class _QueryLogHandler(logging.Handler): + def __init__(self, *args, **kwargs): + self.queries = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.queries.append(record) + + +class count_queries(object): + def __init__(self, only_select=False): + self.only_select = only_select + self.count = 0 + + def get_queries(self): + return self._handler.queries + + def __enter__(self): + self._handler = _QueryLogHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(self._handler) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.removeHandler(self._handler) + if self.only_select: + self.count = len([q for q in self._handler.queries + if q.msg[0].startswith('SELECT ')]) + else: + self.count = len(self._handler.queries) + + +class assert_query_count(count_queries): + def __init__(self, expected, only_select=False): + super(assert_query_count, self).__init__(only_select=only_select) + self.expected = expected + + def __call__(self, f): + @wraps(f) + def decorated(*args, **kwds): + with self: + ret = f(*args, **kwds) + + self._assert_count() + return ret + + return decorated + + def _assert_count(self): + error_msg = '%s != %s' % (self.count, self.expected) + assert self.count == self.expected, error_msg + + def __exit__(self, exc_type, exc_val, exc_tb): + super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) + self._assert_count() diff --git a/libs/pwiz.py b/libs/pwiz.py new file mode 100644 index 000000000..dd50279fc --- /dev/null +++ b/libs/pwiz.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python + +import datetime +import sys +from getpass import getpass +from optparse import OptionParser + +from peewee import * +from peewee import print_ +from peewee import __version__ as peewee_version +from playhouse.reflection import * + + +HEADER = """from peewee import *%s + +database = %s('%s'%s) +""" + +BASE_MODEL = """\ +class BaseModel(Model): + class Meta: + database = database +""" + +UNKNOWN_FIELD = """\ +class UnknownField(object): + def __init__(self, *_, **__): pass +""" + +DATABASE_ALIASES = { + MySQLDatabase: ['mysql', 'mysqldb'], + PostgresqlDatabase: ['postgres', 'postgresql'], + SqliteDatabase: ['sqlite', 'sqlite3'], +} + +DATABASE_MAP = dict((value, key) + for key in DATABASE_ALIASES + for value in DATABASE_ALIASES[key]) + +def make_introspector(database_type, database_name, **kwargs): + if database_type not in DATABASE_MAP: + err('Unrecognized database, must be one of: %s' % + ', '.join(DATABASE_MAP.keys())) + sys.exit(1) + + schema = kwargs.pop('schema', None) + DatabaseClass = DATABASE_MAP[database_type] + db = DatabaseClass(database_name, **kwargs) + return Introspector.from_database(db, schema=schema) + +def print_models(introspector, tables=None, preserve_order=False, + include_views=False, ignore_unknown=False, snake_case=True): + database = introspector.introspect(table_names=tables, + include_views=include_views, + snake_case=snake_case) + + db_kwargs = introspector.get_database_kwargs() + header = HEADER % ( + introspector.get_additional_imports(), + introspector.get_database_class().__name__, + introspector.get_database_name(), + ', **%s' % repr(db_kwargs) if db_kwargs else '') + print_(header) + + if not ignore_unknown: + print_(UNKNOWN_FIELD) + + print_(BASE_MODEL) + + def _print_table(table, seen, accum=None): + accum = accum or [] + foreign_keys = database.foreign_keys[table] + for foreign_key in foreign_keys: + dest = foreign_key.dest_table + + # In the event the destination table has already been pushed + # for printing, then we have a reference cycle. + if dest in accum and table not in accum: + print_('# Possible reference cycle: %s' % dest) + + # If this is not a self-referential foreign key, and we have + # not already processed the destination table, do so now. + if dest not in seen and dest not in accum: + seen.add(dest) + if dest != table: + _print_table(dest, seen, accum + [table]) + + print_('class %s(BaseModel):' % database.model_names[table]) + columns = database.columns[table].items() + if not preserve_order: + columns = sorted(columns) + primary_keys = database.primary_keys[table] + for name, column in columns: + skip = all([ + name in primary_keys, + name == 'id', + len(primary_keys) == 1, + column.field_class in introspector.pk_classes]) + if skip: + continue + if column.primary_key and len(primary_keys) > 1: + # If we have a CompositeKey, then we do not want to explicitly + # mark the columns as being primary keys. + column.primary_key = False + + is_unknown = column.field_class is UnknownField + if is_unknown and ignore_unknown: + disp = '%s - %s' % (column.name, column.raw_column_type or '?') + print_(' # %s' % disp) + else: + print_(' %s' % column.get_field()) + + print_('') + print_(' class Meta:') + print_(' table_name = \'%s\'' % table) + multi_column_indexes = database.multi_column_indexes(table) + if multi_column_indexes: + print_(' indexes = (') + for fields, unique in sorted(multi_column_indexes): + print_(' ((%s), %s),' % ( + ', '.join("'%s'" % field for field in fields), + unique, + )) + print_(' )') + + if introspector.schema: + print_(' schema = \'%s\'' % introspector.schema) + if len(primary_keys) > 1: + pk_field_names = sorted([ + field.name for col, field in columns + if col in primary_keys]) + pk_list = ', '.join("'%s'" % pk for pk in pk_field_names) + print_(' primary_key = CompositeKey(%s)' % pk_list) + elif not primary_keys: + print_(' primary_key = False') + print_('') + + seen.add(table) + + seen = set() + for table in sorted(database.model_names.keys()): + if table not in seen: + if not tables or table in tables: + _print_table(table, seen) + +def print_header(cmd_line, introspector): + timestamp = datetime.datetime.now() + print_('# Code generated by:') + print_('# python -m pwiz %s' % cmd_line) + print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p')) + print_('# Database: %s' % introspector.get_database_name()) + print_('# Peewee version: %s' % peewee_version) + print_('') + + +def err(msg): + sys.stderr.write('\033[91m%s\033[0m\n' % msg) + sys.stderr.flush() + +def get_option_parser(): + parser = OptionParser(usage='usage: %prog [options] database_name') + ao = parser.add_option + ao('-H', '--host', dest='host') + ao('-p', '--port', dest='port', type='int') + ao('-u', '--user', dest='user') + ao('-P', '--password', dest='password', action='store_true') + engines = sorted(DATABASE_MAP) + ao('-e', '--engine', dest='engine', default='postgresql', choices=engines, + help=('Database type, e.g. sqlite, mysql or postgresql. Default ' + 'is "postgresql".')) + ao('-s', '--schema', dest='schema') + ao('-t', '--tables', dest='tables', + help=('Only generate the specified tables. Multiple table names should ' + 'be separated by commas.')) + ao('-v', '--views', dest='views', action='store_true', + help='Generate model classes for VIEWs in addition to tables.') + ao('-i', '--info', dest='info', action='store_true', + help=('Add database information and other metadata to top of the ' + 'generated file.')) + ao('-o', '--preserve-order', action='store_true', dest='preserve_order', + help='Model definition column ordering matches source table.') + ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown', + help='Ignore fields whose type cannot be determined.') + ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming', + help='Use legacy table- and column-name generation.') + return parser + +def get_connect_kwargs(options): + ops = ('host', 'port', 'user', 'schema') + kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o)) + if options.password: + kwargs['password'] = getpass() + return kwargs + + +if __name__ == '__main__': + raw_argv = sys.argv + + parser = get_option_parser() + options, args = parser.parse_args() + + if len(args) < 1: + err('Missing required parameter "database"') + parser.print_help() + sys.exit(1) + + connect = get_connect_kwargs(options) + database = args[-1] + + tables = None + if options.tables: + tables = [table.strip() for table in options.tables.split(',') + if table.strip()] + + introspector = make_introspector(options.engine, database, **connect) + if options.info: + cmd_line = ' '.join(raw_argv[1:]) + print_header(cmd_line, introspector) + + print_models(introspector, tables, options.preserve_order, options.views, + options.ignore_unknown, not options.legacy_naming) diff --git a/libs/version.txt b/libs/version.txt index 47a8584ac..2ec8fefe1 100644 --- a/libs/version.txt +++ b/libs/version.txt @@ -13,6 +13,7 @@ gevent-websocker=0.10.1 gitpython=2.1.9 guessit=2.1.4 langdetect=1.0.7 +peewee=3.9.6 py-pretty=1 pycountry=18.2.23 pysrt=1.1.1