parent
2f3cc8279f
commit
fc757a7e54
@ -0,0 +1,172 @@
|
|||||||
|
import os
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
from get_args import args
|
||||||
|
from peewee import *
|
||||||
|
from playhouse.sqliteq import SqliteQueueDatabase
|
||||||
|
|
||||||
|
from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie
|
||||||
|
|
||||||
|
database = SqliteQueueDatabase(
|
||||||
|
os.path.join(args.config_dir, 'db', 'bazarr.db'),
|
||||||
|
use_gevent=False, # Use the standard library "threading" module.
|
||||||
|
autostart=True, # The worker thread now must be started manually.
|
||||||
|
queue_max_size=256, # Max. # of pending writes that can accumulate.
|
||||||
|
results_timeout=30.0) # Max. time to wait for query to be executed.
|
||||||
|
|
||||||
|
|
||||||
|
@database.func('path_substitution')
|
||||||
|
def path_substitution(path):
|
||||||
|
return path_replace(path)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownField(object):
|
||||||
|
def __init__(self, *_, **__): pass
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = database
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSequence(BaseModel):
|
||||||
|
name = BareField(null=True)
|
||||||
|
seq = BareField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'sqlite_sequence'
|
||||||
|
primary_key = False
|
||||||
|
|
||||||
|
|
||||||
|
class System(BaseModel):
|
||||||
|
configured = TextField(null=True)
|
||||||
|
updated = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'system'
|
||||||
|
primary_key = False
|
||||||
|
|
||||||
|
|
||||||
|
class TableShows(BaseModel):
|
||||||
|
alternate_titles = TextField(column_name='alternateTitles', null=True)
|
||||||
|
audio_language = TextField(null=True)
|
||||||
|
fanart = TextField(null=True)
|
||||||
|
forced = TextField(null=True)
|
||||||
|
hearing_impaired = TextField(null=True)
|
||||||
|
languages = TextField(null=True)
|
||||||
|
overview = TextField(null=True)
|
||||||
|
path = TextField(unique=True)
|
||||||
|
poster = TextField(null=True)
|
||||||
|
sonarr_series_id = IntegerField(column_name='sonarrSeriesId', unique=True)
|
||||||
|
sort_title = TextField(column_name='sortTitle', null=True)
|
||||||
|
title = TextField()
|
||||||
|
tvdb_id = AutoField(column_name='tvdbId')
|
||||||
|
year = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_shows'
|
||||||
|
|
||||||
|
|
||||||
|
class TableEpisodes(BaseModel):
|
||||||
|
audio_codec = TextField(null=True)
|
||||||
|
episode = IntegerField()
|
||||||
|
failed_attempts = TextField(column_name='failedAttempts', null=True)
|
||||||
|
format = TextField(null=True)
|
||||||
|
missing_subtitles = TextField(null=True)
|
||||||
|
monitored = TextField(null=True)
|
||||||
|
path = TextField()
|
||||||
|
resolution = TextField(null=True)
|
||||||
|
scene_name = TextField(null=True)
|
||||||
|
season = IntegerField()
|
||||||
|
sonarr_episode_id = IntegerField(column_name='sonarrEpisodeId', unique=True)
|
||||||
|
sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId')
|
||||||
|
subtitles = TextField(null=True)
|
||||||
|
title = TextField()
|
||||||
|
video_codec = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_episodes'
|
||||||
|
primary_key = False
|
||||||
|
|
||||||
|
|
||||||
|
class TableMovies(BaseModel):
|
||||||
|
alternative_titles = TextField(column_name='alternativeTitles', null=True)
|
||||||
|
audio_codec = TextField(null=True)
|
||||||
|
audio_language = TextField(null=True)
|
||||||
|
failed_attempts = TextField(column_name='failedAttempts', null=True)
|
||||||
|
fanart = TextField(null=True)
|
||||||
|
forced = TextField(null=True)
|
||||||
|
format = TextField(null=True)
|
||||||
|
hearing_impaired = TextField(null=True)
|
||||||
|
imdb_id = TextField(column_name='imdbId', null=True)
|
||||||
|
languages = TextField(null=True)
|
||||||
|
missing_subtitles = TextField(null=True)
|
||||||
|
monitored = TextField(null=True)
|
||||||
|
overview = TextField(null=True)
|
||||||
|
path = TextField(unique=True)
|
||||||
|
poster = TextField(null=True)
|
||||||
|
radarr_id = IntegerField(column_name='radarrId', unique=True)
|
||||||
|
resolution = TextField(null=True)
|
||||||
|
scene_name = TextField(column_name='sceneName', null=True)
|
||||||
|
sort_title = TextField(column_name='sortTitle', null=True)
|
||||||
|
subtitles = TextField(null=True)
|
||||||
|
title = TextField()
|
||||||
|
tmdb_id = TextField(column_name='tmdbId', primary_key=True)
|
||||||
|
video_codec = TextField(null=True)
|
||||||
|
year = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_movies'
|
||||||
|
|
||||||
|
|
||||||
|
class TableHistory(BaseModel):
|
||||||
|
action = IntegerField()
|
||||||
|
description = TextField()
|
||||||
|
language = TextField(null=True)
|
||||||
|
provider = TextField(null=True)
|
||||||
|
score = TextField(null=True)
|
||||||
|
sonarr_episode_id = ForeignKeyField(TableEpisodes, field='sonarr_episode_id' column_name='sonarrEpisodeId')
|
||||||
|
sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id' column_name='sonarrSeriesId')
|
||||||
|
timestamp = IntegerField()
|
||||||
|
video_path = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_history'
|
||||||
|
|
||||||
|
|
||||||
|
class TableHistoryMovie(BaseModel):
|
||||||
|
action = IntegerField()
|
||||||
|
description = TextField()
|
||||||
|
language = TextField(null=True)
|
||||||
|
provider = TextField(null=True)
|
||||||
|
radarr_id = ForeignKeyField(TableMovies, field='radarr_id' column_name='radarrId')
|
||||||
|
score = TextField(null=True)
|
||||||
|
timestamp = IntegerField()
|
||||||
|
video_path = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_history_movie'
|
||||||
|
|
||||||
|
|
||||||
|
class TableSettingsLanguages(BaseModel):
|
||||||
|
code2 = TextField(null=True)
|
||||||
|
code3 = TextField(primary_key=True)
|
||||||
|
code3b = TextField(null=True)
|
||||||
|
enabled = IntegerField(null=True)
|
||||||
|
name = TextField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_settings_languages'
|
||||||
|
|
||||||
|
|
||||||
|
class TableSettingsNotifier(BaseModel):
|
||||||
|
enabled = IntegerField(null=True)
|
||||||
|
name = TextField(null=True, primary_key=True)
|
||||||
|
url = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = 'table_settings_notifier'
|
||||||
|
|
||||||
|
|
||||||
|
@atexit.register
|
||||||
|
def _stop_worker_threads():
|
||||||
|
database.stop()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Peewee integration with APSW, "another python sqlite wrapper".
|
||||||
|
|
||||||
|
Project page: https://rogerbinns.github.io/apsw/
|
||||||
|
|
||||||
|
APSW is a really neat library that provides a thin wrapper on top of SQLite's
|
||||||
|
C interface.
|
||||||
|
|
||||||
|
Here are just a few reasons to use APSW, taken from the documentation:
|
||||||
|
|
||||||
|
* APSW gives all functionality of SQLite, including virtual tables, virtual
|
||||||
|
file system, blob i/o, backups and file control.
|
||||||
|
* Connections can be shared across threads without any additional locking.
|
||||||
|
* Transactions are managed explicitly by your code.
|
||||||
|
* APSW can handle nested transactions.
|
||||||
|
* Unicode is handled correctly.
|
||||||
|
* APSW is faster.
|
||||||
|
"""
|
||||||
|
import apsw
|
||||||
|
from peewee import *
|
||||||
|
from peewee import __exception_wrapper__
|
||||||
|
from peewee import BooleanField as _BooleanField
|
||||||
|
from peewee import DateField as _DateField
|
||||||
|
from peewee import DateTimeField as _DateTimeField
|
||||||
|
from peewee import DecimalField as _DecimalField
|
||||||
|
from peewee import TimeField as _TimeField
|
||||||
|
from peewee import logger
|
||||||
|
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class APSWDatabase(SqliteExtDatabase):
|
||||||
|
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
|
||||||
|
|
||||||
|
def __init__(self, database, **kwargs):
|
||||||
|
self._modules = {}
|
||||||
|
super(APSWDatabase, self).__init__(database, **kwargs)
|
||||||
|
|
||||||
|
def register_module(self, mod_name, mod_inst):
|
||||||
|
self._modules[mod_name] = mod_inst
|
||||||
|
if not self.is_closed():
|
||||||
|
self.connection().createmodule(mod_name, mod_inst)
|
||||||
|
|
||||||
|
def unregister_module(self, mod_name):
|
||||||
|
del(self._modules[mod_name])
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
conn = apsw.Connection(self.database, **self.connect_params)
|
||||||
|
if self._timeout is not None:
|
||||||
|
conn.setbusytimeout(self._timeout * 1000)
|
||||||
|
try:
|
||||||
|
self._add_conn_hooks(conn)
|
||||||
|
except:
|
||||||
|
conn.close()
|
||||||
|
raise
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _add_conn_hooks(self, conn):
|
||||||
|
super(APSWDatabase, self)._add_conn_hooks(conn)
|
||||||
|
self._load_modules(conn) # APSW-only.
|
||||||
|
|
||||||
|
def _load_modules(self, conn):
|
||||||
|
for mod_name, mod_inst in self._modules.items():
|
||||||
|
conn.createmodule(mod_name, mod_inst)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _load_aggregates(self, conn):
|
||||||
|
for name, (klass, num_params) in self._aggregates.items():
|
||||||
|
def make_aggregate():
|
||||||
|
return (klass(), klass.step, klass.finalize)
|
||||||
|
conn.createaggregatefunction(name, make_aggregate)
|
||||||
|
|
||||||
|
def _load_collations(self, conn):
|
||||||
|
for name, fn in self._collations.items():
|
||||||
|
conn.createcollation(name, fn)
|
||||||
|
|
||||||
|
def _load_functions(self, conn):
|
||||||
|
for name, (fn, num_params) in self._functions.items():
|
||||||
|
conn.createscalarfunction(name, fn, num_params)
|
||||||
|
|
||||||
|
def _load_extensions(self, conn):
|
||||||
|
conn.enableloadextension(True)
|
||||||
|
for extension in self._extensions:
|
||||||
|
conn.loadextension(extension)
|
||||||
|
|
||||||
|
def load_extension(self, extension):
|
||||||
|
self._extensions.add(extension)
|
||||||
|
if not self.is_closed():
|
||||||
|
conn = self.connection()
|
||||||
|
conn.enableloadextension(True)
|
||||||
|
conn.loadextension(extension)
|
||||||
|
|
||||||
|
def last_insert_id(self, cursor, query_type=None):
|
||||||
|
return cursor.getconnection().last_insert_rowid()
|
||||||
|
|
||||||
|
def rows_affected(self, cursor):
|
||||||
|
return cursor.getconnection().changes()
|
||||||
|
|
||||||
|
def begin(self, lock_type='deferred'):
|
||||||
|
self.cursor().execute('begin %s;' % lock_type)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
self.cursor().execute('commit;')
|
||||||
|
|
||||||
|
def rollback(self):
|
||||||
|
self.cursor().execute('rollback;')
|
||||||
|
|
||||||
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
|
logger.debug((sql, params))
|
||||||
|
with __exception_wrapper__:
|
||||||
|
cursor = self.cursor()
|
||||||
|
cursor.execute(sql, params or ())
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
|
def nh(s, v):
|
||||||
|
if v is not None:
|
||||||
|
return str(v)
|
||||||
|
|
||||||
|
class BooleanField(_BooleanField):
|
||||||
|
def db_value(self, v):
|
||||||
|
v = super(BooleanField, self).db_value(v)
|
||||||
|
if v is not None:
|
||||||
|
return v and 1 or 0
|
||||||
|
|
||||||
|
class DateField(_DateField):
|
||||||
|
db_value = nh
|
||||||
|
|
||||||
|
class TimeField(_TimeField):
|
||||||
|
db_value = nh
|
||||||
|
|
||||||
|
class DateTimeField(_DateTimeField):
|
||||||
|
db_value = nh
|
||||||
|
|
||||||
|
class DecimalField(_DecimalField):
|
||||||
|
db_value = nh
|
@ -0,0 +1,415 @@
|
|||||||
|
import csv
|
||||||
|
import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
import json
|
||||||
|
import operator
|
||||||
|
try:
|
||||||
|
from urlparse import urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from playhouse.db_url import connect
|
||||||
|
from playhouse.migrate import migrate
|
||||||
|
from playhouse.migrate import SchemaMigrator
|
||||||
|
from playhouse.reflection import Introspector
|
||||||
|
|
||||||
|
if sys.version_info[0] == 3:
|
||||||
|
basestring = str
|
||||||
|
from functools import reduce
|
||||||
|
def open_file(f, mode):
|
||||||
|
return open(f, mode, encoding='utf8')
|
||||||
|
else:
|
||||||
|
open_file = open
|
||||||
|
|
||||||
|
|
||||||
|
class DataSet(object):
|
||||||
|
def __init__(self, url, bare_fields=False):
|
||||||
|
if isinstance(url, Database):
|
||||||
|
self._url = None
|
||||||
|
self._database = url
|
||||||
|
self._database_path = self._database.database
|
||||||
|
else:
|
||||||
|
self._url = url
|
||||||
|
parse_result = urlparse(url)
|
||||||
|
self._database_path = parse_result.path[1:]
|
||||||
|
|
||||||
|
# Connect to the database.
|
||||||
|
self._database = connect(url)
|
||||||
|
|
||||||
|
self._database.connect()
|
||||||
|
|
||||||
|
# Introspect the database and generate models.
|
||||||
|
self._introspector = Introspector.from_database(self._database)
|
||||||
|
self._models = self._introspector.generate_models(
|
||||||
|
skip_invalid=True,
|
||||||
|
literal_column_names=True,
|
||||||
|
bare_fields=bare_fields)
|
||||||
|
self._migrator = SchemaMigrator.from_database(self._database)
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = self._database
|
||||||
|
self._base_model = BaseModel
|
||||||
|
self._export_formats = self.get_export_formats()
|
||||||
|
self._import_formats = self.get_import_formats()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<DataSet: %s>' % self._database_path
|
||||||
|
|
||||||
|
def get_export_formats(self):
|
||||||
|
return {
|
||||||
|
'csv': CSVExporter,
|
||||||
|
'json': JSONExporter}
|
||||||
|
|
||||||
|
def get_import_formats(self):
|
||||||
|
return {
|
||||||
|
'csv': CSVImporter,
|
||||||
|
'json': JSONImporter}
|
||||||
|
|
||||||
|
def __getitem__(self, table):
|
||||||
|
if table not in self._models and table in self.tables:
|
||||||
|
self.update_cache(table)
|
||||||
|
return Table(self, table, self._models.get(table))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tables(self):
|
||||||
|
return self._database.get_tables()
|
||||||
|
|
||||||
|
def __contains__(self, table):
|
||||||
|
return table in self.tables
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self._database.connect()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._database.close()
|
||||||
|
|
||||||
|
def update_cache(self, table=None):
|
||||||
|
if table:
|
||||||
|
dependencies = [table]
|
||||||
|
if table in self._models:
|
||||||
|
model_class = self._models[table]
|
||||||
|
dependencies.extend([
|
||||||
|
related._meta.table_name for _, related, _ in
|
||||||
|
model_class._meta.model_graph()])
|
||||||
|
else:
|
||||||
|
dependencies.extend(self.get_table_dependencies(table))
|
||||||
|
else:
|
||||||
|
dependencies = None # Update all tables.
|
||||||
|
self._models = {}
|
||||||
|
updated = self._introspector.generate_models(
|
||||||
|
skip_invalid=True,
|
||||||
|
table_names=dependencies,
|
||||||
|
literal_column_names=True)
|
||||||
|
self._models.update(updated)
|
||||||
|
|
||||||
|
def get_table_dependencies(self, table):
|
||||||
|
stack = [table]
|
||||||
|
accum = []
|
||||||
|
seen = set()
|
||||||
|
while stack:
|
||||||
|
table = stack.pop()
|
||||||
|
for fk_meta in self._database.get_foreign_keys(table):
|
||||||
|
dest = fk_meta.dest_table
|
||||||
|
if dest not in seen:
|
||||||
|
stack.append(dest)
|
||||||
|
accum.append(dest)
|
||||||
|
return accum
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if not self._database.is_closed():
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def query(self, sql, params=None, commit=True):
|
||||||
|
return self._database.execute_sql(sql, params, commit)
|
||||||
|
|
||||||
|
def transaction(self):
|
||||||
|
if self._database.transaction_depth() == 0:
|
||||||
|
return self._database.transaction()
|
||||||
|
else:
|
||||||
|
return self._database.savepoint()
|
||||||
|
|
||||||
|
def _check_arguments(self, filename, file_obj, format, format_dict):
|
||||||
|
if filename and file_obj:
|
||||||
|
raise ValueError('file is over-specified. Please use either '
|
||||||
|
'filename or file_obj, but not both.')
|
||||||
|
if not filename and not file_obj:
|
||||||
|
raise ValueError('A filename or file-like object must be '
|
||||||
|
'specified.')
|
||||||
|
if format not in format_dict:
|
||||||
|
valid_formats = ', '.join(sorted(format_dict.keys()))
|
||||||
|
raise ValueError('Unsupported format "%s". Use one of %s.' % (
|
||||||
|
format, valid_formats))
|
||||||
|
|
||||||
|
def freeze(self, query, format='csv', filename=None, file_obj=None,
|
||||||
|
**kwargs):
|
||||||
|
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||||
|
if filename:
|
||||||
|
file_obj = open_file(filename, 'w')
|
||||||
|
|
||||||
|
exporter = self._export_formats[format](query)
|
||||||
|
exporter.export(file_obj, **kwargs)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
file_obj.close()
|
||||||
|
|
||||||
|
def thaw(self, table, format='csv', filename=None, file_obj=None,
|
||||||
|
strict=False, **kwargs):
|
||||||
|
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||||
|
if filename:
|
||||||
|
file_obj = open_file(filename, 'r')
|
||||||
|
|
||||||
|
importer = self._import_formats[format](self[table], strict)
|
||||||
|
count = importer.load(file_obj, **kwargs)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
file_obj.close()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class Table(object):
|
||||||
|
def __init__(self, dataset, name, model_class):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.name = name
|
||||||
|
if model_class is None:
|
||||||
|
model_class = self._create_model()
|
||||||
|
model_class.create_table()
|
||||||
|
self.dataset._models[name] = model_class
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_class(self):
|
||||||
|
return self.dataset._models[self.name]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<Table: %s>' % self.name
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.find().count()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.find().iterator())
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
class Meta:
|
||||||
|
table_name = self.name
|
||||||
|
return type(
|
||||||
|
str(self.name),
|
||||||
|
(self.dataset._base_model,),
|
||||||
|
{'Meta': Meta})
|
||||||
|
|
||||||
|
def create_index(self, columns, unique=False):
|
||||||
|
self.dataset._database.create_index(
|
||||||
|
self.model_class,
|
||||||
|
columns,
|
||||||
|
unique=unique)
|
||||||
|
|
||||||
|
def _guess_field_type(self, value):
|
||||||
|
if isinstance(value, basestring):
|
||||||
|
return TextField
|
||||||
|
if isinstance(value, (datetime.date, datetime.datetime)):
|
||||||
|
return DateTimeField
|
||||||
|
elif value is True or value is False:
|
||||||
|
return BooleanField
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return IntegerField
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return FloatField
|
||||||
|
elif isinstance(value, Decimal):
|
||||||
|
return DecimalField
|
||||||
|
return TextField
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self):
|
||||||
|
return [f.name for f in self.model_class._meta.sorted_fields]
|
||||||
|
|
||||||
|
def _migrate_new_columns(self, data):
|
||||||
|
new_keys = set(data) - set(self.model_class._meta.fields)
|
||||||
|
if new_keys:
|
||||||
|
operations = []
|
||||||
|
for key in new_keys:
|
||||||
|
field_class = self._guess_field_type(data[key])
|
||||||
|
field = field_class(null=True)
|
||||||
|
operations.append(
|
||||||
|
self.dataset._migrator.add_column(self.name, key, field))
|
||||||
|
field.bind(self.model_class, key)
|
||||||
|
|
||||||
|
migrate(*operations)
|
||||||
|
|
||||||
|
self.dataset.update_cache(self.name)
|
||||||
|
|
||||||
|
def insert(self, **data):
|
||||||
|
self._migrate_new_columns(data)
|
||||||
|
return self.model_class.insert(**data).execute()
|
||||||
|
|
||||||
|
def _apply_where(self, query, filters, conjunction=None):
|
||||||
|
conjunction = conjunction or operator.and_
|
||||||
|
if filters:
|
||||||
|
expressions = [
|
||||||
|
(self.model_class._meta.fields[column] == value)
|
||||||
|
for column, value in filters.items()]
|
||||||
|
query = query.where(reduce(conjunction, expressions))
|
||||||
|
return query
|
||||||
|
|
||||||
|
def update(self, columns=None, conjunction=None, **data):
|
||||||
|
self._migrate_new_columns(data)
|
||||||
|
filters = {}
|
||||||
|
if columns:
|
||||||
|
for column in columns:
|
||||||
|
filters[column] = data.pop(column)
|
||||||
|
|
||||||
|
return self._apply_where(
|
||||||
|
self.model_class.update(**data),
|
||||||
|
filters,
|
||||||
|
conjunction).execute()
|
||||||
|
|
||||||
|
def _query(self, **query):
|
||||||
|
return self._apply_where(self.model_class.select(), query)
|
||||||
|
|
||||||
|
def find(self, **query):
|
||||||
|
return self._query(**query).dicts()
|
||||||
|
|
||||||
|
def find_one(self, **query):
|
||||||
|
try:
|
||||||
|
return self.find(**query).get()
|
||||||
|
except self.model_class.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
return self.find()
|
||||||
|
|
||||||
|
def delete(self, **query):
|
||||||
|
return self._apply_where(self.model_class.delete(), query).execute()
|
||||||
|
|
||||||
|
def freeze(self, *args, **kwargs):
|
||||||
|
return self.dataset.freeze(self.all(), *args, **kwargs)
|
||||||
|
|
||||||
|
def thaw(self, *args, **kwargs):
|
||||||
|
return self.dataset.thaw(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Exporter(object):
|
||||||
|
def __init__(self, query):
|
||||||
|
self.query = query
|
||||||
|
|
||||||
|
def export(self, file_obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class JSONExporter(Exporter):
|
||||||
|
def __init__(self, query, iso8601_datetimes=False):
|
||||||
|
super(JSONExporter, self).__init__(query)
|
||||||
|
self.iso8601_datetimes = iso8601_datetimes
|
||||||
|
|
||||||
|
def _make_default(self):
|
||||||
|
datetime_types = (datetime.datetime, datetime.date, datetime.time)
|
||||||
|
|
||||||
|
if self.iso8601_datetimes:
|
||||||
|
def default(o):
|
||||||
|
if isinstance(o, datetime_types):
|
||||||
|
return o.isoformat()
|
||||||
|
elif isinstance(o, Decimal):
|
||||||
|
return str(o)
|
||||||
|
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||||
|
else:
|
||||||
|
def default(o):
|
||||||
|
if isinstance(o, datetime_types + (Decimal,)):
|
||||||
|
return str(o)
|
||||||
|
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def export(self, file_obj, **kwargs):
|
||||||
|
json.dump(
|
||||||
|
list(self.query),
|
||||||
|
file_obj,
|
||||||
|
default=self._make_default(),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVExporter(Exporter):
|
||||||
|
def export(self, file_obj, header=True, **kwargs):
|
||||||
|
writer = csv.writer(file_obj, **kwargs)
|
||||||
|
tuples = self.query.tuples().execute()
|
||||||
|
tuples.initialize()
|
||||||
|
if header and getattr(tuples, 'columns', None):
|
||||||
|
writer.writerow([column for column in tuples.columns])
|
||||||
|
for row in tuples:
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
|
||||||
|
class Importer(object):
|
||||||
|
def __init__(self, table, strict=False):
|
||||||
|
self.table = table
|
||||||
|
self.strict = strict
|
||||||
|
|
||||||
|
model = self.table.model_class
|
||||||
|
self.columns = model._meta.columns
|
||||||
|
self.columns.update(model._meta.fields)
|
||||||
|
|
||||||
|
def load(self, file_obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class JSONImporter(Importer):
|
||||||
|
def load(self, file_obj, **kwargs):
|
||||||
|
data = json.load(file_obj, **kwargs)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for row in data:
|
||||||
|
if self.strict:
|
||||||
|
obj = {}
|
||||||
|
for key in row:
|
||||||
|
field = self.columns.get(key)
|
||||||
|
if field is not None:
|
||||||
|
obj[field.name] = field.python_value(row[key])
|
||||||
|
else:
|
||||||
|
obj = row
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
self.table.insert(**obj)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class CSVImporter(Importer):
|
||||||
|
def load(self, file_obj, header=True, **kwargs):
|
||||||
|
count = 0
|
||||||
|
reader = csv.reader(file_obj, **kwargs)
|
||||||
|
if header:
|
||||||
|
try:
|
||||||
|
header_keys = next(reader)
|
||||||
|
except StopIteration:
|
||||||
|
return count
|
||||||
|
|
||||||
|
if self.strict:
|
||||||
|
header_fields = []
|
||||||
|
for idx, key in enumerate(header_keys):
|
||||||
|
if key in self.columns:
|
||||||
|
header_fields.append((idx, self.columns[key]))
|
||||||
|
else:
|
||||||
|
header_fields = list(enumerate(header_keys))
|
||||||
|
else:
|
||||||
|
header_fields = list(enumerate(self.model._meta.sorted_fields))
|
||||||
|
|
||||||
|
if not header_fields:
|
||||||
|
return count
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
obj = {}
|
||||||
|
for idx, field in header_fields:
|
||||||
|
if self.strict:
|
||||||
|
obj[field.name] = field.python_value(row[idx])
|
||||||
|
else:
|
||||||
|
obj[field] = row[idx]
|
||||||
|
|
||||||
|
self.table.insert(**obj)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
@ -0,0 +1,124 @@
|
|||||||
|
try:
|
||||||
|
from urlparse import parse_qsl, unquote, urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urllib.parse import parse_qsl, unquote, urlparse
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from playhouse.pool import PooledMySQLDatabase
|
||||||
|
from playhouse.pool import PooledPostgresqlDatabase
|
||||||
|
from playhouse.pool import PooledSqliteDatabase
|
||||||
|
from playhouse.pool import PooledSqliteExtDatabase
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
|
|
||||||
|
schemes = {
|
||||||
|
'mysql': MySQLDatabase,
|
||||||
|
'mysql+pool': PooledMySQLDatabase,
|
||||||
|
'postgres': PostgresqlDatabase,
|
||||||
|
'postgresql': PostgresqlDatabase,
|
||||||
|
'postgres+pool': PooledPostgresqlDatabase,
|
||||||
|
'postgresql+pool': PooledPostgresqlDatabase,
|
||||||
|
'sqlite': SqliteDatabase,
|
||||||
|
'sqliteext': SqliteExtDatabase,
|
||||||
|
'sqlite+pool': PooledSqliteDatabase,
|
||||||
|
'sqliteext+pool': PooledSqliteExtDatabase,
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_database(db_class, *names):
|
||||||
|
global schemes
|
||||||
|
for name in names:
|
||||||
|
schemes[name] = db_class
|
||||||
|
|
||||||
|
def parseresult_to_dict(parsed, unquote_password=False):
|
||||||
|
|
||||||
|
# urlparse in python 2.6 is broken so query will be empty and instead
|
||||||
|
# appended to path complete with '?'
|
||||||
|
path_parts = parsed.path[1:].split('?')
|
||||||
|
try:
|
||||||
|
query = path_parts[1]
|
||||||
|
except IndexError:
|
||||||
|
query = parsed.query
|
||||||
|
|
||||||
|
connect_kwargs = {'database': path_parts[0]}
|
||||||
|
if parsed.username:
|
||||||
|
connect_kwargs['user'] = parsed.username
|
||||||
|
if parsed.password:
|
||||||
|
connect_kwargs['password'] = parsed.password
|
||||||
|
if unquote_password:
|
||||||
|
connect_kwargs['password'] = unquote(connect_kwargs['password'])
|
||||||
|
if parsed.hostname:
|
||||||
|
connect_kwargs['host'] = parsed.hostname
|
||||||
|
if parsed.port:
|
||||||
|
connect_kwargs['port'] = parsed.port
|
||||||
|
|
||||||
|
# Adjust parameters for MySQL.
|
||||||
|
if parsed.scheme == 'mysql' and 'password' in connect_kwargs:
|
||||||
|
connect_kwargs['passwd'] = connect_kwargs.pop('password')
|
||||||
|
elif 'sqlite' in parsed.scheme and not connect_kwargs['database']:
|
||||||
|
connect_kwargs['database'] = ':memory:'
|
||||||
|
|
||||||
|
# Get additional connection args from the query string
|
||||||
|
qs_args = parse_qsl(query, keep_blank_values=True)
|
||||||
|
for key, value in qs_args:
|
||||||
|
if value.lower() == 'false':
|
||||||
|
value = False
|
||||||
|
elif value.lower() == 'true':
|
||||||
|
value = True
|
||||||
|
elif value.isdigit():
|
||||||
|
value = int(value)
|
||||||
|
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
|
||||||
|
try:
|
||||||
|
value = float(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
elif value.lower() in ('null', 'none'):
|
||||||
|
value = None
|
||||||
|
|
||||||
|
connect_kwargs[key] = value
|
||||||
|
|
||||||
|
return connect_kwargs
|
||||||
|
|
||||||
|
def parse(url, unquote_password=False):
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return parseresult_to_dict(parsed, unquote_password)
|
||||||
|
|
||||||
|
def connect(url, unquote_password=False, **connect_params):
|
||||||
|
parsed = urlparse(url)
|
||||||
|
connect_kwargs = parseresult_to_dict(parsed, unquote_password)
|
||||||
|
connect_kwargs.update(connect_params)
|
||||||
|
database_class = schemes.get(parsed.scheme)
|
||||||
|
|
||||||
|
if database_class is None:
|
||||||
|
if database_class in schemes:
|
||||||
|
raise RuntimeError('Attempted to use "%s" but a required library '
|
||||||
|
'could not be imported.' % parsed.scheme)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unrecognized or unsupported scheme: "%s".' %
|
||||||
|
parsed.scheme)
|
||||||
|
|
||||||
|
return database_class(**connect_kwargs)
|
||||||
|
|
||||||
|
# Conditionally register additional databases.
|
||||||
|
try:
|
||||||
|
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
register_database(
|
||||||
|
PooledPostgresqlExtDatabase,
|
||||||
|
'postgresext+pool',
|
||||||
|
'postgresqlext+pool')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse.apsw_ext import APSWDatabase
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
register_database(APSWDatabase, 'apsw')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext')
|
@ -0,0 +1,64 @@
|
|||||||
|
try:
|
||||||
|
import bz2
|
||||||
|
except ImportError:
|
||||||
|
bz2 = None
|
||||||
|
try:
|
||||||
|
import zlib
|
||||||
|
except ImportError:
|
||||||
|
zlib = None
|
||||||
|
try:
|
||||||
|
import cPickle as pickle
|
||||||
|
except ImportError:
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from peewee import BlobField
|
||||||
|
from peewee import buffer_type
|
||||||
|
|
||||||
|
|
||||||
|
PY2 = sys.version_info[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedField(BlobField):
|
||||||
|
ZLIB = 'zlib'
|
||||||
|
BZ2 = 'bz2'
|
||||||
|
algorithm_to_import = {
|
||||||
|
ZLIB: zlib,
|
||||||
|
BZ2: bz2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
|
||||||
|
**kwargs):
|
||||||
|
self.compression_level = compression_level
|
||||||
|
if algorithm not in self.algorithm_to_import:
|
||||||
|
raise ValueError('Unrecognized algorithm %s' % algorithm)
|
||||||
|
compress_module = self.algorithm_to_import[algorithm]
|
||||||
|
if compress_module is None:
|
||||||
|
raise ValueError('Missing library required for %s.' % algorithm)
|
||||||
|
|
||||||
|
self.algorithm = algorithm
|
||||||
|
self.compress = compress_module.compress
|
||||||
|
self.decompress = compress_module.decompress
|
||||||
|
super(CompressedField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def python_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
return self.decompress(value)
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
return self._constructor(
|
||||||
|
self.compress(value, self.compression_level))
|
||||||
|
|
||||||
|
|
||||||
|
class PickleField(BlobField):
|
||||||
|
def python_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
if isinstance(value, buffer_type):
|
||||||
|
value = bytes(value)
|
||||||
|
return pickle.loads(value)
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
|
||||||
|
return self._constructor(pickled)
|
@ -0,0 +1,185 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from flask import abort
|
||||||
|
from flask import render_template
|
||||||
|
from flask import request
|
||||||
|
from peewee import Database
|
||||||
|
from peewee import DoesNotExist
|
||||||
|
from peewee import Model
|
||||||
|
from peewee import Proxy
|
||||||
|
from peewee import SelectQuery
|
||||||
|
from playhouse.db_url import connect as db_url_connect
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedQuery(object):
|
||||||
|
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
|
||||||
|
check_bounds=False):
|
||||||
|
self.paginate_by = paginate_by
|
||||||
|
self.page_var = page_var
|
||||||
|
self.page = page or None
|
||||||
|
self.check_bounds = check_bounds
|
||||||
|
|
||||||
|
if isinstance(query_or_model, SelectQuery):
|
||||||
|
self.query = query_or_model
|
||||||
|
self.model = self.query.model
|
||||||
|
else:
|
||||||
|
self.model = query_or_model
|
||||||
|
self.query = self.model.select()
|
||||||
|
|
||||||
|
def get_page(self):
|
||||||
|
if self.page is not None:
|
||||||
|
return self.page
|
||||||
|
|
||||||
|
curr_page = request.args.get(self.page_var)
|
||||||
|
if curr_page and curr_page.isdigit():
|
||||||
|
return max(1, int(curr_page))
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_page_count(self):
|
||||||
|
if not hasattr(self, '_page_count'):
|
||||||
|
self._page_count = int(math.ceil(
|
||||||
|
float(self.query.count()) / self.paginate_by))
|
||||||
|
return self._page_count
|
||||||
|
|
||||||
|
def get_object_list(self):
|
||||||
|
if self.check_bounds and self.get_page() > self.get_page_count():
|
||||||
|
abort(404)
|
||||||
|
return self.query.paginate(self.get_page(), self.paginate_by)
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_or_404(query_or_model, *query):
|
||||||
|
if not isinstance(query_or_model, SelectQuery):
|
||||||
|
query_or_model = query_or_model.select()
|
||||||
|
try:
|
||||||
|
return query_or_model.where(*query).get()
|
||||||
|
except DoesNotExist:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
def object_list(template_name, query, context_variable='object_list',
|
||||||
|
paginate_by=20, page_var='page', page=None, check_bounds=True,
|
||||||
|
**kwargs):
|
||||||
|
paginated_query = PaginatedQuery(
|
||||||
|
query,
|
||||||
|
paginate_by=paginate_by,
|
||||||
|
page_var=page_var,
|
||||||
|
page=page,
|
||||||
|
check_bounds=check_bounds)
|
||||||
|
kwargs[context_variable] = paginated_query.get_object_list()
|
||||||
|
return render_template(
|
||||||
|
template_name,
|
||||||
|
pagination=paginated_query,
|
||||||
|
page=paginated_query.get_page(),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def get_current_url():
|
||||||
|
if not request.query_string:
|
||||||
|
return request.path
|
||||||
|
return '%s?%s' % (request.path, request.query_string)
|
||||||
|
|
||||||
|
def get_next_url(default='/'):
|
||||||
|
if request.args.get('next'):
|
||||||
|
return request.args['next']
|
||||||
|
elif request.form.get('next'):
|
||||||
|
return request.form['next']
|
||||||
|
return default
|
||||||
|
|
||||||
|
class FlaskDB(object):
|
||||||
|
def __init__(self, app=None, database=None, model_class=Model):
|
||||||
|
self.database = None # Reference to actual Peewee database instance.
|
||||||
|
self.base_model_class = model_class
|
||||||
|
self._app = app
|
||||||
|
self._db = database # dict, url, Database, or None (default).
|
||||||
|
if app is not None:
|
||||||
|
self.init_app(app)
|
||||||
|
|
||||||
|
def init_app(self, app):
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
if self._db is None:
|
||||||
|
if 'DATABASE' in app.config:
|
||||||
|
initial_db = app.config['DATABASE']
|
||||||
|
elif 'DATABASE_URL' in app.config:
|
||||||
|
initial_db = app.config['DATABASE_URL']
|
||||||
|
else:
|
||||||
|
raise ValueError('Missing required configuration data for '
|
||||||
|
'database: DATABASE or DATABASE_URL.')
|
||||||
|
else:
|
||||||
|
initial_db = self._db
|
||||||
|
|
||||||
|
self._load_database(app, initial_db)
|
||||||
|
self._register_handlers(app)
|
||||||
|
|
||||||
|
def _load_database(self, app, config_value):
|
||||||
|
if isinstance(config_value, Database):
|
||||||
|
database = config_value
|
||||||
|
elif isinstance(config_value, dict):
|
||||||
|
database = self._load_from_config_dict(dict(config_value))
|
||||||
|
else:
|
||||||
|
# Assume a database connection URL.
|
||||||
|
database = db_url_connect(config_value)
|
||||||
|
|
||||||
|
if isinstance(self.database, Proxy):
|
||||||
|
self.database.initialize(database)
|
||||||
|
else:
|
||||||
|
self.database = database
|
||||||
|
|
||||||
|
def _load_from_config_dict(self, config_dict):
|
||||||
|
try:
|
||||||
|
name = config_dict.pop('name')
|
||||||
|
engine = config_dict.pop('engine')
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError('DATABASE configuration must specify a '
|
||||||
|
'`name` and `engine`.')
|
||||||
|
|
||||||
|
if '.' in engine:
|
||||||
|
path, class_name = engine.rsplit('.', 1)
|
||||||
|
else:
|
||||||
|
path, class_name = 'peewee', engine
|
||||||
|
|
||||||
|
try:
|
||||||
|
__import__(path)
|
||||||
|
module = sys.modules[path]
|
||||||
|
database_class = getattr(module, class_name)
|
||||||
|
assert issubclass(database_class, Database)
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('Unable to import %s' % engine)
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError('Database engine not found %s' % engine)
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError('Database engine not a subclass of '
|
||||||
|
'peewee.Database: %s' % engine)
|
||||||
|
|
||||||
|
return database_class(name, **config_dict)
|
||||||
|
|
||||||
|
def _register_handlers(self, app):
|
||||||
|
app.before_request(self.connect_db)
|
||||||
|
app.teardown_request(self.close_db)
|
||||||
|
|
||||||
|
def get_model_class(self):
|
||||||
|
if self.database is None:
|
||||||
|
raise RuntimeError('Database must be initialized.')
|
||||||
|
|
||||||
|
class BaseModel(self.base_model_class):
|
||||||
|
class Meta:
|
||||||
|
database = self.database
|
||||||
|
|
||||||
|
return BaseModel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Model(self):
|
||||||
|
if self._app is None:
|
||||||
|
database = getattr(self, 'database', None)
|
||||||
|
if database is None:
|
||||||
|
self.database = Proxy()
|
||||||
|
|
||||||
|
if not hasattr(self, '_model_class'):
|
||||||
|
self._model_class = self.get_model_class()
|
||||||
|
return self._model_class
|
||||||
|
|
||||||
|
def connect_db(self):
|
||||||
|
self.database.connect()
|
||||||
|
|
||||||
|
def close_db(self, exc):
|
||||||
|
if not self.database.is_closed():
|
||||||
|
self.database.close()
|
@ -0,0 +1,50 @@
|
|||||||
|
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
|
||||||
|
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
|
||||||
|
class hybrid_method(object):
|
||||||
|
def __init__(self, func, expr=None):
|
||||||
|
self.func = func
|
||||||
|
self.expr = expr or func
|
||||||
|
|
||||||
|
def __get__(self, instance, instance_type):
|
||||||
|
if instance is None:
|
||||||
|
return self.expr.__get__(instance_type, instance_type.__class__)
|
||||||
|
return self.func.__get__(instance, instance_type)
|
||||||
|
|
||||||
|
def expression(self, expr):
|
||||||
|
self.expr = expr
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class hybrid_property(object):
|
||||||
|
def __init__(self, fget, fset=None, fdel=None, expr=None):
|
||||||
|
self.fget = fget
|
||||||
|
self.fset = fset
|
||||||
|
self.fdel = fdel
|
||||||
|
self.expr = expr or fget
|
||||||
|
|
||||||
|
def __get__(self, instance, instance_type):
|
||||||
|
if instance is None:
|
||||||
|
return self.expr(instance_type)
|
||||||
|
return self.fget(instance)
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
if self.fset is None:
|
||||||
|
raise AttributeError('Cannot set attribute.')
|
||||||
|
self.fset(instance, value)
|
||||||
|
|
||||||
|
def __delete__(self, instance):
|
||||||
|
if self.fdel is None:
|
||||||
|
raise AttributeError('Cannot delete attribute.')
|
||||||
|
self.fdel(instance)
|
||||||
|
|
||||||
|
def setter(self, fset):
|
||||||
|
self.fset = fset
|
||||||
|
return self
|
||||||
|
|
||||||
|
def deleter(self, fdel):
|
||||||
|
self.fdel = fdel
|
||||||
|
return self
|
||||||
|
|
||||||
|
def expression(self, expr):
|
||||||
|
self.expr = expr
|
||||||
|
return self
|
@ -0,0 +1,172 @@
|
|||||||
|
import operator
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from peewee import Expression
|
||||||
|
from playhouse.fields import PickleField
|
||||||
|
try:
|
||||||
|
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
|
||||||
|
except ImportError:
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
|
|
||||||
|
Sentinel = type('Sentinel', (object,), {})
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValue(object):
|
||||||
|
"""
|
||||||
|
Persistent dictionary.
|
||||||
|
|
||||||
|
:param Field key_field: field to use for key. Defaults to CharField.
|
||||||
|
:param Field value_field: field to use for value. Defaults to PickleField.
|
||||||
|
:param bool ordered: data should be returned in key-sorted order.
|
||||||
|
:param Database database: database where key/value data is stored.
|
||||||
|
:param str table_name: table name for data.
|
||||||
|
"""
|
||||||
|
def __init__(self, key_field=None, value_field=None, ordered=False,
|
||||||
|
database=None, table_name='keyvalue'):
|
||||||
|
if key_field is None:
|
||||||
|
key_field = CharField(max_length=255, primary_key=True)
|
||||||
|
if not key_field.primary_key:
|
||||||
|
raise ValueError('key_field must have primary_key=True.')
|
||||||
|
|
||||||
|
if value_field is None:
|
||||||
|
value_field = PickleField()
|
||||||
|
|
||||||
|
self._key_field = key_field
|
||||||
|
self._value_field = value_field
|
||||||
|
self._ordered = ordered
|
||||||
|
self._database = database or SqliteExtDatabase(':memory:')
|
||||||
|
self._table_name = table_name
|
||||||
|
if isinstance(self._database, PostgresqlDatabase):
|
||||||
|
self.upsert = self._postgres_upsert
|
||||||
|
self.update = self._postgres_update
|
||||||
|
else:
|
||||||
|
self.upsert = self._upsert
|
||||||
|
self.update = self._update
|
||||||
|
|
||||||
|
self.model = self.create_model()
|
||||||
|
self.key = self.model.key
|
||||||
|
self.value = self.model.value
|
||||||
|
|
||||||
|
# Ensure table exists.
|
||||||
|
self.model.create_table()
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
class KeyValue(Model):
|
||||||
|
key = self._key_field
|
||||||
|
value = self._value_field
|
||||||
|
class Meta:
|
||||||
|
database = self._database
|
||||||
|
table_name = self._table_name
|
||||||
|
return KeyValue
|
||||||
|
|
||||||
|
def query(self, *select):
|
||||||
|
query = self.model.select(*select).tuples()
|
||||||
|
if self._ordered:
|
||||||
|
query = query.order_by(self.key)
|
||||||
|
return query
|
||||||
|
|
||||||
|
def convert_expression(self, expr):
|
||||||
|
if not isinstance(expr, Expression):
|
||||||
|
return (self.key == expr), True
|
||||||
|
return expr, False
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
expr, _ = self.convert_expression(key)
|
||||||
|
return self.model.select().where(expr).exists()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.model)
|
||||||
|
|
||||||
|
def __getitem__(self, expr):
|
||||||
|
converted, is_single = self.convert_expression(expr)
|
||||||
|
query = self.query(self.value).where(converted)
|
||||||
|
item_getter = operator.itemgetter(0)
|
||||||
|
result = [item_getter(row) for row in query]
|
||||||
|
if len(result) == 0 and is_single:
|
||||||
|
raise KeyError(expr)
|
||||||
|
elif is_single:
|
||||||
|
return result[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _upsert(self, key, value):
|
||||||
|
(self.model
|
||||||
|
.insert(key=key, value=value)
|
||||||
|
.on_conflict('replace')
|
||||||
|
.execute())
|
||||||
|
|
||||||
|
def _postgres_upsert(self, key, value):
|
||||||
|
(self.model
|
||||||
|
.insert(key=key, value=value)
|
||||||
|
.on_conflict(conflict_target=[self.key],
|
||||||
|
preserve=[self.value])
|
||||||
|
.execute())
|
||||||
|
|
||||||
|
def __setitem__(self, expr, value):
|
||||||
|
if isinstance(expr, Expression):
|
||||||
|
self.model.update(value=value).where(expr).execute()
|
||||||
|
else:
|
||||||
|
self.upsert(expr, value)
|
||||||
|
|
||||||
|
def __delitem__(self, expr):
|
||||||
|
converted, _ = self.convert_expression(expr)
|
||||||
|
self.model.delete().where(converted).execute()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.query().execute())
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return map(operator.itemgetter(0), self.query(self.key))
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return map(operator.itemgetter(0), self.query(self.value))
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return iter(self.query().execute())
|
||||||
|
|
||||||
|
def _update(self, __data=None, **mapping):
|
||||||
|
if __data is not None:
|
||||||
|
mapping.update(__data)
|
||||||
|
return (self.model
|
||||||
|
.insert_many(list(mapping.items()),
|
||||||
|
fields=[self.key, self.value])
|
||||||
|
.on_conflict('replace')
|
||||||
|
.execute())
|
||||||
|
|
||||||
|
def _postgres_update(self, __data=None, **mapping):
|
||||||
|
if __data is not None:
|
||||||
|
mapping.update(__data)
|
||||||
|
return (self.model
|
||||||
|
.insert_many(list(mapping.items()),
|
||||||
|
fields=[self.key, self.value])
|
||||||
|
.on_conflict(conflict_target=[self.key],
|
||||||
|
preserve=[self.value])
|
||||||
|
.execute())
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def setdefault(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
self[key] = default
|
||||||
|
return default
|
||||||
|
|
||||||
|
def pop(self, key, default=Sentinel):
|
||||||
|
with self._database.atomic():
|
||||||
|
try:
|
||||||
|
result = self[key]
|
||||||
|
except KeyError:
|
||||||
|
if default is Sentinel:
|
||||||
|
raise
|
||||||
|
return default
|
||||||
|
del self[key]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.model.delete().execute()
|
@ -0,0 +1,823 @@
|
|||||||
|
"""
|
||||||
|
Lightweight schema migrations.
|
||||||
|
|
||||||
|
NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some
|
||||||
|
features.
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Instantiate a migrator:
|
||||||
|
|
||||||
|
# Postgres example:
|
||||||
|
my_db = PostgresqlDatabase(...)
|
||||||
|
migrator = PostgresqlMigrator(my_db)
|
||||||
|
|
||||||
|
# SQLite example:
|
||||||
|
my_db = SqliteDatabase('my_database.db')
|
||||||
|
migrator = SqliteMigrator(my_db)
|
||||||
|
|
||||||
|
Then you will use the `migrate` function to run various `Operation`s which
|
||||||
|
are generated by the migrator:
|
||||||
|
|
||||||
|
migrate(
|
||||||
|
migrator.add_column('some_table', 'column_name', CharField(default=''))
|
||||||
|
)
|
||||||
|
|
||||||
|
Migrations are not run inside a transaction, so if you wish the migration to
|
||||||
|
run in a transaction you will need to wrap the call to `migrate` in a
|
||||||
|
transaction block, e.g.:
|
||||||
|
|
||||||
|
with my_db.transaction():
|
||||||
|
migrate(...)
|
||||||
|
|
||||||
|
Supported Operations
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
Add new field(s) to an existing model:
|
||||||
|
|
||||||
|
# Create your field instances. For non-null fields you must specify a
|
||||||
|
# default value.
|
||||||
|
pubdate_field = DateTimeField(null=True)
|
||||||
|
comment_field = TextField(default='')
|
||||||
|
|
||||||
|
# Run the migration, specifying the database table, field name and field.
|
||||||
|
migrate(
|
||||||
|
migrator.add_column('comment_tbl', 'pub_date', pubdate_field),
|
||||||
|
migrator.add_column('comment_tbl', 'comment', comment_field),
|
||||||
|
)
|
||||||
|
|
||||||
|
Renaming a field:
|
||||||
|
|
||||||
|
# Specify the table, original name of the column, and its new name.
|
||||||
|
migrate(
|
||||||
|
migrator.rename_column('story', 'pub_date', 'publish_date'),
|
||||||
|
migrator.rename_column('story', 'mod_date', 'modified_date'),
|
||||||
|
)
|
||||||
|
|
||||||
|
Dropping a field:
|
||||||
|
|
||||||
|
migrate(
|
||||||
|
migrator.drop_column('story', 'some_old_field'),
|
||||||
|
)
|
||||||
|
|
||||||
|
Making a field nullable or not nullable:
|
||||||
|
|
||||||
|
# Note that when making a field not null that field must not have any
|
||||||
|
# NULL values present.
|
||||||
|
migrate(
|
||||||
|
# Make `pub_date` allow NULL values.
|
||||||
|
migrator.drop_not_null('story', 'pub_date'),
|
||||||
|
|
||||||
|
# Prevent `modified_date` from containing NULL values.
|
||||||
|
migrator.add_not_null('story', 'modified_date'),
|
||||||
|
)
|
||||||
|
|
||||||
|
Renaming a table:
|
||||||
|
|
||||||
|
migrate(
|
||||||
|
migrator.rename_table('story', 'stories_tbl'),
|
||||||
|
)
|
||||||
|
|
||||||
|
Adding an index:
|
||||||
|
|
||||||
|
# Specify the table, column names, and whether the index should be
|
||||||
|
# UNIQUE or not.
|
||||||
|
migrate(
|
||||||
|
# Create an index on the `pub_date` column.
|
||||||
|
migrator.add_index('story', ('pub_date',), False),
|
||||||
|
|
||||||
|
# Create a multi-column index on the `pub_date` and `status` fields.
|
||||||
|
migrator.add_index('story', ('pub_date', 'status'), False),
|
||||||
|
|
||||||
|
# Create a unique index on the category and title fields.
|
||||||
|
migrator.add_index('story', ('category_id', 'title'), True),
|
||||||
|
)
|
||||||
|
|
||||||
|
Dropping an index:
|
||||||
|
|
||||||
|
# Specify the index name.
|
||||||
|
migrate(migrator.drop_index('story', 'story_pub_date_status'))
|
||||||
|
|
||||||
|
Adding or dropping table constraints:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Add a CHECK() constraint to enforce the price cannot be negative.
|
||||||
|
migrate(migrator.add_constraint(
|
||||||
|
'products',
|
||||||
|
'price_check',
|
||||||
|
Check('price >= 0')))
|
||||||
|
|
||||||
|
# Remove the price check constraint.
|
||||||
|
migrate(migrator.drop_constraint('products', 'price_check'))
|
||||||
|
|
||||||
|
# Add a UNIQUE constraint on the first and last names.
|
||||||
|
migrate(migrator.add_unique('person', 'first_name', 'last_name'))
|
||||||
|
"""
|
||||||
|
from collections import namedtuple
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from peewee import CommaNodeList
|
||||||
|
from peewee import EnclosedNodeList
|
||||||
|
from peewee import Entity
|
||||||
|
from peewee import Expression
|
||||||
|
from peewee import Node
|
||||||
|
from peewee import NodeList
|
||||||
|
from peewee import OP
|
||||||
|
from peewee import callable_
|
||||||
|
from peewee import sort_models
|
||||||
|
from peewee import _truncate_constraint_name
|
||||||
|
|
||||||
|
|
||||||
|
class Operation(object):
|
||||||
|
"""Encapsulate a single schema altering operation."""
|
||||||
|
def __init__(self, migrator, method, *args, **kwargs):
|
||||||
|
self.migrator = migrator
|
||||||
|
self.method = method
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def execute(self, node):
|
||||||
|
self.migrator.database.execute(node)
|
||||||
|
|
||||||
|
def _handle_result(self, result):
|
||||||
|
if isinstance(result, (Node, Context)):
|
||||||
|
self.execute(result)
|
||||||
|
elif isinstance(result, Operation):
|
||||||
|
result.run()
|
||||||
|
elif isinstance(result, (list, tuple)):
|
||||||
|
for item in result:
|
||||||
|
self._handle_result(item)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
kwargs = self.kwargs.copy()
|
||||||
|
kwargs['with_context'] = True
|
||||||
|
method = getattr(self.migrator, self.method)
|
||||||
|
self._handle_result(method(*self.args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def operation(fn):
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
with_context = kwargs.pop('with_context', False)
|
||||||
|
if with_context:
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
return Operation(self, fn.__name__, *args, **kwargs)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def make_index_name(table_name, columns):
|
||||||
|
index_name = '_'.join((table_name,) + tuple(columns))
|
||||||
|
if len(index_name) > 64:
|
||||||
|
index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest()
|
||||||
|
index_name = '%s_%s' % (index_name[:56], index_hash[:7])
|
||||||
|
return index_name
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaMigrator(object):
|
||||||
|
explicit_create_foreign_key = False
|
||||||
|
explicit_delete_foreign_key = False
|
||||||
|
|
||||||
|
def __init__(self, database):
|
||||||
|
self.database = database
|
||||||
|
|
||||||
|
def make_context(self):
|
||||||
|
return self.database.get_sql_context()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_database(cls, database):
|
||||||
|
if isinstance(database, PostgresqlDatabase):
|
||||||
|
return PostgresqlMigrator(database)
|
||||||
|
elif isinstance(database, MySQLDatabase):
|
||||||
|
return MySQLMigrator(database)
|
||||||
|
elif isinstance(database, SqliteDatabase):
|
||||||
|
return SqliteMigrator(database)
|
||||||
|
raise ValueError('Unsupported database: %s' % database)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def apply_default(self, table, column_name, field):
|
||||||
|
default = field.default
|
||||||
|
if callable_(default):
|
||||||
|
default = default()
|
||||||
|
|
||||||
|
return (self.make_context()
|
||||||
|
.literal('UPDATE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' SET ')
|
||||||
|
.sql(Expression(
|
||||||
|
Entity(column_name),
|
||||||
|
OP.EQ,
|
||||||
|
field.db_value(default),
|
||||||
|
flat=True)))
|
||||||
|
|
||||||
|
def _alter_table(self, ctx, table):
|
||||||
|
return ctx.literal('ALTER TABLE ').sql(Entity(table))
|
||||||
|
|
||||||
|
def _alter_column(self, ctx, table, column):
|
||||||
|
return (self
|
||||||
|
._alter_table(ctx, table)
|
||||||
|
.literal(' ALTER COLUMN ')
|
||||||
|
.sql(Entity(column)))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def alter_add_column(self, table, column_name, field):
|
||||||
|
# Make field null at first.
|
||||||
|
ctx = self.make_context()
|
||||||
|
field_null, field.null = field.null, True
|
||||||
|
field.name = field.column_name = column_name
|
||||||
|
(self
|
||||||
|
._alter_table(ctx, table)
|
||||||
|
.literal(' ADD COLUMN ')
|
||||||
|
.sql(field.ddl(ctx)))
|
||||||
|
|
||||||
|
field.null = field_null
|
||||||
|
if isinstance(field, ForeignKeyField):
|
||||||
|
self.add_inline_fk_sql(ctx, field)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_constraint(self, table, name, constraint):
|
||||||
|
return (self
|
||||||
|
._alter_table(self.make_context(), table)
|
||||||
|
.literal(' ADD CONSTRAINT ')
|
||||||
|
.sql(Entity(name))
|
||||||
|
.literal(' ')
|
||||||
|
.sql(constraint))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_unique(self, table, *column_names):
|
||||||
|
constraint_name = 'uniq_%s' % '_'.join(column_names)
|
||||||
|
constraint = NodeList((
|
||||||
|
SQL('UNIQUE'),
|
||||||
|
EnclosedNodeList([Entity(column) for column in column_names])))
|
||||||
|
return self.add_constraint(table, constraint_name, constraint)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_constraint(self, table, name):
|
||||||
|
return (self
|
||||||
|
._alter_table(self.make_context(), table)
|
||||||
|
.literal(' DROP CONSTRAINT ')
|
||||||
|
.sql(Entity(name)))
|
||||||
|
|
||||||
|
def add_inline_fk_sql(self, ctx, field):
|
||||||
|
ctx = (ctx
|
||||||
|
.literal(' REFERENCES ')
|
||||||
|
.sql(Entity(field.rel_model._meta.table_name))
|
||||||
|
.literal(' ')
|
||||||
|
.sql(EnclosedNodeList((Entity(field.rel_field.column_name),))))
|
||||||
|
if field.on_delete is not None:
|
||||||
|
ctx = ctx.literal(' ON DELETE %s' % field.on_delete)
|
||||||
|
if field.on_update is not None:
|
||||||
|
ctx = ctx.literal(' ON UPDATE %s' % field.on_update)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_foreign_key_constraint(self, table, column_name, rel, rel_column,
|
||||||
|
on_delete=None, on_update=None):
|
||||||
|
constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel)
|
||||||
|
ctx = (self
|
||||||
|
.make_context()
|
||||||
|
.literal('ALTER TABLE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' ADD CONSTRAINT ')
|
||||||
|
.sql(Entity(_truncate_constraint_name(constraint)))
|
||||||
|
.literal(' FOREIGN KEY ')
|
||||||
|
.sql(EnclosedNodeList((Entity(column_name),)))
|
||||||
|
.literal(' REFERENCES ')
|
||||||
|
.sql(Entity(rel))
|
||||||
|
.literal(' (')
|
||||||
|
.sql(Entity(rel_column))
|
||||||
|
.literal(')'))
|
||||||
|
if on_delete is not None:
|
||||||
|
ctx = ctx.literal(' ON DELETE %s' % on_delete)
|
||||||
|
if on_update is not None:
|
||||||
|
ctx = ctx.literal(' ON UPDATE %s' % on_update)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_column(self, table, column_name, field):
|
||||||
|
# Adding a column is complicated by the fact that if there are rows
|
||||||
|
# present and the field is non-null, then we need to first add the
|
||||||
|
# column as a nullable field, then set the value, then add a not null
|
||||||
|
# constraint.
|
||||||
|
if not field.null and field.default is None:
|
||||||
|
raise ValueError('%s is not null but has no default' % column_name)
|
||||||
|
|
||||||
|
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||||
|
if is_foreign_key and not field.rel_field:
|
||||||
|
raise ValueError('Foreign keys must specify a `field`.')
|
||||||
|
|
||||||
|
operations = [self.alter_add_column(table, column_name, field)]
|
||||||
|
|
||||||
|
# In the event the field is *not* nullable, update with the default
|
||||||
|
# value and set not null.
|
||||||
|
if not field.null:
|
||||||
|
operations.extend([
|
||||||
|
self.apply_default(table, column_name, field),
|
||||||
|
self.add_not_null(table, column_name)])
|
||||||
|
|
||||||
|
if is_foreign_key and self.explicit_create_foreign_key:
|
||||||
|
operations.append(
|
||||||
|
self.add_foreign_key_constraint(
|
||||||
|
table,
|
||||||
|
column_name,
|
||||||
|
field.rel_model._meta.table_name,
|
||||||
|
field.rel_field.column_name,
|
||||||
|
field.on_delete,
|
||||||
|
field.on_update))
|
||||||
|
|
||||||
|
if field.index or field.unique:
|
||||||
|
using = getattr(field, 'index_type', None)
|
||||||
|
operations.append(self.add_index(table, (column_name,),
|
||||||
|
field.unique, using))
|
||||||
|
|
||||||
|
return operations
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_foreign_key_constraint(self, table, column_name):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_column(self, table, column_name, cascade=True):
|
||||||
|
ctx = self.make_context()
|
||||||
|
(self._alter_table(ctx, table)
|
||||||
|
.literal(' DROP COLUMN ')
|
||||||
|
.sql(Entity(column_name)))
|
||||||
|
|
||||||
|
if cascade:
|
||||||
|
ctx.literal(' CASCADE')
|
||||||
|
|
||||||
|
fk_columns = [
|
||||||
|
foreign_key.column
|
||||||
|
for foreign_key in self.database.get_foreign_keys(table)]
|
||||||
|
if column_name in fk_columns and self.explicit_delete_foreign_key:
|
||||||
|
return [self.drop_foreign_key_constraint(table, column_name), ctx]
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_column(self, table, old_name, new_name):
|
||||||
|
return (self
|
||||||
|
._alter_table(self.make_context(), table)
|
||||||
|
.literal(' RENAME COLUMN ')
|
||||||
|
.sql(Entity(old_name))
|
||||||
|
.literal(' TO ')
|
||||||
|
.sql(Entity(new_name)))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_not_null(self, table, column):
|
||||||
|
return (self
|
||||||
|
._alter_column(self.make_context(), table, column)
|
||||||
|
.literal(' SET NOT NULL'))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_not_null(self, table, column):
|
||||||
|
return (self
|
||||||
|
._alter_column(self.make_context(), table, column)
|
||||||
|
.literal(' DROP NOT NULL'))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_table(self, old_name, new_name):
|
||||||
|
return (self
|
||||||
|
._alter_table(self.make_context(), old_name)
|
||||||
|
.literal(' RENAME TO ')
|
||||||
|
.sql(Entity(new_name)))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_index(self, table, columns, unique=False, using=None):
|
||||||
|
ctx = self.make_context()
|
||||||
|
index_name = make_index_name(table, columns)
|
||||||
|
table_obj = Table(table)
|
||||||
|
cols = [getattr(table_obj.c, column) for column in columns]
|
||||||
|
index = Index(index_name, table_obj, cols, unique=unique, using=using)
|
||||||
|
return ctx.sql(index)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_index(self, table, index_name):
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('DROP INDEX ')
|
||||||
|
.sql(Entity(index_name)))
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresqlMigrator(SchemaMigrator):
|
||||||
|
def _primary_key_columns(self, tbl):
|
||||||
|
query = """
|
||||||
|
SELECT pg_attribute.attname
|
||||||
|
FROM pg_index, pg_class, pg_attribute
|
||||||
|
WHERE
|
||||||
|
pg_class.oid = '%s'::regclass AND
|
||||||
|
indrelid = pg_class.oid AND
|
||||||
|
pg_attribute.attrelid = pg_class.oid AND
|
||||||
|
pg_attribute.attnum = any(pg_index.indkey) AND
|
||||||
|
indisprimary;
|
||||||
|
"""
|
||||||
|
cursor = self.database.execute_sql(query % tbl)
|
||||||
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def set_search_path(self, schema_name):
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('SET search_path TO %s' % schema_name))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_table(self, old_name, new_name):
|
||||||
|
pk_names = self._primary_key_columns(old_name)
|
||||||
|
ParentClass = super(PostgresqlMigrator, self)
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
ParentClass.rename_table(old_name, new_name, with_context=True)]
|
||||||
|
|
||||||
|
if len(pk_names) == 1:
|
||||||
|
# Check for existence of primary key sequence.
|
||||||
|
seq_name = '%s_%s_seq' % (old_name, pk_names[0])
|
||||||
|
query = """
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.sequences
|
||||||
|
WHERE LOWER(sequence_name) = LOWER(%s)
|
||||||
|
"""
|
||||||
|
cursor = self.database.execute_sql(query, (seq_name,))
|
||||||
|
if bool(cursor.fetchone()):
|
||||||
|
new_seq_name = '%s_%s_seq' % (new_name, pk_names[0])
|
||||||
|
operations.append(ParentClass.rename_table(
|
||||||
|
seq_name, new_seq_name))
|
||||||
|
|
||||||
|
return operations
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk',
|
||||||
|
'default', 'extra'))):
|
||||||
|
@property
|
||||||
|
def is_pk(self):
|
||||||
|
return self.pk == 'PRI'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_unique(self):
|
||||||
|
return self.pk == 'UNI'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_null(self):
|
||||||
|
return self.null == 'YES'
|
||||||
|
|
||||||
|
def sql(self, column_name=None, is_null=None):
|
||||||
|
if is_null is None:
|
||||||
|
is_null = self.is_null
|
||||||
|
if column_name is None:
|
||||||
|
column_name = self.name
|
||||||
|
parts = [
|
||||||
|
Entity(column_name),
|
||||||
|
SQL(self.definition)]
|
||||||
|
if self.is_unique:
|
||||||
|
parts.append(SQL('UNIQUE'))
|
||||||
|
if is_null:
|
||||||
|
parts.append(SQL('NULL'))
|
||||||
|
else:
|
||||||
|
parts.append(SQL('NOT NULL'))
|
||||||
|
if self.is_pk:
|
||||||
|
parts.append(SQL('PRIMARY KEY'))
|
||||||
|
if self.extra:
|
||||||
|
parts.append(SQL(self.extra))
|
||||||
|
return NodeList(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLMigrator(SchemaMigrator):
|
||||||
|
explicit_create_foreign_key = True
|
||||||
|
explicit_delete_foreign_key = True
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_table(self, old_name, new_name):
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('RENAME TABLE ')
|
||||||
|
.sql(Entity(old_name))
|
||||||
|
.literal(' TO ')
|
||||||
|
.sql(Entity(new_name)))
|
||||||
|
|
||||||
|
def _get_column_definition(self, table, column_name):
|
||||||
|
cursor = self.database.execute_sql('DESCRIBE `%s`;' % table)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
for row in rows:
|
||||||
|
column = MySQLColumn(*row)
|
||||||
|
if column.name == column_name:
|
||||||
|
return column
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_foreign_key_constraint(self, table, column_name):
|
||||||
|
cursor = self.database.execute_sql(
|
||||||
|
('SELECT constraint_name '
|
||||||
|
'FROM information_schema.key_column_usage WHERE '
|
||||||
|
'table_schema = DATABASE() AND '
|
||||||
|
'table_name = %s AND '
|
||||||
|
'column_name = %s AND '
|
||||||
|
'referenced_table_name IS NOT NULL AND '
|
||||||
|
'referenced_column_name IS NOT NULL;'),
|
||||||
|
(table, column_name))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if not result:
|
||||||
|
raise AttributeError(
|
||||||
|
'Unable to find foreign key constraint for '
|
||||||
|
'"%s" on table "%s".' % (table, column_name))
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_foreign_key_constraint(self, table, column_name):
|
||||||
|
fk_constraint = self.get_foreign_key_constraint(table, column_name)
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('ALTER TABLE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' DROP FOREIGN KEY ')
|
||||||
|
.sql(Entity(fk_constraint)))
|
||||||
|
|
||||||
|
def add_inline_fk_sql(self, ctx, field):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_not_null(self, table, column):
|
||||||
|
column_def = self._get_column_definition(table, column)
|
||||||
|
add_not_null = (self
|
||||||
|
.make_context()
|
||||||
|
.literal('ALTER TABLE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' MODIFY ')
|
||||||
|
.sql(column_def.sql(is_null=False)))
|
||||||
|
|
||||||
|
fk_objects = dict(
|
||||||
|
(fk.column, fk)
|
||||||
|
for fk in self.database.get_foreign_keys(table))
|
||||||
|
if column not in fk_objects:
|
||||||
|
return add_not_null
|
||||||
|
|
||||||
|
fk_metadata = fk_objects[column]
|
||||||
|
return (self.drop_foreign_key_constraint(table, column),
|
||||||
|
add_not_null,
|
||||||
|
self.add_foreign_key_constraint(
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
fk_metadata.dest_table,
|
||||||
|
fk_metadata.dest_column))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_not_null(self, table, column):
|
||||||
|
column = self._get_column_definition(table, column)
|
||||||
|
if column.is_pk:
|
||||||
|
raise ValueError('Primary keys can not be null')
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('ALTER TABLE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' MODIFY ')
|
||||||
|
.sql(column.sql(is_null=True)))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_column(self, table, old_name, new_name):
|
||||||
|
fk_objects = dict(
|
||||||
|
(fk.column, fk)
|
||||||
|
for fk in self.database.get_foreign_keys(table))
|
||||||
|
is_foreign_key = old_name in fk_objects
|
||||||
|
|
||||||
|
column = self._get_column_definition(table, old_name)
|
||||||
|
rename_ctx = (self
|
||||||
|
.make_context()
|
||||||
|
.literal('ALTER TABLE ')
|
||||||
|
.sql(Entity(table))
|
||||||
|
.literal(' CHANGE ')
|
||||||
|
.sql(Entity(old_name))
|
||||||
|
.literal(' ')
|
||||||
|
.sql(column.sql(column_name=new_name)))
|
||||||
|
if is_foreign_key:
|
||||||
|
fk_metadata = fk_objects[old_name]
|
||||||
|
return [
|
||||||
|
self.drop_foreign_key_constraint(table, old_name),
|
||||||
|
rename_ctx,
|
||||||
|
self.add_foreign_key_constraint(
|
||||||
|
table,
|
||||||
|
new_name,
|
||||||
|
fk_metadata.dest_table,
|
||||||
|
fk_metadata.dest_column),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return rename_ctx
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_index(self, table, index_name):
|
||||||
|
return (self
|
||||||
|
.make_context()
|
||||||
|
.literal('DROP INDEX ')
|
||||||
|
.sql(Entity(index_name))
|
||||||
|
.literal(' ON ')
|
||||||
|
.sql(Entity(table)))
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteMigrator(SchemaMigrator):
|
||||||
|
"""
|
||||||
|
SQLite supports a subset of ALTER TABLE queries, view the docs for the
|
||||||
|
full details http://sqlite.org/lang_altertable.html
|
||||||
|
"""
|
||||||
|
column_re = re.compile('(.+?)\((.+)\)')
|
||||||
|
column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+')
|
||||||
|
column_name_re = re.compile('["`\']?([\w]+)')
|
||||||
|
fk_re = re.compile('FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I)
|
||||||
|
|
||||||
|
def _get_column_names(self, table):
|
||||||
|
res = self.database.execute_sql('select * from "%s" limit 1' % table)
|
||||||
|
return [item[0] for item in res.description]
|
||||||
|
|
||||||
|
def _get_create_table(self, table):
|
||||||
|
res = self.database.execute_sql(
|
||||||
|
('select name, sql from sqlite_master '
|
||||||
|
'where type=? and LOWER(name)=?'),
|
||||||
|
['table', table.lower()])
|
||||||
|
return res.fetchone()
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def _update_column(self, table, column_to_update, fn):
|
||||||
|
columns = set(column.name.lower()
|
||||||
|
for column in self.database.get_columns(table))
|
||||||
|
if column_to_update.lower() not in columns:
|
||||||
|
raise ValueError('Column "%s" does not exist on "%s"' %
|
||||||
|
(column_to_update, table))
|
||||||
|
|
||||||
|
# Get the SQL used to create the given table.
|
||||||
|
table, create_table = self._get_create_table(table)
|
||||||
|
|
||||||
|
# Get the indexes and SQL to re-create indexes.
|
||||||
|
indexes = self.database.get_indexes(table)
|
||||||
|
|
||||||
|
# Find any foreign keys we may need to remove.
|
||||||
|
self.database.get_foreign_keys(table)
|
||||||
|
|
||||||
|
# Make sure the create_table does not contain any newlines or tabs,
|
||||||
|
# allowing the regex to work correctly.
|
||||||
|
create_table = re.sub(r'\s+', ' ', create_table)
|
||||||
|
|
||||||
|
# Parse out the `CREATE TABLE` and column list portions of the query.
|
||||||
|
raw_create, raw_columns = self.column_re.search(create_table).groups()
|
||||||
|
|
||||||
|
# Clean up the individual column definitions.
|
||||||
|
split_columns = self.column_split_re.findall(raw_columns)
|
||||||
|
column_defs = [col.strip() for col in split_columns]
|
||||||
|
|
||||||
|
new_column_defs = []
|
||||||
|
new_column_names = []
|
||||||
|
original_column_names = []
|
||||||
|
constraint_terms = ('foreign ', 'primary ', 'constraint ')
|
||||||
|
|
||||||
|
for column_def in column_defs:
|
||||||
|
column_name, = self.column_name_re.match(column_def).groups()
|
||||||
|
|
||||||
|
if column_name == column_to_update:
|
||||||
|
new_column_def = fn(column_name, column_def)
|
||||||
|
if new_column_def:
|
||||||
|
new_column_defs.append(new_column_def)
|
||||||
|
original_column_names.append(column_name)
|
||||||
|
column_name, = self.column_name_re.match(
|
||||||
|
new_column_def).groups()
|
||||||
|
new_column_names.append(column_name)
|
||||||
|
else:
|
||||||
|
new_column_defs.append(column_def)
|
||||||
|
|
||||||
|
# Avoid treating constraints as columns.
|
||||||
|
if not column_def.lower().startswith(constraint_terms):
|
||||||
|
new_column_names.append(column_name)
|
||||||
|
original_column_names.append(column_name)
|
||||||
|
|
||||||
|
# Create a mapping of original columns to new columns.
|
||||||
|
original_to_new = dict(zip(original_column_names, new_column_names))
|
||||||
|
new_column = original_to_new.get(column_to_update)
|
||||||
|
|
||||||
|
fk_filter_fn = lambda column_def: column_def
|
||||||
|
if not new_column:
|
||||||
|
# Remove any foreign keys associated with this column.
|
||||||
|
fk_filter_fn = lambda column_def: None
|
||||||
|
elif new_column != column_to_update:
|
||||||
|
# Update any foreign keys for this column.
|
||||||
|
fk_filter_fn = lambda column_def: self.fk_re.sub(
|
||||||
|
'FOREIGN KEY ("%s") ' % new_column,
|
||||||
|
column_def)
|
||||||
|
|
||||||
|
cleaned_columns = []
|
||||||
|
for column_def in new_column_defs:
|
||||||
|
match = self.fk_re.match(column_def)
|
||||||
|
if match is not None and match.groups()[0] == column_to_update:
|
||||||
|
column_def = fk_filter_fn(column_def)
|
||||||
|
if column_def:
|
||||||
|
cleaned_columns.append(column_def)
|
||||||
|
|
||||||
|
# Update the name of the new CREATE TABLE query.
|
||||||
|
temp_table = table + '__tmp__'
|
||||||
|
rgx = re.compile('("?)%s("?)' % table, re.I)
|
||||||
|
create = rgx.sub(
|
||||||
|
'\\1%s\\2' % temp_table,
|
||||||
|
raw_create)
|
||||||
|
|
||||||
|
# Create the new table.
|
||||||
|
columns = ', '.join(cleaned_columns)
|
||||||
|
queries = [
|
||||||
|
NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]),
|
||||||
|
SQL('%s (%s)' % (create.strip(), columns))]
|
||||||
|
|
||||||
|
# Populate new table.
|
||||||
|
populate_table = NodeList((
|
||||||
|
SQL('INSERT INTO'),
|
||||||
|
Entity(temp_table),
|
||||||
|
EnclosedNodeList([Entity(col) for col in new_column_names]),
|
||||||
|
SQL('SELECT'),
|
||||||
|
CommaNodeList([Entity(col) for col in original_column_names]),
|
||||||
|
SQL('FROM'),
|
||||||
|
Entity(table)))
|
||||||
|
drop_original = NodeList([SQL('DROP TABLE'), Entity(table)])
|
||||||
|
|
||||||
|
# Drop existing table and rename temp table.
|
||||||
|
queries += [
|
||||||
|
populate_table,
|
||||||
|
drop_original,
|
||||||
|
self.rename_table(temp_table, table)]
|
||||||
|
|
||||||
|
# Re-create user-defined indexes. User-defined indexes will have a
|
||||||
|
# non-empty SQL attribute.
|
||||||
|
for index in filter(lambda idx: idx.sql, indexes):
|
||||||
|
if column_to_update not in index.columns:
|
||||||
|
queries.append(SQL(index.sql))
|
||||||
|
elif new_column:
|
||||||
|
sql = self._fix_index(index.sql, column_to_update, new_column)
|
||||||
|
if sql is not None:
|
||||||
|
queries.append(SQL(sql))
|
||||||
|
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def _fix_index(self, sql, column_to_update, new_column):
|
||||||
|
# Split on the name of the column to update. If it splits into two
|
||||||
|
# pieces, then there's no ambiguity and we can simply replace the
|
||||||
|
# old with the new.
|
||||||
|
parts = sql.split(column_to_update)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return sql.replace(column_to_update, new_column)
|
||||||
|
|
||||||
|
# Find the list of columns in the index expression.
|
||||||
|
lhs, rhs = sql.rsplit('(', 1)
|
||||||
|
|
||||||
|
# Apply the same "split in two" logic to the column list portion of
|
||||||
|
# the query.
|
||||||
|
if len(rhs.split(column_to_update)) == 2:
|
||||||
|
return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column))
|
||||||
|
|
||||||
|
# Strip off the trailing parentheses and go through each column.
|
||||||
|
parts = rhs.rsplit(')', 1)[0].split(',')
|
||||||
|
columns = [part.strip('"`[]\' ') for part in parts]
|
||||||
|
|
||||||
|
# `columns` looks something like: ['status', 'timestamp" DESC']
|
||||||
|
# https://www.sqlite.org/lang_keywords.html
|
||||||
|
# Strip out any junk after the column name.
|
||||||
|
clean = []
|
||||||
|
for column in columns:
|
||||||
|
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
|
||||||
|
column = new_columne + column[len(column_to_update):]
|
||||||
|
clean.append(column)
|
||||||
|
|
||||||
|
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_column(self, table, column_name, cascade=True):
|
||||||
|
return self._update_column(table, column_name, lambda a, b: None)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def rename_column(self, table, old_name, new_name):
|
||||||
|
def _rename(column_name, column_def):
|
||||||
|
return column_def.replace(column_name, new_name)
|
||||||
|
return self._update_column(table, old_name, _rename)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_not_null(self, table, column):
|
||||||
|
def _add_not_null(column_name, column_def):
|
||||||
|
return column_def + ' NOT NULL'
|
||||||
|
return self._update_column(table, column, _add_not_null)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_not_null(self, table, column):
|
||||||
|
def _drop_not_null(column_name, column_def):
|
||||||
|
return column_def.replace('NOT NULL', '')
|
||||||
|
return self._update_column(table, column, _drop_not_null)
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_constraint(self, table, name, constraint):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def drop_constraint(self, table, name):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@operation
|
||||||
|
def add_foreign_key_constraint(self, table, column_name, field,
|
||||||
|
on_delete=None, on_update=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(*operations, **kwargs):
|
||||||
|
for operation in operations:
|
||||||
|
operation.run()
|
@ -0,0 +1,34 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mysql.connector as mysql_connector
|
||||||
|
except ImportError:
|
||||||
|
mysql_connector = None
|
||||||
|
|
||||||
|
from peewee import ImproperlyConfigured
|
||||||
|
from peewee import MySQLDatabase
|
||||||
|
from peewee import TextField
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLConnectorDatabase(MySQLDatabase):
|
||||||
|
def _connect(self):
|
||||||
|
if mysql_connector is None:
|
||||||
|
raise ImproperlyConfigured('MySQL connector not installed!')
|
||||||
|
return mysql_connector.connect(db=self.database, **self.connect_params)
|
||||||
|
|
||||||
|
def cursor(self, commit=None):
|
||||||
|
if self.is_closed():
|
||||||
|
self.connect()
|
||||||
|
return self._state.conn.cursor(buffered=True)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONField(TextField):
|
||||||
|
field_type = 'JSON'
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
def python_value(self, value):
|
||||||
|
if value is not None:
|
||||||
|
return json.loads(value)
|
@ -0,0 +1,317 @@
|
|||||||
|
"""
|
||||||
|
Lightweight connection pooling for peewee.
|
||||||
|
|
||||||
|
In a multi-threaded application, up to `max_connections` will be opened. Each
|
||||||
|
thread (or, if using gevent, greenlet) will have it's own connection.
|
||||||
|
|
||||||
|
In a single-threaded application, only one connection will be created. It will
|
||||||
|
be continually recycled until either it exceeds the stale timeout or is closed
|
||||||
|
explicitly (using `.manual_close()`).
|
||||||
|
|
||||||
|
By default, all your application needs to do is ensure that connections are
|
||||||
|
closed when you are finished with them, and they will be returned to the pool.
|
||||||
|
For web applications, this typically means that at the beginning of a request,
|
||||||
|
you will open a connection, and when you return a response, you will close the
|
||||||
|
connection.
|
||||||
|
|
||||||
|
Simple Postgres pool example code:
|
||||||
|
|
||||||
|
# Use the special postgresql extensions.
|
||||||
|
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||||
|
|
||||||
|
db = PooledPostgresqlExtDatabase(
|
||||||
|
'my_app',
|
||||||
|
max_connections=32,
|
||||||
|
stale_timeout=300, # 5 minutes.
|
||||||
|
user='postgres')
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
|
||||||
|
That's it!
|
||||||
|
"""
|
||||||
|
import heapq
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import namedtuple
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
try:
|
||||||
|
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
|
||||||
|
from psycopg2.extensions import TRANSACTION_STATUS_INERROR
|
||||||
|
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN
|
||||||
|
except ImportError:
|
||||||
|
TRANSACTION_STATUS_IDLE = \
|
||||||
|
TRANSACTION_STATUS_INERROR = \
|
||||||
|
TRANSACTION_STATUS_UNKNOWN = None
|
||||||
|
|
||||||
|
from peewee import MySQLDatabase
|
||||||
|
from peewee import PostgresqlDatabase
|
||||||
|
from peewee import SqliteDatabase
|
||||||
|
|
||||||
|
logger = logging.getLogger('peewee.pool')
|
||||||
|
|
||||||
|
|
||||||
|
def make_int(val):
|
||||||
|
if val is not None and not isinstance(val, (int, float)):
|
||||||
|
return int(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
class MaxConnectionsExceeded(ValueError): pass
|
||||||
|
|
||||||
|
|
||||||
|
PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection',
|
||||||
|
'checked_out'))
|
||||||
|
|
||||||
|
|
||||||
|
class PooledDatabase(object):
|
||||||
|
def __init__(self, database, max_connections=20, stale_timeout=None,
|
||||||
|
timeout=None, **kwargs):
|
||||||
|
self._max_connections = make_int(max_connections)
|
||||||
|
self._stale_timeout = make_int(stale_timeout)
|
||||||
|
self._wait_timeout = make_int(timeout)
|
||||||
|
if self._wait_timeout == 0:
|
||||||
|
self._wait_timeout = float('inf')
|
||||||
|
|
||||||
|
# Available / idle connections stored in a heap, sorted oldest first.
|
||||||
|
self._connections = []
|
||||||
|
|
||||||
|
# Mapping of connection id to PoolConnection. Ordinarily we would want
|
||||||
|
# to use something like a WeakKeyDictionary, but Python typically won't
|
||||||
|
# allow us to create weak references to connection objects.
|
||||||
|
self._in_use = {}
|
||||||
|
|
||||||
|
# Use the memory address of the connection as the key in the event the
|
||||||
|
# connection object is not hashable. Connections will not get
|
||||||
|
# garbage-collected, however, because a reference to them will persist
|
||||||
|
# in "_in_use" as long as the conn has not been closed.
|
||||||
|
self.conn_key = id
|
||||||
|
|
||||||
|
super(PooledDatabase, self).__init__(database, **kwargs)
|
||||||
|
|
||||||
|
def init(self, database, max_connections=None, stale_timeout=None,
|
||||||
|
timeout=None, **connect_kwargs):
|
||||||
|
super(PooledDatabase, self).init(database, **connect_kwargs)
|
||||||
|
if max_connections is not None:
|
||||||
|
self._max_connections = make_int(max_connections)
|
||||||
|
if stale_timeout is not None:
|
||||||
|
self._stale_timeout = make_int(stale_timeout)
|
||||||
|
if timeout is not None:
|
||||||
|
self._wait_timeout = make_int(timeout)
|
||||||
|
if self._wait_timeout == 0:
|
||||||
|
self._wait_timeout = float('inf')
|
||||||
|
|
||||||
|
def connect(self, reuse_if_open=False):
|
||||||
|
if not self._wait_timeout:
|
||||||
|
return super(PooledDatabase, self).connect(reuse_if_open)
|
||||||
|
|
||||||
|
expires = time.time() + self._wait_timeout
|
||||||
|
while expires > time.time():
|
||||||
|
try:
|
||||||
|
ret = super(PooledDatabase, self).connect(reuse_if_open)
|
||||||
|
except MaxConnectionsExceeded:
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
|
||||||
|
'attempting to connect.')
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Remove the oldest connection from the heap.
|
||||||
|
ts, conn = heapq.heappop(self._connections)
|
||||||
|
key = self.conn_key(conn)
|
||||||
|
except IndexError:
|
||||||
|
ts = conn = None
|
||||||
|
logger.debug('No connection available in pool.')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if self._is_closed(conn):
|
||||||
|
# This connecton was closed, but since it was not stale
|
||||||
|
# it got added back to the queue of available conns. We
|
||||||
|
# then closed it and marked it as explicitly closed, so
|
||||||
|
# it's safe to throw it away now.
|
||||||
|
# (Because Database.close() calls Database._close()).
|
||||||
|
logger.debug('Connection %s was closed.', key)
|
||||||
|
ts = conn = None
|
||||||
|
elif self._stale_timeout and self._is_stale(ts):
|
||||||
|
# If we are attempting to check out a stale connection,
|
||||||
|
# then close it. We don't need to mark it in the "closed"
|
||||||
|
# set, because it is not in the list of available conns
|
||||||
|
# anymore.
|
||||||
|
logger.debug('Connection %s was stale, closing.', key)
|
||||||
|
self._close(conn, True)
|
||||||
|
ts = conn = None
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if conn is None:
|
||||||
|
if self._max_connections and (
|
||||||
|
len(self._in_use) >= self._max_connections):
|
||||||
|
raise MaxConnectionsExceeded('Exceeded maximum connections.')
|
||||||
|
conn = super(PooledDatabase, self)._connect()
|
||||||
|
ts = time.time()
|
||||||
|
key = self.conn_key(conn)
|
||||||
|
logger.debug('Created new connection %s.', key)
|
||||||
|
|
||||||
|
self._in_use[key] = PoolConnection(ts, conn, time.time())
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _is_stale(self, timestamp):
|
||||||
|
# Called on check-out and check-in to ensure the connection has
|
||||||
|
# not outlived the stale timeout.
|
||||||
|
return (time.time() - timestamp) > self._stale_timeout
|
||||||
|
|
||||||
|
def _is_closed(self, conn):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _can_reuse(self, conn):
|
||||||
|
# Called on check-in to make sure the connection can be re-used.
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _close(self, conn, close_conn=False):
|
||||||
|
key = self.conn_key(conn)
|
||||||
|
if close_conn:
|
||||||
|
super(PooledDatabase, self)._close(conn)
|
||||||
|
elif key in self._in_use:
|
||||||
|
pool_conn = self._in_use.pop(key)
|
||||||
|
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
|
||||||
|
logger.debug('Closing stale connection %s.', key)
|
||||||
|
super(PooledDatabase, self)._close(conn)
|
||||||
|
elif self._can_reuse(conn):
|
||||||
|
logger.debug('Returning %s to pool.', key)
|
||||||
|
heapq.heappush(self._connections, (pool_conn.timestamp, conn))
|
||||||
|
else:
|
||||||
|
logger.debug('Closed %s.', key)
|
||||||
|
|
||||||
|
def manual_close(self):
|
||||||
|
"""
|
||||||
|
Close the underlying connection without returning it to the pool.
|
||||||
|
"""
|
||||||
|
if self.is_closed():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Obtain reference to the connection in-use by the calling thread.
|
||||||
|
conn = self.connection()
|
||||||
|
|
||||||
|
# A connection will only be re-added to the available list if it is
|
||||||
|
# marked as "in use" at the time it is closed. We will explicitly
|
||||||
|
# remove it from the "in use" list, call "close()" for the
|
||||||
|
# side-effects, and then explicitly close the connection.
|
||||||
|
self._in_use.pop(self.conn_key(conn), None)
|
||||||
|
self.close()
|
||||||
|
self._close(conn, close_conn=True)
|
||||||
|
|
||||||
|
def close_idle(self):
|
||||||
|
# Close any open connections that are not currently in-use.
|
||||||
|
with self._lock:
|
||||||
|
for _, conn in self._connections:
|
||||||
|
self._close(conn, close_conn=True)
|
||||||
|
self._connections = []
|
||||||
|
|
||||||
|
def close_stale(self, age=600):
|
||||||
|
# Close any connections that are in-use but were checked out quite some
|
||||||
|
# time ago and can be considered stale.
|
||||||
|
with self._lock:
|
||||||
|
in_use = {}
|
||||||
|
cutoff = time.time() - age
|
||||||
|
n = 0
|
||||||
|
for key, pool_conn in self._in_use.items():
|
||||||
|
if pool_conn.checked_out < cutoff:
|
||||||
|
self._close(pool_conn.connection, close_conn=True)
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
in_use[key] = pool_conn
|
||||||
|
self._in_use = in_use
|
||||||
|
return n
|
||||||
|
|
||||||
|
def close_all(self):
|
||||||
|
# Close all connections -- available and in-use. Warning: may break any
|
||||||
|
# active connections used by other threads.
|
||||||
|
self.close()
|
||||||
|
with self._lock:
|
||||||
|
for _, conn in self._connections:
|
||||||
|
self._close(conn, close_conn=True)
|
||||||
|
for pool_conn in self._in_use.values():
|
||||||
|
self._close(pool_conn.connection, close_conn=True)
|
||||||
|
self._connections = []
|
||||||
|
self._in_use = {}
|
||||||
|
|
||||||
|
|
||||||
|
class PooledMySQLDatabase(PooledDatabase, MySQLDatabase):
|
||||||
|
def _is_closed(self, conn):
|
||||||
|
try:
|
||||||
|
conn.ping(False)
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _PooledPostgresqlDatabase(PooledDatabase):
|
||||||
|
def _is_closed(self, conn):
|
||||||
|
if conn.closed:
|
||||||
|
return True
|
||||||
|
|
||||||
|
txn_status = conn.get_transaction_status()
|
||||||
|
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||||
|
return True
|
||||||
|
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||||
|
conn.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _can_reuse(self, conn):
|
||||||
|
txn_status = conn.get_transaction_status()
|
||||||
|
# Do not return connection in an error state, as subsequent queries
|
||||||
|
# will all fail. If the status is unknown then we lost the connection
|
||||||
|
# to the server and the connection should not be re-used.
|
||||||
|
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||||
|
return False
|
||||||
|
elif txn_status == TRANSACTION_STATUS_INERROR:
|
||||||
|
conn.reset()
|
||||||
|
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||||
|
conn.rollback()
|
||||||
|
return True
|
||||||
|
|
||||||
|
class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||||
|
|
||||||
|
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
PooledPostgresqlExtDatabase = None
|
||||||
|
|
||||||
|
|
||||||
|
class _PooledSqliteDatabase(PooledDatabase):
|
||||||
|
def _is_closed(self, conn):
|
||||||
|
try:
|
||||||
|
conn.total_changes
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
|
class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase):
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
PooledSqliteExtDatabase = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse.sqlite_ext import CSqliteExtDatabase
|
||||||
|
|
||||||
|
class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase):
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
PooledCSqliteExtDatabase = None
|
@ -0,0 +1,468 @@
|
|||||||
|
"""
|
||||||
|
Collection of postgres-specific extensions, currently including:
|
||||||
|
|
||||||
|
* Support for hstore, a key/value type storage
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from peewee import ColumnBase
|
||||||
|
from peewee import Expression
|
||||||
|
from peewee import Node
|
||||||
|
from peewee import NodeList
|
||||||
|
from peewee import SENTINEL
|
||||||
|
from peewee import __exception_wrapper__
|
||||||
|
|
||||||
|
try:
|
||||||
|
from psycopg2cffi import compat
|
||||||
|
compat.register()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from psycopg2.extras import register_hstore
|
||||||
|
try:
|
||||||
|
from psycopg2.extras import Json
|
||||||
|
except:
|
||||||
|
Json = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('peewee')
|
||||||
|
|
||||||
|
|
||||||
|
HCONTAINS_DICT = '@>'
|
||||||
|
HCONTAINS_KEYS = '?&'
|
||||||
|
HCONTAINS_KEY = '?'
|
||||||
|
HCONTAINS_ANY_KEY = '?|'
|
||||||
|
HKEY = '->'
|
||||||
|
HUPDATE = '||'
|
||||||
|
ACONTAINS = '@>'
|
||||||
|
ACONTAINS_ANY = '&&'
|
||||||
|
TS_MATCH = '@@'
|
||||||
|
JSONB_CONTAINS = '@>'
|
||||||
|
JSONB_CONTAINED_BY = '<@'
|
||||||
|
JSONB_CONTAINS_KEY = '?'
|
||||||
|
JSONB_CONTAINS_ANY_KEY = '?|'
|
||||||
|
JSONB_CONTAINS_ALL_KEYS = '?&'
|
||||||
|
JSONB_EXISTS = '?'
|
||||||
|
JSONB_REMOVE = '-'
|
||||||
|
|
||||||
|
|
||||||
|
class _LookupNode(ColumnBase):
|
||||||
|
def __init__(self, node, parts):
|
||||||
|
self.node = node
|
||||||
|
self.parts = parts
|
||||||
|
super(_LookupNode, self).__init__()
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return type(self)(self.node, list(self.parts))
|
||||||
|
|
||||||
|
|
||||||
|
class _JsonLookupBase(_LookupNode):
|
||||||
|
def __init__(self, node, parts, as_json=False):
|
||||||
|
super(_JsonLookupBase, self).__init__(node, parts)
|
||||||
|
self._as_json = as_json
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return type(self)(self.node, list(self.parts), self._as_json)
|
||||||
|
|
||||||
|
@Node.copy
|
||||||
|
def as_json(self, as_json=True):
|
||||||
|
self._as_json = as_json
|
||||||
|
|
||||||
|
def concat(self, rhs):
|
||||||
|
return Expression(self.as_json(True), OP.CONCAT, Json(rhs))
|
||||||
|
|
||||||
|
def contains(self, other):
|
||||||
|
clone = self.as_json(True)
|
||||||
|
if isinstance(other, (list, dict)):
|
||||||
|
return Expression(clone, JSONB_CONTAINS, Json(other))
|
||||||
|
return Expression(clone, JSONB_EXISTS, other)
|
||||||
|
|
||||||
|
def contains_any(self, *keys):
|
||||||
|
return Expression(
|
||||||
|
self.as_json(True),
|
||||||
|
JSONB_CONTAINS_ANY_KEY,
|
||||||
|
Value(list(keys), unpack=False))
|
||||||
|
|
||||||
|
def contains_all(self, *keys):
|
||||||
|
return Expression(
|
||||||
|
self.as_json(True),
|
||||||
|
JSONB_CONTAINS_ALL_KEYS,
|
||||||
|
Value(list(keys), unpack=False))
|
||||||
|
|
||||||
|
def has_key(self, key):
|
||||||
|
return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonLookup(_JsonLookupBase):
|
||||||
|
def __getitem__(self, value):
|
||||||
|
return JsonLookup(self.node, self.parts + [value], self._as_json)
|
||||||
|
|
||||||
|
def __sql__(self, ctx):
|
||||||
|
ctx.sql(self.node)
|
||||||
|
for part in self.parts[:-1]:
|
||||||
|
ctx.literal('->').sql(part)
|
||||||
|
if self.parts:
|
||||||
|
(ctx
|
||||||
|
.literal('->' if self._as_json else '->>')
|
||||||
|
.sql(self.parts[-1]))
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
class JsonPath(_JsonLookupBase):
|
||||||
|
def __sql__(self, ctx):
|
||||||
|
return (ctx
|
||||||
|
.sql(self.node)
|
||||||
|
.literal('#>' if self._as_json else '#>>')
|
||||||
|
.sql(Value('{%s}' % ','.join(map(str, self.parts)))))
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSlice(_LookupNode):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, node, value):
|
||||||
|
if isinstance(value, slice):
|
||||||
|
parts = [value.start or 0, value.stop or 0]
|
||||||
|
elif isinstance(value, int):
|
||||||
|
parts = [value]
|
||||||
|
else:
|
||||||
|
parts = map(int, value.split(':'))
|
||||||
|
return cls(node, parts)
|
||||||
|
|
||||||
|
def __sql__(self, ctx):
|
||||||
|
return (ctx
|
||||||
|
.sql(self.node)
|
||||||
|
.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)))
|
||||||
|
|
||||||
|
def __getitem__(self, value):
|
||||||
|
return ObjectSlice.create(self, value)
|
||||||
|
|
||||||
|
|
||||||
|
class IndexedFieldMixin(object):
|
||||||
|
default_index_type = 'GIN'
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault('index', True) # By default, use an index.
|
||||||
|
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayField(IndexedFieldMixin, Field):
|
||||||
|
passthrough = True
|
||||||
|
|
||||||
|
def __init__(self, field_class=IntegerField, field_kwargs=None,
|
||||||
|
dimensions=1, convert_values=False, *args, **kwargs):
|
||||||
|
self.__field = field_class(**(field_kwargs or {}))
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.convert_values = convert_values
|
||||||
|
self.field_type = self.__field.field_type
|
||||||
|
super(ArrayField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def bind(self, model, name, set_attribute=True):
|
||||||
|
ret = super(ArrayField, self).bind(model, name, set_attribute)
|
||||||
|
self.__field.bind(model, '__array_%s' % name, False)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def ddl_datatype(self, ctx):
|
||||||
|
data_type = self.__field.ddl_datatype(ctx)
|
||||||
|
return NodeList((data_type, SQL('[]' * self.dimensions)), glue='')
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is None or isinstance(value, Node):
|
||||||
|
return value
|
||||||
|
elif self.convert_values:
|
||||||
|
return self._process(self.__field.db_value, value, self.dimensions)
|
||||||
|
else:
|
||||||
|
return value if isinstance(value, list) else list(value)
|
||||||
|
|
||||||
|
def python_value(self, value):
|
||||||
|
if self.convert_values and value is not None:
|
||||||
|
conv = self.__field.python_value
|
||||||
|
if isinstance(value, list):
|
||||||
|
return self._process(conv, value, self.dimensions)
|
||||||
|
else:
|
||||||
|
return conv(value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _process(self, conv, value, dimensions):
|
||||||
|
dimensions -= 1
|
||||||
|
if dimensions == 0:
|
||||||
|
return [conv(v) for v in value]
|
||||||
|
else:
|
||||||
|
return [self._process(conv, v, dimensions) for v in value]
|
||||||
|
|
||||||
|
def __getitem__(self, value):
|
||||||
|
return ObjectSlice.create(self, value)
|
||||||
|
|
||||||
|
def _e(op):
|
||||||
|
def inner(self, rhs):
|
||||||
|
return Expression(self, op, ArrayValue(self, rhs))
|
||||||
|
return inner
|
||||||
|
__eq__ = _e(OP.EQ)
|
||||||
|
__ne__ = _e(OP.NE)
|
||||||
|
__gt__ = _e(OP.GT)
|
||||||
|
__ge__ = _e(OP.GTE)
|
||||||
|
__lt__ = _e(OP.LT)
|
||||||
|
__le__ = _e(OP.LTE)
|
||||||
|
__hash__ = Field.__hash__
|
||||||
|
|
||||||
|
def contains(self, *items):
|
||||||
|
return Expression(self, ACONTAINS, ArrayValue(self, items))
|
||||||
|
|
||||||
|
def contains_any(self, *items):
|
||||||
|
return Expression(self, ACONTAINS_ANY, ArrayValue(self, items))
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayValue(Node):
|
||||||
|
def __init__(self, field, value):
|
||||||
|
self.field = field
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __sql__(self, ctx):
|
||||||
|
return (ctx
|
||||||
|
.sql(Value(self.value, unpack=False))
|
||||||
|
.literal('::')
|
||||||
|
.sql(self.field.ddl_datatype(ctx)))
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeTZField(DateTimeField):
|
||||||
|
field_type = 'TIMESTAMPTZ'
|
||||||
|
|
||||||
|
|
||||||
|
class HStoreField(IndexedFieldMixin, Field):
|
||||||
|
field_type = 'HSTORE'
|
||||||
|
__hash__ = Field.__hash__
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return Expression(self, HKEY, Value(key))
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return fn.akeys(self)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return fn.avals(self)
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return fn.hstore_to_matrix(self)
|
||||||
|
|
||||||
|
def slice(self, *args):
|
||||||
|
return fn.slice(self, Value(list(args), unpack=False))
|
||||||
|
|
||||||
|
def exists(self, key):
|
||||||
|
return fn.exist(self, key)
|
||||||
|
|
||||||
|
def defined(self, key):
|
||||||
|
return fn.defined(self, key)
|
||||||
|
|
||||||
|
def update(self, **data):
|
||||||
|
return Expression(self, HUPDATE, data)
|
||||||
|
|
||||||
|
def delete(self, *keys):
|
||||||
|
return fn.delete(self, Value(list(keys), unpack=False))
|
||||||
|
|
||||||
|
def contains(self, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
rhs = Value(value, unpack=False)
|
||||||
|
return Expression(self, HCONTAINS_DICT, rhs)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
rhs = Value(value, unpack=False)
|
||||||
|
return Expression(self, HCONTAINS_KEYS, rhs)
|
||||||
|
return Expression(self, HCONTAINS_KEY, value)
|
||||||
|
|
||||||
|
def contains_any(self, *keys):
|
||||||
|
return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys),
|
||||||
|
unpack=False))
|
||||||
|
|
||||||
|
|
||||||
|
class JSONField(Field):
|
||||||
|
field_type = 'JSON'
|
||||||
|
|
||||||
|
def __init__(self, dumps=None, *args, **kwargs):
|
||||||
|
if Json is None:
|
||||||
|
raise Exception('Your version of psycopg2 does not support JSON.')
|
||||||
|
self.dumps = dumps
|
||||||
|
super(JSONField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if not isinstance(value, Json):
|
||||||
|
return Json(value, dumps=self.dumps)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __getitem__(self, value):
|
||||||
|
return JsonLookup(self, [value])
|
||||||
|
|
||||||
|
def path(self, *keys):
|
||||||
|
return JsonPath(self, keys)
|
||||||
|
|
||||||
|
def concat(self, value):
|
||||||
|
return super(JSONField, self).concat(Json(value))
|
||||||
|
|
||||||
|
|
||||||
|
def cast_jsonb(node):
|
||||||
|
return NodeList((node, SQL('::jsonb')), glue='')
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryJSONField(IndexedFieldMixin, JSONField):
|
||||||
|
field_type = 'JSONB'
|
||||||
|
__hash__ = Field.__hash__
|
||||||
|
|
||||||
|
def contains(self, other):
|
||||||
|
if isinstance(other, (list, dict)):
|
||||||
|
return Expression(self, JSONB_CONTAINS, Json(other))
|
||||||
|
return Expression(cast_jsonb(self), JSONB_EXISTS, other)
|
||||||
|
|
||||||
|
def contained_by(self, other):
|
||||||
|
return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other))
|
||||||
|
|
||||||
|
def contains_any(self, *items):
|
||||||
|
return Expression(
|
||||||
|
cast_jsonb(self),
|
||||||
|
JSONB_CONTAINS_ANY_KEY,
|
||||||
|
Value(list(items), unpack=False))
|
||||||
|
|
||||||
|
def contains_all(self, *items):
|
||||||
|
return Expression(
|
||||||
|
cast_jsonb(self),
|
||||||
|
JSONB_CONTAINS_ALL_KEYS,
|
||||||
|
Value(list(items), unpack=False))
|
||||||
|
|
||||||
|
def has_key(self, key):
|
||||||
|
return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key)
|
||||||
|
|
||||||
|
def remove(self, *items):
|
||||||
|
return Expression(
|
||||||
|
cast_jsonb(self),
|
||||||
|
JSONB_REMOVE,
|
||||||
|
Value(list(items), unpack=False))
|
||||||
|
|
||||||
|
|
||||||
|
class TSVectorField(IndexedFieldMixin, TextField):
|
||||||
|
field_type = 'TSVECTOR'
|
||||||
|
__hash__ = Field.__hash__
|
||||||
|
|
||||||
|
def match(self, query, language=None, plain=False):
|
||||||
|
params = (language, query) if language is not None else (query,)
|
||||||
|
func = fn.plainto_tsquery if plain else fn.to_tsquery
|
||||||
|
return Expression(self, TS_MATCH, func(*params))
|
||||||
|
|
||||||
|
|
||||||
|
def Match(field, query, language=None):
|
||||||
|
params = (language, query) if language is not None else (query,)
|
||||||
|
field_params = (language, field) if language is not None else (field,)
|
||||||
|
return Expression(
|
||||||
|
fn.to_tsvector(*field_params),
|
||||||
|
TS_MATCH,
|
||||||
|
fn.to_tsquery(*params))
|
||||||
|
|
||||||
|
|
||||||
|
class IntervalField(Field):
|
||||||
|
field_type = 'INTERVAL'
|
||||||
|
|
||||||
|
|
||||||
|
class FetchManyCursor(object):
|
||||||
|
__slots__ = ('cursor', 'array_size', 'exhausted', 'iterable')
|
||||||
|
|
||||||
|
def __init__(self, cursor, array_size=None):
|
||||||
|
self.cursor = cursor
|
||||||
|
self.array_size = array_size or cursor.itersize
|
||||||
|
self.exhausted = False
|
||||||
|
self.iterable = self.row_gen()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return self.cursor.description
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.cursor.close()
|
||||||
|
|
||||||
|
def row_gen(self):
|
||||||
|
while True:
|
||||||
|
rows = self.cursor.fetchmany(self.array_size)
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
for row in rows:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
if self.exhausted:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
return next(self.iterable)
|
||||||
|
except StopIteration:
|
||||||
|
self.exhausted = True
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSideQuery(Node):
|
||||||
|
def __init__(self, query, array_size=None):
|
||||||
|
self.query = query
|
||||||
|
self.array_size = array_size
|
||||||
|
self._cursor_wrapper = None
|
||||||
|
|
||||||
|
def __sql__(self, ctx):
|
||||||
|
return self.query.__sql__(ctx)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self._cursor_wrapper is None:
|
||||||
|
self._execute(self.query._database)
|
||||||
|
return iter(self._cursor_wrapper.iterator())
|
||||||
|
|
||||||
|
def _execute(self, database):
|
||||||
|
if self._cursor_wrapper is None:
|
||||||
|
cursor = database.execute(self.query, named_cursor=True,
|
||||||
|
array_size=self.array_size)
|
||||||
|
self._cursor_wrapper = self.query._get_cursor_wrapper(cursor)
|
||||||
|
return self._cursor_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def ServerSide(query, database=None, array_size=None):
|
||||||
|
if database is None:
|
||||||
|
database = query._database
|
||||||
|
with database.transaction():
|
||||||
|
server_side_query = ServerSideQuery(query, array_size=array_size)
|
||||||
|
for row in server_side_query:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
|
||||||
|
class _empty_object(object):
|
||||||
|
__slots__ = ()
|
||||||
|
def __nonzero__(self):
|
||||||
|
return False
|
||||||
|
__bool__ = __nonzero__
|
||||||
|
|
||||||
|
__named_cursor__ = _empty_object()
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresqlExtDatabase(PostgresqlDatabase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._register_hstore = kwargs.pop('register_hstore', False)
|
||||||
|
self._server_side_cursors = kwargs.pop('server_side_cursors', False)
|
||||||
|
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
conn = super(PostgresqlExtDatabase, self)._connect()
|
||||||
|
if self._register_hstore:
|
||||||
|
register_hstore(conn, globally=True)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def cursor(self, commit=None):
|
||||||
|
if self.is_closed():
|
||||||
|
self.connect()
|
||||||
|
if commit is __named_cursor__:
|
||||||
|
return self._state.conn.cursor(name=str(uuid.uuid1()))
|
||||||
|
return self._state.conn.cursor()
|
||||||
|
|
||||||
|
def execute(self, query, commit=SENTINEL, named_cursor=False,
|
||||||
|
array_size=None, **context_options):
|
||||||
|
ctx = self.get_sql_context(**context_options)
|
||||||
|
sql, params = ctx.sql(query).query()
|
||||||
|
named_cursor = named_cursor or (self._server_side_cursors and
|
||||||
|
sql[:6].lower() == 'select')
|
||||||
|
if named_cursor:
|
||||||
|
commit = __named_cursor__
|
||||||
|
cursor = self.execute_sql(sql, params, commit=commit)
|
||||||
|
if named_cursor:
|
||||||
|
cursor = FetchManyCursor(cursor, array_size)
|
||||||
|
return cursor
|
@ -0,0 +1,799 @@
|
|||||||
|
try:
|
||||||
|
from collections import OrderedDict
|
||||||
|
except ImportError:
|
||||||
|
OrderedDict = dict
|
||||||
|
from collections import namedtuple
|
||||||
|
from inspect import isclass
|
||||||
|
import re
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from peewee import _StringField
|
||||||
|
from peewee import _query_val_transform
|
||||||
|
from peewee import CommaNodeList
|
||||||
|
from peewee import SCOPE_VALUES
|
||||||
|
from peewee import make_snake_case
|
||||||
|
from peewee import text_type
|
||||||
|
try:
|
||||||
|
from pymysql.constants import FIELD_TYPE
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from MySQLdb.constants import FIELD_TYPE
|
||||||
|
except ImportError:
|
||||||
|
FIELD_TYPE = None
|
||||||
|
try:
|
||||||
|
from playhouse import postgres_ext
|
||||||
|
except ImportError:
|
||||||
|
postgres_ext = None
|
||||||
|
|
||||||
|
RESERVED_WORDS = set([
|
||||||
|
'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
|
||||||
|
'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if',
|
||||||
|
'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise',
|
||||||
|
'return', 'try', 'while', 'with', 'yield',
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownField(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Column(object):
|
||||||
|
"""
|
||||||
|
Store metadata about a database column.
|
||||||
|
"""
|
||||||
|
primary_key_types = (IntegerField, AutoField)
|
||||||
|
|
||||||
|
def __init__(self, name, field_class, raw_column_type, nullable,
|
||||||
|
primary_key=False, column_name=None, index=False,
|
||||||
|
unique=False, default=None, extra_parameters=None):
|
||||||
|
self.name = name
|
||||||
|
self.field_class = field_class
|
||||||
|
self.raw_column_type = raw_column_type
|
||||||
|
self.nullable = nullable
|
||||||
|
self.primary_key = primary_key
|
||||||
|
self.column_name = column_name
|
||||||
|
self.index = index
|
||||||
|
self.unique = unique
|
||||||
|
self.default = default
|
||||||
|
self.extra_parameters = extra_parameters
|
||||||
|
|
||||||
|
# Foreign key metadata.
|
||||||
|
self.rel_model = None
|
||||||
|
self.related_name = None
|
||||||
|
self.to_field = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
attrs = [
|
||||||
|
'field_class',
|
||||||
|
'raw_column_type',
|
||||||
|
'nullable',
|
||||||
|
'primary_key',
|
||||||
|
'column_name']
|
||||||
|
keyword_args = ', '.join(
|
||||||
|
'%s=%s' % (attr, getattr(self, attr))
|
||||||
|
for attr in attrs)
|
||||||
|
return 'Column(%s, %s)' % (self.name, keyword_args)
|
||||||
|
|
||||||
|
def get_field_parameters(self):
|
||||||
|
params = {}
|
||||||
|
if self.extra_parameters is not None:
|
||||||
|
params.update(self.extra_parameters)
|
||||||
|
|
||||||
|
# Set up default attributes.
|
||||||
|
if self.nullable:
|
||||||
|
params['null'] = True
|
||||||
|
if self.field_class is ForeignKeyField or self.name != self.column_name:
|
||||||
|
params['column_name'] = "'%s'" % self.column_name
|
||||||
|
if self.primary_key and not issubclass(self.field_class, AutoField):
|
||||||
|
params['primary_key'] = True
|
||||||
|
if self.default is not None:
|
||||||
|
params['constraints'] = '[SQL("DEFAULT %s")]' % self.default
|
||||||
|
|
||||||
|
# Handle ForeignKeyField-specific attributes.
|
||||||
|
if self.is_foreign_key():
|
||||||
|
params['model'] = self.rel_model
|
||||||
|
if self.to_field:
|
||||||
|
params['field'] = "'%s'" % self.to_field
|
||||||
|
if self.related_name:
|
||||||
|
params['backref'] = "'%s'" % self.related_name
|
||||||
|
|
||||||
|
# Handle indexes on column.
|
||||||
|
if not self.is_primary_key():
|
||||||
|
if self.unique:
|
||||||
|
params['unique'] = 'True'
|
||||||
|
elif self.index and not self.is_foreign_key():
|
||||||
|
params['index'] = 'True'
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def is_primary_key(self):
|
||||||
|
return self.field_class is AutoField or self.primary_key
|
||||||
|
|
||||||
|
def is_foreign_key(self):
|
||||||
|
return self.field_class is ForeignKeyField
|
||||||
|
|
||||||
|
def is_self_referential_fk(self):
|
||||||
|
return (self.field_class is ForeignKeyField and
|
||||||
|
self.rel_model == "'self'")
|
||||||
|
|
||||||
|
def set_foreign_key(self, foreign_key, model_names, dest=None,
|
||||||
|
related_name=None):
|
||||||
|
self.foreign_key = foreign_key
|
||||||
|
self.field_class = ForeignKeyField
|
||||||
|
if foreign_key.dest_table == foreign_key.table:
|
||||||
|
self.rel_model = "'self'"
|
||||||
|
else:
|
||||||
|
self.rel_model = model_names[foreign_key.dest_table]
|
||||||
|
self.to_field = dest and dest.name or None
|
||||||
|
self.related_name = related_name or None
|
||||||
|
|
||||||
|
def get_field(self):
|
||||||
|
# Generate the field definition for this column.
|
||||||
|
field_params = {}
|
||||||
|
for key, value in self.get_field_parameters().items():
|
||||||
|
if isclass(value) and issubclass(value, Field):
|
||||||
|
value = value.__name__
|
||||||
|
field_params[key] = value
|
||||||
|
|
||||||
|
param_str = ', '.join('%s=%s' % (k, v)
|
||||||
|
for k, v in sorted(field_params.items()))
|
||||||
|
field = '%s = %s(%s)' % (
|
||||||
|
self.name,
|
||||||
|
self.field_class.__name__,
|
||||||
|
param_str)
|
||||||
|
|
||||||
|
if self.field_class is UnknownField:
|
||||||
|
field = '%s # %s' % (field, self.raw_column_type)
|
||||||
|
|
||||||
|
return field
|
||||||
|
|
||||||
|
|
||||||
|
class Metadata(object):
|
||||||
|
column_map = {}
|
||||||
|
extension_import = ''
|
||||||
|
|
||||||
|
def __init__(self, database):
|
||||||
|
self.database = database
|
||||||
|
self.requires_extension = False
|
||||||
|
|
||||||
|
def execute(self, sql, *params):
|
||||||
|
return self.database.execute_sql(sql, params)
|
||||||
|
|
||||||
|
def get_columns(self, table, schema=None):
|
||||||
|
metadata = OrderedDict(
|
||||||
|
(metadata.name, metadata)
|
||||||
|
for metadata in self.database.get_columns(table, schema))
|
||||||
|
|
||||||
|
# Look up the actual column type for each column.
|
||||||
|
column_types, extra_params = self.get_column_types(table, schema)
|
||||||
|
|
||||||
|
# Look up the primary keys.
|
||||||
|
pk_names = self.get_primary_keys(table, schema)
|
||||||
|
if len(pk_names) == 1:
|
||||||
|
pk = pk_names[0]
|
||||||
|
if column_types[pk] is IntegerField:
|
||||||
|
column_types[pk] = AutoField
|
||||||
|
elif column_types[pk] is BigIntegerField:
|
||||||
|
column_types[pk] = BigAutoField
|
||||||
|
|
||||||
|
columns = OrderedDict()
|
||||||
|
for name, column_data in metadata.items():
|
||||||
|
field_class = column_types[name]
|
||||||
|
default = self._clean_default(field_class, column_data.default)
|
||||||
|
|
||||||
|
columns[name] = Column(
|
||||||
|
name,
|
||||||
|
field_class=field_class,
|
||||||
|
raw_column_type=column_data.data_type,
|
||||||
|
nullable=column_data.null,
|
||||||
|
primary_key=column_data.primary_key,
|
||||||
|
column_name=name,
|
||||||
|
default=default,
|
||||||
|
extra_parameters=extra_params.get(name))
|
||||||
|
|
||||||
|
return columns
|
||||||
|
|
||||||
|
def get_column_types(self, table, schema=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _clean_default(self, field_class, default):
|
||||||
|
if default is None or field_class in (AutoField, BigAutoField) or \
|
||||||
|
default.lower() == 'null':
|
||||||
|
return
|
||||||
|
if issubclass(field_class, _StringField) and \
|
||||||
|
isinstance(default, text_type) and not default.startswith("'"):
|
||||||
|
default = "'%s'" % default
|
||||||
|
return default or "''"
|
||||||
|
|
||||||
|
def get_foreign_keys(self, table, schema=None):
|
||||||
|
return self.database.get_foreign_keys(table, schema)
|
||||||
|
|
||||||
|
def get_primary_keys(self, table, schema=None):
|
||||||
|
return self.database.get_primary_keys(table, schema)
|
||||||
|
|
||||||
|
def get_indexes(self, table, schema=None):
|
||||||
|
return self.database.get_indexes(table, schema)
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresqlMetadata(Metadata):
|
||||||
|
column_map = {
|
||||||
|
16: BooleanField,
|
||||||
|
17: BlobField,
|
||||||
|
20: BigIntegerField,
|
||||||
|
21: IntegerField,
|
||||||
|
23: IntegerField,
|
||||||
|
25: TextField,
|
||||||
|
700: FloatField,
|
||||||
|
701: FloatField,
|
||||||
|
1042: CharField, # blank-padded CHAR
|
||||||
|
1043: CharField,
|
||||||
|
1082: DateField,
|
||||||
|
1114: DateTimeField,
|
||||||
|
1184: DateTimeField,
|
||||||
|
1083: TimeField,
|
||||||
|
1266: TimeField,
|
||||||
|
1700: DecimalField,
|
||||||
|
2950: TextField, # UUID
|
||||||
|
}
|
||||||
|
array_types = {
|
||||||
|
1000: BooleanField,
|
||||||
|
1001: BlobField,
|
||||||
|
1005: SmallIntegerField,
|
||||||
|
1007: IntegerField,
|
||||||
|
1009: TextField,
|
||||||
|
1014: CharField,
|
||||||
|
1015: CharField,
|
||||||
|
1016: BigIntegerField,
|
||||||
|
1115: DateTimeField,
|
||||||
|
1182: DateField,
|
||||||
|
1183: TimeField,
|
||||||
|
}
|
||||||
|
extension_import = 'from playhouse.postgres_ext import *'
|
||||||
|
|
||||||
|
def __init__(self, database):
|
||||||
|
super(PostgresqlMetadata, self).__init__(database)
|
||||||
|
|
||||||
|
if postgres_ext is not None:
|
||||||
|
# Attempt to add types like HStore and JSON.
|
||||||
|
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
|
||||||
|
' from pg_type;')
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
for oid, typname, formatted_type in results:
|
||||||
|
if typname == 'json':
|
||||||
|
self.column_map[oid] = postgres_ext.JSONField
|
||||||
|
elif typname == 'jsonb':
|
||||||
|
self.column_map[oid] = postgres_ext.BinaryJSONField
|
||||||
|
elif typname == 'hstore':
|
||||||
|
self.column_map[oid] = postgres_ext.HStoreField
|
||||||
|
elif typname == 'tsvector':
|
||||||
|
self.column_map[oid] = postgres_ext.TSVectorField
|
||||||
|
|
||||||
|
for oid in self.array_types:
|
||||||
|
self.column_map[oid] = postgres_ext.ArrayField
|
||||||
|
|
||||||
|
def get_column_types(self, table, schema):
|
||||||
|
column_types = {}
|
||||||
|
extra_params = {}
|
||||||
|
extension_types = set((
|
||||||
|
postgres_ext.ArrayField,
|
||||||
|
postgres_ext.BinaryJSONField,
|
||||||
|
postgres_ext.JSONField,
|
||||||
|
postgres_ext.TSVectorField,
|
||||||
|
postgres_ext.HStoreField)) if postgres_ext is not None else set()
|
||||||
|
|
||||||
|
# Look up the actual column type for each column.
|
||||||
|
identifier = '"%s"."%s"' % (schema, table)
|
||||||
|
cursor = self.execute('SELECT * FROM %s LIMIT 1' % identifier)
|
||||||
|
|
||||||
|
# Store column metadata in dictionary keyed by column name.
|
||||||
|
for column_description in cursor.description:
|
||||||
|
name = column_description.name
|
||||||
|
oid = column_description.type_code
|
||||||
|
column_types[name] = self.column_map.get(oid, UnknownField)
|
||||||
|
if column_types[name] in extension_types:
|
||||||
|
self.requires_extension = True
|
||||||
|
if oid in self.array_types:
|
||||||
|
extra_params[name] = {'field_class': self.array_types[oid]}
|
||||||
|
|
||||||
|
return column_types, extra_params
|
||||||
|
|
||||||
|
def get_columns(self, table, schema=None):
|
||||||
|
schema = schema or 'public'
|
||||||
|
return super(PostgresqlMetadata, self).get_columns(table, schema)
|
||||||
|
|
||||||
|
def get_foreign_keys(self, table, schema=None):
|
||||||
|
schema = schema or 'public'
|
||||||
|
return super(PostgresqlMetadata, self).get_foreign_keys(table, schema)
|
||||||
|
|
||||||
|
def get_primary_keys(self, table, schema=None):
|
||||||
|
schema = schema or 'public'
|
||||||
|
return super(PostgresqlMetadata, self).get_primary_keys(table, schema)
|
||||||
|
|
||||||
|
def get_indexes(self, table, schema=None):
|
||||||
|
schema = schema or 'public'
|
||||||
|
return super(PostgresqlMetadata, self).get_indexes(table, schema)
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLMetadata(Metadata):
|
||||||
|
if FIELD_TYPE is None:
|
||||||
|
column_map = {}
|
||||||
|
else:
|
||||||
|
column_map = {
|
||||||
|
FIELD_TYPE.BLOB: TextField,
|
||||||
|
FIELD_TYPE.CHAR: CharField,
|
||||||
|
FIELD_TYPE.DATE: DateField,
|
||||||
|
FIELD_TYPE.DATETIME: DateTimeField,
|
||||||
|
FIELD_TYPE.DECIMAL: DecimalField,
|
||||||
|
FIELD_TYPE.DOUBLE: FloatField,
|
||||||
|
FIELD_TYPE.FLOAT: FloatField,
|
||||||
|
FIELD_TYPE.INT24: IntegerField,
|
||||||
|
FIELD_TYPE.LONG_BLOB: TextField,
|
||||||
|
FIELD_TYPE.LONG: IntegerField,
|
||||||
|
FIELD_TYPE.LONGLONG: BigIntegerField,
|
||||||
|
FIELD_TYPE.MEDIUM_BLOB: TextField,
|
||||||
|
FIELD_TYPE.NEWDECIMAL: DecimalField,
|
||||||
|
FIELD_TYPE.SHORT: IntegerField,
|
||||||
|
FIELD_TYPE.STRING: CharField,
|
||||||
|
FIELD_TYPE.TIMESTAMP: DateTimeField,
|
||||||
|
FIELD_TYPE.TIME: TimeField,
|
||||||
|
FIELD_TYPE.TINY_BLOB: TextField,
|
||||||
|
FIELD_TYPE.TINY: IntegerField,
|
||||||
|
FIELD_TYPE.VAR_STRING: CharField,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, database, **kwargs):
|
||||||
|
if 'password' in kwargs:
|
||||||
|
kwargs['passwd'] = kwargs.pop('password')
|
||||||
|
super(MySQLMetadata, self).__init__(database, **kwargs)
|
||||||
|
|
||||||
|
def get_column_types(self, table, schema=None):
|
||||||
|
column_types = {}
|
||||||
|
|
||||||
|
# Look up the actual column type for each column.
|
||||||
|
cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table)
|
||||||
|
|
||||||
|
# Store column metadata in dictionary keyed by column name.
|
||||||
|
for column_description in cursor.description:
|
||||||
|
name, type_code = column_description[:2]
|
||||||
|
column_types[name] = self.column_map.get(type_code, UnknownField)
|
||||||
|
|
||||||
|
return column_types, {}
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteMetadata(Metadata):
|
||||||
|
column_map = {
|
||||||
|
'bigint': BigIntegerField,
|
||||||
|
'blob': BlobField,
|
||||||
|
'bool': BooleanField,
|
||||||
|
'boolean': BooleanField,
|
||||||
|
'char': CharField,
|
||||||
|
'date': DateField,
|
||||||
|
'datetime': DateTimeField,
|
||||||
|
'decimal': DecimalField,
|
||||||
|
'float': FloatField,
|
||||||
|
'integer': IntegerField,
|
||||||
|
'integer unsigned': IntegerField,
|
||||||
|
'int': IntegerField,
|
||||||
|
'long': BigIntegerField,
|
||||||
|
'numeric': DecimalField,
|
||||||
|
'real': FloatField,
|
||||||
|
'smallinteger': IntegerField,
|
||||||
|
'smallint': IntegerField,
|
||||||
|
'smallint unsigned': IntegerField,
|
||||||
|
'text': TextField,
|
||||||
|
'time': TimeField,
|
||||||
|
'varchar': CharField,
|
||||||
|
}
|
||||||
|
|
||||||
|
begin = '(?:["\[\(]+)?'
|
||||||
|
end = '(?:["\]\)]+)?'
|
||||||
|
re_foreign_key = (
|
||||||
|
'(?:FOREIGN KEY\s*)?'
|
||||||
|
'{begin}(.+?){end}\s+(?:.+\s+)?'
|
||||||
|
'references\s+{begin}(.+?){end}'
|
||||||
|
'\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end)
|
||||||
|
re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$'
|
||||||
|
|
||||||
|
def _map_col(self, column_type):
|
||||||
|
raw_column_type = column_type.lower()
|
||||||
|
if raw_column_type in self.column_map:
|
||||||
|
field_class = self.column_map[raw_column_type]
|
||||||
|
elif re.search(self.re_varchar, raw_column_type):
|
||||||
|
field_class = CharField
|
||||||
|
else:
|
||||||
|
column_type = re.sub('\(.+\)', '', raw_column_type)
|
||||||
|
if column_type == '':
|
||||||
|
field_class = BareField
|
||||||
|
else:
|
||||||
|
field_class = self.column_map.get(column_type, UnknownField)
|
||||||
|
return field_class
|
||||||
|
|
||||||
|
def get_column_types(self, table, schema=None):
|
||||||
|
column_types = {}
|
||||||
|
columns = self.database.get_columns(table)
|
||||||
|
|
||||||
|
for column in columns:
|
||||||
|
column_types[column.name] = self._map_col(column.data_type)
|
||||||
|
|
||||||
|
return column_types, {}
|
||||||
|
|
||||||
|
|
||||||
|
_DatabaseMetadata = namedtuple('_DatabaseMetadata', (
|
||||||
|
'columns',
|
||||||
|
'primary_keys',
|
||||||
|
'foreign_keys',
|
||||||
|
'model_names',
|
||||||
|
'indexes'))
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMetadata(_DatabaseMetadata):
|
||||||
|
def multi_column_indexes(self, table):
|
||||||
|
accum = []
|
||||||
|
for index in self.indexes[table]:
|
||||||
|
if len(index.columns) > 1:
|
||||||
|
field_names = [self.columns[table][column].name
|
||||||
|
for column in index.columns
|
||||||
|
if column in self.columns[table]]
|
||||||
|
accum.append((field_names, index.unique))
|
||||||
|
return accum
|
||||||
|
|
||||||
|
def column_indexes(self, table):
|
||||||
|
accum = {}
|
||||||
|
for index in self.indexes[table]:
|
||||||
|
if len(index.columns) == 1:
|
||||||
|
accum[index.columns[0]] = index.unique
|
||||||
|
return accum
|
||||||
|
|
||||||
|
|
||||||
|
class Introspector(object):
|
||||||
|
pk_classes = [AutoField, IntegerField]
|
||||||
|
|
||||||
|
def __init__(self, metadata, schema=None):
|
||||||
|
self.metadata = metadata
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<Introspector: %s>' % self.metadata.database
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_database(cls, database, schema=None):
|
||||||
|
if isinstance(database, PostgresqlDatabase):
|
||||||
|
metadata = PostgresqlMetadata(database)
|
||||||
|
elif isinstance(database, MySQLDatabase):
|
||||||
|
metadata = MySQLMetadata(database)
|
||||||
|
elif isinstance(database, SqliteDatabase):
|
||||||
|
metadata = SqliteMetadata(database)
|
||||||
|
else:
|
||||||
|
raise ValueError('Introspection not supported for %r' % database)
|
||||||
|
return cls(metadata, schema=schema)
|
||||||
|
|
||||||
|
def get_database_class(self):
|
||||||
|
return type(self.metadata.database)
|
||||||
|
|
||||||
|
def get_database_name(self):
|
||||||
|
return self.metadata.database.database
|
||||||
|
|
||||||
|
def get_database_kwargs(self):
|
||||||
|
return self.metadata.database.connect_params
|
||||||
|
|
||||||
|
def get_additional_imports(self):
|
||||||
|
if self.metadata.requires_extension:
|
||||||
|
return '\n' + self.metadata.extension_import
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def make_model_name(self, table, snake_case=True):
|
||||||
|
if snake_case:
|
||||||
|
table = make_snake_case(table)
|
||||||
|
model = re.sub('[^\w]+', '', table)
|
||||||
|
model_name = ''.join(sub.title() for sub in model.split('_'))
|
||||||
|
if not model_name[0].isalpha():
|
||||||
|
model_name = 'T' + model_name
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def make_column_name(self, column, is_foreign_key=False, snake_case=True):
|
||||||
|
column = column.strip()
|
||||||
|
if snake_case:
|
||||||
|
column = make_snake_case(column)
|
||||||
|
column = column.lower()
|
||||||
|
if is_foreign_key:
|
||||||
|
# Strip "_id" from foreign keys, unless the foreign-key happens to
|
||||||
|
# be named "_id", in which case the name is retained.
|
||||||
|
column = re.sub('_id$', '', column) or column
|
||||||
|
|
||||||
|
# Remove characters that are invalid for Python identifiers.
|
||||||
|
column = re.sub('[^\w]+', '_', column)
|
||||||
|
if column in RESERVED_WORDS:
|
||||||
|
column += '_'
|
||||||
|
if len(column) and column[0].isdigit():
|
||||||
|
column = '_' + column
|
||||||
|
return column
|
||||||
|
|
||||||
|
def introspect(self, table_names=None, literal_column_names=False,
|
||||||
|
include_views=False, snake_case=True):
|
||||||
|
# Retrieve all the tables in the database.
|
||||||
|
tables = self.metadata.database.get_tables(schema=self.schema)
|
||||||
|
if include_views:
|
||||||
|
views = self.metadata.database.get_views(schema=self.schema)
|
||||||
|
tables.extend([view.name for view in views])
|
||||||
|
|
||||||
|
if table_names is not None:
|
||||||
|
tables = [table for table in tables if table in table_names]
|
||||||
|
table_set = set(tables)
|
||||||
|
|
||||||
|
# Store a mapping of table name -> dictionary of columns.
|
||||||
|
columns = {}
|
||||||
|
|
||||||
|
# Store a mapping of table name -> set of primary key columns.
|
||||||
|
primary_keys = {}
|
||||||
|
|
||||||
|
# Store a mapping of table -> foreign keys.
|
||||||
|
foreign_keys = {}
|
||||||
|
|
||||||
|
# Store a mapping of table name -> model name.
|
||||||
|
model_names = {}
|
||||||
|
|
||||||
|
# Store a mapping of table name -> indexes.
|
||||||
|
indexes = {}
|
||||||
|
|
||||||
|
# Gather the columns for each table.
|
||||||
|
for table in tables:
|
||||||
|
table_indexes = self.metadata.get_indexes(table, self.schema)
|
||||||
|
table_columns = self.metadata.get_columns(table, self.schema)
|
||||||
|
try:
|
||||||
|
foreign_keys[table] = self.metadata.get_foreign_keys(
|
||||||
|
table, self.schema)
|
||||||
|
except ValueError as exc:
|
||||||
|
err(*exc.args)
|
||||||
|
foreign_keys[table] = []
|
||||||
|
else:
|
||||||
|
# If there is a possibility we could exclude a dependent table,
|
||||||
|
# ensure that we introspect it so FKs will work.
|
||||||
|
if table_names is not None:
|
||||||
|
for foreign_key in foreign_keys[table]:
|
||||||
|
if foreign_key.dest_table not in table_set:
|
||||||
|
tables.append(foreign_key.dest_table)
|
||||||
|
table_set.add(foreign_key.dest_table)
|
||||||
|
|
||||||
|
model_names[table] = self.make_model_name(table, snake_case)
|
||||||
|
|
||||||
|
# Collect sets of all the column names as well as all the
|
||||||
|
# foreign-key column names.
|
||||||
|
lower_col_names = set(column_name.lower()
|
||||||
|
for column_name in table_columns)
|
||||||
|
fks = set(fk_col.column for fk_col in foreign_keys[table])
|
||||||
|
|
||||||
|
for col_name, column in table_columns.items():
|
||||||
|
if literal_column_names:
|
||||||
|
new_name = re.sub('[^\w]+', '_', col_name)
|
||||||
|
else:
|
||||||
|
new_name = self.make_column_name(col_name, col_name in fks,
|
||||||
|
snake_case)
|
||||||
|
|
||||||
|
# If we have two columns, "parent" and "parent_id", ensure
|
||||||
|
# that when we don't introduce naming conflicts.
|
||||||
|
lower_name = col_name.lower()
|
||||||
|
if lower_name.endswith('_id') and new_name in lower_col_names:
|
||||||
|
new_name = col_name.lower()
|
||||||
|
|
||||||
|
column.name = new_name
|
||||||
|
|
||||||
|
for index in table_indexes:
|
||||||
|
if len(index.columns) == 1:
|
||||||
|
column = index.columns[0]
|
||||||
|
if column in table_columns:
|
||||||
|
table_columns[column].unique = index.unique
|
||||||
|
table_columns[column].index = True
|
||||||
|
|
||||||
|
primary_keys[table] = self.metadata.get_primary_keys(
|
||||||
|
table, self.schema)
|
||||||
|
columns[table] = table_columns
|
||||||
|
indexes[table] = table_indexes
|
||||||
|
|
||||||
|
# Gather all instances where we might have a `related_name` conflict,
|
||||||
|
# either due to multiple FKs on a table pointing to the same table,
|
||||||
|
# or a related_name that would conflict with an existing field.
|
||||||
|
related_names = {}
|
||||||
|
sort_fn = lambda foreign_key: foreign_key.column
|
||||||
|
for table in tables:
|
||||||
|
models_referenced = set()
|
||||||
|
for foreign_key in sorted(foreign_keys[table], key=sort_fn):
|
||||||
|
try:
|
||||||
|
column = columns[table][foreign_key.column]
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dest_table = foreign_key.dest_table
|
||||||
|
if dest_table in models_referenced:
|
||||||
|
related_names[column] = '%s_%s_set' % (
|
||||||
|
dest_table,
|
||||||
|
column.name)
|
||||||
|
else:
|
||||||
|
models_referenced.add(dest_table)
|
||||||
|
|
||||||
|
# On the second pass convert all foreign keys.
|
||||||
|
for table in tables:
|
||||||
|
for foreign_key in foreign_keys[table]:
|
||||||
|
src = columns[foreign_key.table][foreign_key.column]
|
||||||
|
try:
|
||||||
|
dest = columns[foreign_key.dest_table][
|
||||||
|
foreign_key.dest_column]
|
||||||
|
except KeyError:
|
||||||
|
dest = None
|
||||||
|
|
||||||
|
src.set_foreign_key(
|
||||||
|
foreign_key=foreign_key,
|
||||||
|
model_names=model_names,
|
||||||
|
dest=dest,
|
||||||
|
related_name=related_names.get(src))
|
||||||
|
|
||||||
|
return DatabaseMetadata(
|
||||||
|
columns,
|
||||||
|
primary_keys,
|
||||||
|
foreign_keys,
|
||||||
|
model_names,
|
||||||
|
indexes)
|
||||||
|
|
||||||
|
def generate_models(self, skip_invalid=False, table_names=None,
|
||||||
|
literal_column_names=False, bare_fields=False,
|
||||||
|
include_views=False):
|
||||||
|
database = self.introspect(table_names, literal_column_names,
|
||||||
|
include_views)
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = self.metadata.database
|
||||||
|
schema = self.schema
|
||||||
|
|
||||||
|
def _create_model(table, models):
|
||||||
|
for foreign_key in database.foreign_keys[table]:
|
||||||
|
dest = foreign_key.dest_table
|
||||||
|
|
||||||
|
if dest not in models and dest != table:
|
||||||
|
_create_model(dest, models)
|
||||||
|
|
||||||
|
primary_keys = []
|
||||||
|
columns = database.columns[table]
|
||||||
|
for column_name, column in columns.items():
|
||||||
|
if column.primary_key:
|
||||||
|
primary_keys.append(column.name)
|
||||||
|
|
||||||
|
multi_column_indexes = database.multi_column_indexes(table)
|
||||||
|
column_indexes = database.column_indexes(table)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
indexes = multi_column_indexes
|
||||||
|
table_name = table
|
||||||
|
|
||||||
|
# Fix models with multi-column primary keys.
|
||||||
|
composite_key = False
|
||||||
|
if len(primary_keys) == 0:
|
||||||
|
primary_keys = columns.keys()
|
||||||
|
if len(primary_keys) > 1:
|
||||||
|
Meta.primary_key = CompositeKey(*[
|
||||||
|
field.name for col, field in columns.items()
|
||||||
|
if col in primary_keys])
|
||||||
|
composite_key = True
|
||||||
|
|
||||||
|
attrs = {'Meta': Meta}
|
||||||
|
for column_name, column in columns.items():
|
||||||
|
FieldClass = column.field_class
|
||||||
|
if FieldClass is not ForeignKeyField and bare_fields:
|
||||||
|
FieldClass = BareField
|
||||||
|
elif FieldClass is UnknownField:
|
||||||
|
FieldClass = BareField
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'column_name': column_name,
|
||||||
|
'null': column.nullable}
|
||||||
|
if column.primary_key and composite_key:
|
||||||
|
if FieldClass is AutoField:
|
||||||
|
FieldClass = IntegerField
|
||||||
|
params['primary_key'] = False
|
||||||
|
elif column.primary_key and FieldClass is not AutoField:
|
||||||
|
params['primary_key'] = True
|
||||||
|
if column.is_foreign_key():
|
||||||
|
if column.is_self_referential_fk():
|
||||||
|
params['model'] = 'self'
|
||||||
|
else:
|
||||||
|
dest_table = column.foreign_key.dest_table
|
||||||
|
params['model'] = models[dest_table]
|
||||||
|
if column.to_field:
|
||||||
|
params['field'] = column.to_field
|
||||||
|
|
||||||
|
# Generate a unique related name.
|
||||||
|
params['backref'] = '%s_%s_rel' % (table, column_name)
|
||||||
|
|
||||||
|
if column.default is not None:
|
||||||
|
constraint = SQL('DEFAULT %s' % column.default)
|
||||||
|
params['constraints'] = [constraint]
|
||||||
|
|
||||||
|
if column_name in column_indexes and not \
|
||||||
|
column.is_primary_key():
|
||||||
|
if column_indexes[column_name]:
|
||||||
|
params['unique'] = True
|
||||||
|
elif not column.is_foreign_key():
|
||||||
|
params['index'] = True
|
||||||
|
|
||||||
|
attrs[column.name] = FieldClass(**params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
models[table] = type(str(table), (BaseModel,), attrs)
|
||||||
|
except ValueError:
|
||||||
|
if not skip_invalid:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Actually generate Model classes.
|
||||||
|
for table, model in sorted(database.model_names.items()):
|
||||||
|
if table not in models:
|
||||||
|
_create_model(table, models)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def introspect(database, schema=None):
|
||||||
|
introspector = Introspector.from_database(database, schema=schema)
|
||||||
|
return introspector.introspect()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_models(database, schema=None, **options):
|
||||||
|
introspector = Introspector.from_database(database, schema=schema)
|
||||||
|
return introspector.generate_models(**options)
|
||||||
|
|
||||||
|
|
||||||
|
def print_model(model, indexes=True, inline_indexes=False):
|
||||||
|
print(model._meta.name)
|
||||||
|
for field in model._meta.sorted_fields:
|
||||||
|
parts = [' %s %s' % (field.name, field.field_type)]
|
||||||
|
if field.primary_key:
|
||||||
|
parts.append(' PK')
|
||||||
|
elif inline_indexes:
|
||||||
|
if field.unique:
|
||||||
|
parts.append(' UNIQUE')
|
||||||
|
elif field.index:
|
||||||
|
parts.append(' INDEX')
|
||||||
|
if isinstance(field, ForeignKeyField):
|
||||||
|
parts.append(' FK: %s.%s' % (field.rel_model.__name__,
|
||||||
|
field.rel_field.name))
|
||||||
|
print(''.join(parts))
|
||||||
|
|
||||||
|
if indexes:
|
||||||
|
index_list = model._meta.fields_to_index()
|
||||||
|
if not index_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
print('\nindex(es)')
|
||||||
|
for index in index_list:
|
||||||
|
parts = [' ']
|
||||||
|
ctx = model._meta.database.get_sql_context()
|
||||||
|
with ctx.scope_values(param='%s', quote='""'):
|
||||||
|
ctx.sql(CommaNodeList(index._expressions))
|
||||||
|
if index._where:
|
||||||
|
ctx.literal(' WHERE ')
|
||||||
|
ctx.sql(index._where)
|
||||||
|
sql, params = ctx.query()
|
||||||
|
|
||||||
|
clean = sql % tuple(map(_query_val_transform, params))
|
||||||
|
parts.append(clean.replace('"', ''))
|
||||||
|
|
||||||
|
if index._unique:
|
||||||
|
parts.append(' UNIQUE')
|
||||||
|
print(''.join(parts))
|
||||||
|
|
||||||
|
|
||||||
|
def get_table_sql(model):
|
||||||
|
sql, params = model._schema._create_table().query()
|
||||||
|
if model._meta.database.param != '%s':
|
||||||
|
sql = sql.replace(model._meta.database.param, '%s')
|
||||||
|
|
||||||
|
# Format and indent the table declaration, simplest possible approach.
|
||||||
|
match_obj = re.match('^(.+?\()(.+)(\).*)', sql)
|
||||||
|
create, columns, extra = match_obj.groups()
|
||||||
|
indented = ',\n'.join(' %s' % column for column in columns.split(', '))
|
||||||
|
|
||||||
|
clean = '\n'.join((create, indented, extra)).strip()
|
||||||
|
return clean % tuple(map(_query_val_transform, params))
|
||||||
|
|
||||||
|
def print_table_sql(model):
|
||||||
|
print(get_table_sql(model))
|
@ -0,0 +1,224 @@
|
|||||||
|
from peewee import *
|
||||||
|
from peewee import Alias
|
||||||
|
from peewee import SENTINEL
|
||||||
|
from peewee import callable_
|
||||||
|
|
||||||
|
|
||||||
|
_clone_set = lambda s: set(s) if s else set()
|
||||||
|
|
||||||
|
|
||||||
|
def model_to_dict(model, recurse=True, backrefs=False, only=None,
|
||||||
|
exclude=None, seen=None, extra_attrs=None,
|
||||||
|
fields_from_query=None, max_depth=None, manytomany=False):
|
||||||
|
"""
|
||||||
|
Convert a model instance (and any related objects) to a dictionary.
|
||||||
|
|
||||||
|
:param bool recurse: Whether foreign-keys should be recursed.
|
||||||
|
:param bool backrefs: Whether lists of related objects should be recursed.
|
||||||
|
:param only: A list (or set) of field instances indicating which fields
|
||||||
|
should be included.
|
||||||
|
:param exclude: A list (or set) of field instances that should be
|
||||||
|
excluded from the dictionary.
|
||||||
|
:param list extra_attrs: Names of model instance attributes or methods
|
||||||
|
that should be included.
|
||||||
|
:param SelectQuery fields_from_query: Query that was source of model. Take
|
||||||
|
fields explicitly selected by the query and serialize them.
|
||||||
|
:param int max_depth: Maximum depth to recurse, value <= 0 means no max.
|
||||||
|
:param bool manytomany: Process many-to-many fields.
|
||||||
|
"""
|
||||||
|
max_depth = -1 if max_depth is None else max_depth
|
||||||
|
if max_depth == 0:
|
||||||
|
recurse = False
|
||||||
|
|
||||||
|
only = _clone_set(only)
|
||||||
|
extra_attrs = _clone_set(extra_attrs)
|
||||||
|
should_skip = lambda n: (n in exclude) or (only and (n not in only))
|
||||||
|
|
||||||
|
if fields_from_query is not None:
|
||||||
|
for item in fields_from_query._returning:
|
||||||
|
if isinstance(item, Field):
|
||||||
|
only.add(item)
|
||||||
|
elif isinstance(item, Alias):
|
||||||
|
extra_attrs.add(item._alias)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
exclude = _clone_set(exclude)
|
||||||
|
seen = _clone_set(seen)
|
||||||
|
exclude |= seen
|
||||||
|
model_class = type(model)
|
||||||
|
|
||||||
|
if manytomany:
|
||||||
|
for name, m2m in model._meta.manytomany.items():
|
||||||
|
if should_skip(name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref]))
|
||||||
|
for fkf in m2m.through_model._meta.refs:
|
||||||
|
exclude.add(fkf)
|
||||||
|
|
||||||
|
accum = []
|
||||||
|
for rel_obj in getattr(model, name):
|
||||||
|
accum.append(model_to_dict(
|
||||||
|
rel_obj,
|
||||||
|
recurse=recurse,
|
||||||
|
backrefs=backrefs,
|
||||||
|
only=only,
|
||||||
|
exclude=exclude,
|
||||||
|
max_depth=max_depth - 1))
|
||||||
|
data[name] = accum
|
||||||
|
|
||||||
|
for field in model._meta.sorted_fields:
|
||||||
|
if should_skip(field):
|
||||||
|
continue
|
||||||
|
|
||||||
|
field_data = model.__data__.get(field.name)
|
||||||
|
if isinstance(field, ForeignKeyField) and recurse:
|
||||||
|
if field_data:
|
||||||
|
seen.add(field)
|
||||||
|
rel_obj = getattr(model, field.name)
|
||||||
|
field_data = model_to_dict(
|
||||||
|
rel_obj,
|
||||||
|
recurse=recurse,
|
||||||
|
backrefs=backrefs,
|
||||||
|
only=only,
|
||||||
|
exclude=exclude,
|
||||||
|
seen=seen,
|
||||||
|
max_depth=max_depth - 1)
|
||||||
|
else:
|
||||||
|
field_data = None
|
||||||
|
|
||||||
|
data[field.name] = field_data
|
||||||
|
|
||||||
|
if extra_attrs:
|
||||||
|
for attr_name in extra_attrs:
|
||||||
|
attr = getattr(model, attr_name)
|
||||||
|
if callable_(attr):
|
||||||
|
data[attr_name] = attr()
|
||||||
|
else:
|
||||||
|
data[attr_name] = attr
|
||||||
|
|
||||||
|
if backrefs and recurse:
|
||||||
|
for foreign_key, rel_model in model._meta.backrefs.items():
|
||||||
|
if foreign_key.backref == '+': continue
|
||||||
|
descriptor = getattr(model_class, foreign_key.backref)
|
||||||
|
if descriptor in exclude or foreign_key in exclude:
|
||||||
|
continue
|
||||||
|
if only and (descriptor not in only) and (foreign_key not in only):
|
||||||
|
continue
|
||||||
|
|
||||||
|
accum = []
|
||||||
|
exclude.add(foreign_key)
|
||||||
|
related_query = getattr(model, foreign_key.backref)
|
||||||
|
|
||||||
|
for rel_obj in related_query:
|
||||||
|
accum.append(model_to_dict(
|
||||||
|
rel_obj,
|
||||||
|
recurse=recurse,
|
||||||
|
backrefs=backrefs,
|
||||||
|
only=only,
|
||||||
|
exclude=exclude,
|
||||||
|
max_depth=max_depth - 1))
|
||||||
|
|
||||||
|
data[foreign_key.backref] = accum
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def update_model_from_dict(instance, data, ignore_unknown=False):
|
||||||
|
meta = instance._meta
|
||||||
|
backrefs = dict([(fk.backref, fk) for fk in meta.backrefs])
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
if key in meta.combined:
|
||||||
|
field = meta.combined[key]
|
||||||
|
is_backref = False
|
||||||
|
elif key in backrefs:
|
||||||
|
field = backrefs[key]
|
||||||
|
is_backref = True
|
||||||
|
elif ignore_unknown:
|
||||||
|
setattr(instance, key, value)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise AttributeError('Unrecognized attribute "%s" for model '
|
||||||
|
'class %s.' % (key, type(instance)))
|
||||||
|
|
||||||
|
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||||
|
|
||||||
|
if not is_backref and is_foreign_key and isinstance(value, dict):
|
||||||
|
try:
|
||||||
|
rel_instance = instance.__rel__[field.name]
|
||||||
|
except KeyError:
|
||||||
|
rel_instance = field.rel_model()
|
||||||
|
setattr(
|
||||||
|
instance,
|
||||||
|
field.name,
|
||||||
|
update_model_from_dict(rel_instance, value, ignore_unknown))
|
||||||
|
elif is_backref and isinstance(value, (list, tuple)):
|
||||||
|
instances = [
|
||||||
|
dict_to_model(field.model, row_data, ignore_unknown)
|
||||||
|
for row_data in value]
|
||||||
|
for rel_instance in instances:
|
||||||
|
setattr(rel_instance, field.name, instance)
|
||||||
|
setattr(instance, field.backref, instances)
|
||||||
|
else:
|
||||||
|
setattr(instance, field.name, value)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_model(model_class, data, ignore_unknown=False):
|
||||||
|
return update_model_from_dict(model_class(), data, ignore_unknown)
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectMixin(object):
|
||||||
|
"""
|
||||||
|
Mixin class that attempts to automatically reconnect to the database under
|
||||||
|
certain error conditions.
|
||||||
|
|
||||||
|
For example, MySQL servers will typically close connections that are idle
|
||||||
|
for 28800 seconds ("wait_timeout" setting). If your application makes use
|
||||||
|
of long-lived connections, you may find your connections are closed after
|
||||||
|
a period of no activity. This mixin will attempt to reconnect automatically
|
||||||
|
when these errors occur.
|
||||||
|
|
||||||
|
This mixin class probably should not be used with Postgres (unless you
|
||||||
|
REALLY know what you are doing) and definitely has no business being used
|
||||||
|
with Sqlite. If you wish to use with Postgres, you will need to adapt the
|
||||||
|
`reconnect_errors` attribute to something appropriate for Postgres.
|
||||||
|
"""
|
||||||
|
reconnect_errors = (
|
||||||
|
# Error class, error message fragment (or empty string for all).
|
||||||
|
(OperationalError, '2006'), # MySQL server has gone away.
|
||||||
|
(OperationalError, '2013'), # Lost connection to MySQL server.
|
||||||
|
(OperationalError, '2014'), # Commands out of sync.
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ReconnectMixin, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Normalize the reconnect errors to a more efficient data-structure.
|
||||||
|
self._reconnect_errors = {}
|
||||||
|
for exc_class, err_fragment in self.reconnect_errors:
|
||||||
|
self._reconnect_errors.setdefault(exc_class, [])
|
||||||
|
self._reconnect_errors[exc_class].append(err_fragment.lower())
|
||||||
|
|
||||||
|
def execute_sql(self, sql, params=None, commit=SENTINEL):
|
||||||
|
try:
|
||||||
|
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
||||||
|
except Exception as exc:
|
||||||
|
exc_class = type(exc)
|
||||||
|
if exc_class not in self._reconnect_errors:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
exc_repr = str(exc).lower()
|
||||||
|
for err_fragment in self._reconnect_errors[exc_class]:
|
||||||
|
if err_fragment in exc_repr:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
if not self.is_closed():
|
||||||
|
self.close()
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
Provide django-style hooks for model events.
|
||||||
|
"""
|
||||||
|
from peewee import Model as _Model
|
||||||
|
|
||||||
|
|
||||||
|
class Signal(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
def _flush(self):
|
||||||
|
self._receivers = set()
|
||||||
|
self._receiver_list = []
|
||||||
|
|
||||||
|
def connect(self, receiver, name=None, sender=None):
|
||||||
|
name = name or receiver.__name__
|
||||||
|
key = (name, sender)
|
||||||
|
if key not in self._receivers:
|
||||||
|
self._receivers.add(key)
|
||||||
|
self._receiver_list.append((name, receiver, sender))
|
||||||
|
else:
|
||||||
|
raise ValueError('receiver named %s (for sender=%s) already '
|
||||||
|
'connected' % (name, sender or 'any'))
|
||||||
|
|
||||||
|
def disconnect(self, receiver=None, name=None, sender=None):
|
||||||
|
if receiver:
|
||||||
|
name = name or receiver.__name__
|
||||||
|
if not name:
|
||||||
|
raise ValueError('a receiver or a name must be provided')
|
||||||
|
|
||||||
|
key = (name, sender)
|
||||||
|
if key not in self._receivers:
|
||||||
|
raise ValueError('receiver named %s for sender=%s not found.' %
|
||||||
|
(name, sender or 'any'))
|
||||||
|
|
||||||
|
self._receivers.remove(key)
|
||||||
|
self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list
|
||||||
|
if n != name and s != sender]
|
||||||
|
|
||||||
|
def __call__(self, name=None, sender=None):
|
||||||
|
def decorator(fn):
|
||||||
|
self.connect(fn, name, sender)
|
||||||
|
return fn
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def send(self, instance, *args, **kwargs):
|
||||||
|
sender = type(instance)
|
||||||
|
responses = []
|
||||||
|
for n, r, s in self._receiver_list:
|
||||||
|
if s is None or isinstance(instance, s):
|
||||||
|
responses.append((r, r(sender, instance, *args, **kwargs)))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
pre_save = Signal()
|
||||||
|
post_save = Signal()
|
||||||
|
pre_delete = Signal()
|
||||||
|
post_delete = Signal()
|
||||||
|
pre_init = Signal()
|
||||||
|
|
||||||
|
|
||||||
|
class Model(_Model):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Model, self).__init__(*args, **kwargs)
|
||||||
|
pre_init.send(self)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
pk_value = self._pk
|
||||||
|
created = kwargs.get('force_insert', False) or not bool(pk_value)
|
||||||
|
pre_save.send(self, created=created)
|
||||||
|
ret = super(Model, self).save(*args, **kwargs)
|
||||||
|
post_save.send(self, created=created)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def delete_instance(self, *args, **kwargs):
|
||||||
|
pre_delete.send(self)
|
||||||
|
ret = super(Model, self).delete_instance(*args, **kwargs)
|
||||||
|
post_delete.send(self)
|
||||||
|
return ret
|
@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Peewee integration with pysqlcipher.
|
||||||
|
|
||||||
|
Project page: https://github.com/leapcode/pysqlcipher/
|
||||||
|
|
||||||
|
**WARNING!!! EXPERIMENTAL!!!**
|
||||||
|
|
||||||
|
* Although this extention's code is short, it has not been properly
|
||||||
|
peer-reviewed yet and may have introduced vulnerabilities.
|
||||||
|
|
||||||
|
Also note that this code relies on pysqlcipher and sqlcipher, and
|
||||||
|
the code there might have vulnerabilities as well, but since these
|
||||||
|
are widely used crypto modules, we can expect "short zero days" there.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
from peewee.playground.ciphersql_ext import SqlCipherDatabase
|
||||||
|
db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real")
|
||||||
|
|
||||||
|
* `passphrase`: should be "long enough".
|
||||||
|
Note that *length beats vocabulary* (much exponential), and even
|
||||||
|
a lowercase-only passphrase like easytorememberyethardforotherstoguess
|
||||||
|
packs more noise than 8 random printable characters and *can* be memorized.
|
||||||
|
|
||||||
|
When opening an existing database, passphrase should be the one used when the
|
||||||
|
database was created. If the passphrase is incorrect, an exception will only be
|
||||||
|
raised **when you access the database**.
|
||||||
|
|
||||||
|
If you need to ask for an interactive passphrase, here's example code you can
|
||||||
|
put after the `db = ...` line:
|
||||||
|
|
||||||
|
try: # Just access the database so that it checks the encryption.
|
||||||
|
db.get_tables()
|
||||||
|
# We're looking for a DatabaseError with a specific error message.
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
# Check whether the message *means* "passphrase is wrong"
|
||||||
|
if e.args[0] == 'file is encrypted or is not a database':
|
||||||
|
raise Exception('Developer should Prompt user for passphrase '
|
||||||
|
'again.')
|
||||||
|
else:
|
||||||
|
# A different DatabaseError. Raise it.
|
||||||
|
raise e
|
||||||
|
|
||||||
|
See a more elaborate example with this code at
|
||||||
|
https://gist.github.com/thedod/11048875
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
if sys.version_info[0] != 3:
|
||||||
|
from pysqlcipher import dbapi2 as sqlcipher
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from sqlcipher3 import dbapi2 as sqlcipher
|
||||||
|
except ImportError:
|
||||||
|
from pysqlcipher3 import dbapi2 as sqlcipher
|
||||||
|
|
||||||
|
sqlcipher.register_adapter(decimal.Decimal, str)
|
||||||
|
sqlcipher.register_adapter(datetime.date, str)
|
||||||
|
sqlcipher.register_adapter(datetime.time, str)
|
||||||
|
|
||||||
|
|
||||||
|
class _SqlCipherDatabase(object):
|
||||||
|
def _connect(self):
|
||||||
|
params = dict(self.connect_params)
|
||||||
|
passphrase = params.pop('passphrase', '').replace("'", "''")
|
||||||
|
|
||||||
|
conn = sqlcipher.connect(self.database, isolation_level=None, **params)
|
||||||
|
try:
|
||||||
|
if passphrase:
|
||||||
|
conn.execute("PRAGMA key='%s'" % passphrase)
|
||||||
|
self._add_conn_hooks(conn)
|
||||||
|
except:
|
||||||
|
conn.close()
|
||||||
|
raise
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def set_passphrase(self, passphrase):
|
||||||
|
if not self.is_closed():
|
||||||
|
raise ImproperlyConfigured('Cannot set passphrase when database '
|
||||||
|
'is open. To change passphrase of an '
|
||||||
|
'open database use the rekey() method.')
|
||||||
|
|
||||||
|
self.connect_params['passphrase'] = passphrase
|
||||||
|
|
||||||
|
def rekey(self, passphrase):
|
||||||
|
if self.is_closed():
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''"))
|
||||||
|
self.connect_params['passphrase'] = passphrase
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase):
|
||||||
|
pass
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,522 @@
|
|||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import heapq
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import zlib
|
||||||
|
try:
|
||||||
|
from collections import Counter
|
||||||
|
except ImportError:
|
||||||
|
Counter = None
|
||||||
|
try:
|
||||||
|
from urlparse import urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse._sqlite_ext import TableFunction
|
||||||
|
except ImportError:
|
||||||
|
TableFunction = None
|
||||||
|
|
||||||
|
|
||||||
|
SQLITE_DATETIME_FORMATS = (
|
||||||
|
'%Y-%m-%d %H:%M:%S',
|
||||||
|
'%Y-%m-%d %H:%M:%S.%f',
|
||||||
|
'%Y-%m-%d',
|
||||||
|
'%H:%M:%S',
|
||||||
|
'%H:%M:%S.%f',
|
||||||
|
'%H:%M')
|
||||||
|
|
||||||
|
from peewee import format_date_time
|
||||||
|
|
||||||
|
def format_date_time_sqlite(date_value):
|
||||||
|
return format_date_time(date_value, SQLITE_DATETIME_FORMATS)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from playhouse import _sqlite_udf as cython_udf
|
||||||
|
except ImportError:
|
||||||
|
cython_udf = None
|
||||||
|
|
||||||
|
|
||||||
|
# Group udf by function.
|
||||||
|
CONTROL_FLOW = 'control_flow'
|
||||||
|
DATE = 'date'
|
||||||
|
FILE = 'file'
|
||||||
|
HELPER = 'helpers'
|
||||||
|
MATH = 'math'
|
||||||
|
STRING = 'string'
|
||||||
|
|
||||||
|
AGGREGATE_COLLECTION = {}
|
||||||
|
TABLE_FUNCTION_COLLECTION = {}
|
||||||
|
UDF_COLLECTION = {}
|
||||||
|
|
||||||
|
|
||||||
|
class synchronized_dict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(synchronized_dict, self).__init__(*args, **kwargs)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
with self._lock:
|
||||||
|
return super(synchronized_dict, self).__getitem__(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
with self._lock:
|
||||||
|
return super(synchronized_dict, self).__setitem__(key, value)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
with self._lock:
|
||||||
|
return super(synchronized_dict, self).__delitem__(key)
|
||||||
|
|
||||||
|
|
||||||
|
STATE = synchronized_dict()
|
||||||
|
SETTINGS = synchronized_dict()
|
||||||
|
|
||||||
|
# Class and function decorators.
|
||||||
|
def aggregate(*groups):
|
||||||
|
def decorator(klass):
|
||||||
|
for group in groups:
|
||||||
|
AGGREGATE_COLLECTION.setdefault(group, [])
|
||||||
|
AGGREGATE_COLLECTION[group].append(klass)
|
||||||
|
return klass
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def table_function(*groups):
|
||||||
|
def decorator(klass):
|
||||||
|
for group in groups:
|
||||||
|
TABLE_FUNCTION_COLLECTION.setdefault(group, [])
|
||||||
|
TABLE_FUNCTION_COLLECTION[group].append(klass)
|
||||||
|
return klass
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def udf(*groups):
|
||||||
|
def decorator(fn):
|
||||||
|
for group in groups:
|
||||||
|
UDF_COLLECTION.setdefault(group, [])
|
||||||
|
UDF_COLLECTION[group].append(fn)
|
||||||
|
return fn
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
# Register aggregates / functions with connection.
|
||||||
|
def register_aggregate_groups(db, *groups):
|
||||||
|
seen = set()
|
||||||
|
for group in groups:
|
||||||
|
klasses = AGGREGATE_COLLECTION.get(group, ())
|
||||||
|
for klass in klasses:
|
||||||
|
name = getattr(klass, 'name', klass.__name__)
|
||||||
|
if name not in seen:
|
||||||
|
seen.add(name)
|
||||||
|
db.register_aggregate(klass, name)
|
||||||
|
|
||||||
|
def register_table_function_groups(db, *groups):
|
||||||
|
seen = set()
|
||||||
|
for group in groups:
|
||||||
|
klasses = TABLE_FUNCTION_COLLECTION.get(group, ())
|
||||||
|
for klass in klasses:
|
||||||
|
if klass.name not in seen:
|
||||||
|
seen.add(klass.name)
|
||||||
|
db.register_table_function(klass)
|
||||||
|
|
||||||
|
def register_udf_groups(db, *groups):
|
||||||
|
seen = set()
|
||||||
|
for group in groups:
|
||||||
|
functions = UDF_COLLECTION.get(group, ())
|
||||||
|
for function in functions:
|
||||||
|
name = function.__name__
|
||||||
|
if name not in seen:
|
||||||
|
seen.add(name)
|
||||||
|
db.register_function(function, name)
|
||||||
|
|
||||||
|
def register_groups(db, *groups):
|
||||||
|
register_aggregate_groups(db, *groups)
|
||||||
|
register_table_function_groups(db, *groups)
|
||||||
|
register_udf_groups(db, *groups)
|
||||||
|
|
||||||
|
def register_all(db):
|
||||||
|
register_aggregate_groups(db, *AGGREGATE_COLLECTION)
|
||||||
|
register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION)
|
||||||
|
register_udf_groups(db, *UDF_COLLECTION)
|
||||||
|
|
||||||
|
|
||||||
|
# Begin actual user-defined functions and aggregates.
|
||||||
|
|
||||||
|
# Scalar functions.
|
||||||
|
@udf(CONTROL_FLOW)
|
||||||
|
def if_then_else(cond, truthy, falsey=None):
|
||||||
|
if cond:
|
||||||
|
return truthy
|
||||||
|
return falsey
|
||||||
|
|
||||||
|
@udf(DATE)
|
||||||
|
def strip_tz(date_str):
|
||||||
|
date_str = date_str.replace('T', ' ')
|
||||||
|
tz_idx1 = date_str.find('+')
|
||||||
|
if tz_idx1 != -1:
|
||||||
|
return date_str[:tz_idx1]
|
||||||
|
tz_idx2 = date_str.find('-')
|
||||||
|
if tz_idx2 > 13:
|
||||||
|
return date_str[:tz_idx2]
|
||||||
|
return date_str
|
||||||
|
|
||||||
|
@udf(DATE)
|
||||||
|
def human_delta(nseconds, glue=', '):
|
||||||
|
parts = (
|
||||||
|
(86400 * 365, 'year'),
|
||||||
|
(86400 * 30, 'month'),
|
||||||
|
(86400 * 7, 'week'),
|
||||||
|
(86400, 'day'),
|
||||||
|
(3600, 'hour'),
|
||||||
|
(60, 'minute'),
|
||||||
|
(1, 'second'),
|
||||||
|
)
|
||||||
|
accum = []
|
||||||
|
for offset, name in parts:
|
||||||
|
val, nseconds = divmod(nseconds, offset)
|
||||||
|
if val:
|
||||||
|
suffix = val != 1 and 's' or ''
|
||||||
|
accum.append('%s %s%s' % (val, name, suffix))
|
||||||
|
if not accum:
|
||||||
|
return '0 seconds'
|
||||||
|
return glue.join(accum)
|
||||||
|
|
||||||
|
@udf(FILE)
|
||||||
|
def file_ext(filename):
|
||||||
|
try:
|
||||||
|
res = os.path.splitext(filename)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return res[1]
|
||||||
|
|
||||||
|
@udf(FILE)
|
||||||
|
def file_read(filename):
|
||||||
|
try:
|
||||||
|
with open(filename) as fh:
|
||||||
|
return fh.read()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
@udf(HELPER)
|
||||||
|
def gzip(data, compression=9):
|
||||||
|
return buffer(zlib.compress(data, compression))
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def gunzip(data):
|
||||||
|
return zlib.decompress(data)
|
||||||
|
else:
|
||||||
|
@udf(HELPER)
|
||||||
|
def gzip(data, compression=9):
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = bytes(data.encode('raw_unicode_escape'))
|
||||||
|
return zlib.compress(data, compression)
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def gunzip(data):
|
||||||
|
return zlib.decompress(data)
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def hostname(url):
|
||||||
|
parse_result = urlparse(url)
|
||||||
|
if parse_result:
|
||||||
|
return parse_result.netloc
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def toggle(key):
|
||||||
|
key = key.lower()
|
||||||
|
STATE[key] = ret = not STATE.get(key)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def setting(key, value=None):
|
||||||
|
if value is None:
|
||||||
|
return SETTINGS.get(key)
|
||||||
|
else:
|
||||||
|
SETTINGS[key] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def clear_settings():
|
||||||
|
SETTINGS.clear()
|
||||||
|
|
||||||
|
@udf(HELPER)
|
||||||
|
def clear_toggles():
|
||||||
|
STATE.clear()
|
||||||
|
|
||||||
|
@udf(MATH)
|
||||||
|
def randomrange(start, end=None, step=None):
|
||||||
|
if end is None:
|
||||||
|
start, end = 0, start
|
||||||
|
elif step is None:
|
||||||
|
step = 1
|
||||||
|
return random.randrange(start, end, step)
|
||||||
|
|
||||||
|
@udf(MATH)
|
||||||
|
def gauss_distribution(mean, sigma):
|
||||||
|
try:
|
||||||
|
return random.gauss(mean, sigma)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@udf(MATH)
|
||||||
|
def sqrt(n):
|
||||||
|
try:
|
||||||
|
return math.sqrt(n)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@udf(MATH)
|
||||||
|
def tonumber(s):
|
||||||
|
try:
|
||||||
|
return int(s)
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
return float(s)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@udf(STRING)
|
||||||
|
def substr_count(haystack, needle):
|
||||||
|
if not haystack or not needle:
|
||||||
|
return 0
|
||||||
|
return haystack.count(needle)
|
||||||
|
|
||||||
|
@udf(STRING)
|
||||||
|
def strip_chars(haystack, chars):
|
||||||
|
return haystack.strip(chars)
|
||||||
|
|
||||||
|
def _hash(constructor, *args):
|
||||||
|
hash_obj = constructor()
|
||||||
|
for arg in args:
|
||||||
|
hash_obj.update(arg)
|
||||||
|
return hash_obj.hexdigest()
|
||||||
|
|
||||||
|
# Aggregates.
|
||||||
|
class _heap_agg(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.heap = []
|
||||||
|
self.ct = 0
|
||||||
|
|
||||||
|
def process(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def step(self, value):
|
||||||
|
self.ct += 1
|
||||||
|
heapq.heappush(self.heap, self.process(value))
|
||||||
|
|
||||||
|
class _datetime_heap_agg(_heap_agg):
|
||||||
|
def process(self, value):
|
||||||
|
return format_date_time_sqlite(value)
|
||||||
|
|
||||||
|
if sys.version_info[:2] == (2, 6):
|
||||||
|
def total_seconds(td):
|
||||||
|
return (td.seconds +
|
||||||
|
(td.days * 86400) +
|
||||||
|
(td.microseconds / (10.**6)))
|
||||||
|
else:
|
||||||
|
total_seconds = lambda td: td.total_seconds()
|
||||||
|
|
||||||
|
@aggregate(DATE)
|
||||||
|
class mintdiff(_datetime_heap_agg):
|
||||||
|
def finalize(self):
|
||||||
|
dtp = min_diff = None
|
||||||
|
while self.heap:
|
||||||
|
if min_diff is None:
|
||||||
|
if dtp is None:
|
||||||
|
dtp = heapq.heappop(self.heap)
|
||||||
|
continue
|
||||||
|
dt = heapq.heappop(self.heap)
|
||||||
|
diff = dt - dtp
|
||||||
|
if min_diff is None or min_diff > diff:
|
||||||
|
min_diff = diff
|
||||||
|
dtp = dt
|
||||||
|
if min_diff is not None:
|
||||||
|
return total_seconds(min_diff)
|
||||||
|
|
||||||
|
@aggregate(DATE)
|
||||||
|
class avgtdiff(_datetime_heap_agg):
|
||||||
|
def finalize(self):
|
||||||
|
if self.ct < 1:
|
||||||
|
return
|
||||||
|
elif self.ct == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = ct = 0
|
||||||
|
dtp = None
|
||||||
|
while self.heap:
|
||||||
|
if total == 0:
|
||||||
|
if dtp is None:
|
||||||
|
dtp = heapq.heappop(self.heap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
dt = heapq.heappop(self.heap)
|
||||||
|
diff = dt - dtp
|
||||||
|
ct += 1
|
||||||
|
total += total_seconds(diff)
|
||||||
|
dtp = dt
|
||||||
|
|
||||||
|
return float(total) / ct
|
||||||
|
|
||||||
|
@aggregate(DATE)
|
||||||
|
class duration(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._min = self._max = None
|
||||||
|
|
||||||
|
def step(self, value):
|
||||||
|
dt = format_date_time_sqlite(value)
|
||||||
|
if self._min is None or dt < self._min:
|
||||||
|
self._min = dt
|
||||||
|
if self._max is None or dt > self._max:
|
||||||
|
self._max = dt
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self._min and self._max:
|
||||||
|
td = (self._max - self._min)
|
||||||
|
return total_seconds(td)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@aggregate(MATH)
|
||||||
|
class mode(object):
|
||||||
|
if Counter:
|
||||||
|
def __init__(self):
|
||||||
|
self.items = Counter()
|
||||||
|
|
||||||
|
def step(self, *args):
|
||||||
|
self.items.update(args)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.items:
|
||||||
|
return self.items.most_common(1)[0][0]
|
||||||
|
else:
|
||||||
|
def __init__(self):
|
||||||
|
self.items = []
|
||||||
|
|
||||||
|
def step(self, item):
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.items:
|
||||||
|
return max(set(self.items), key=self.items.count)
|
||||||
|
|
||||||
|
@aggregate(MATH)
|
||||||
|
class minrange(_heap_agg):
|
||||||
|
def finalize(self):
|
||||||
|
if self.ct == 0:
|
||||||
|
return
|
||||||
|
elif self.ct == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
prev = min_diff = None
|
||||||
|
|
||||||
|
while self.heap:
|
||||||
|
if min_diff is None:
|
||||||
|
if prev is None:
|
||||||
|
prev = heapq.heappop(self.heap)
|
||||||
|
continue
|
||||||
|
curr = heapq.heappop(self.heap)
|
||||||
|
diff = curr - prev
|
||||||
|
if min_diff is None or min_diff > diff:
|
||||||
|
min_diff = diff
|
||||||
|
prev = curr
|
||||||
|
return min_diff
|
||||||
|
|
||||||
|
@aggregate(MATH)
|
||||||
|
class avgrange(_heap_agg):
|
||||||
|
def finalize(self):
|
||||||
|
if self.ct == 0:
|
||||||
|
return
|
||||||
|
elif self.ct == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = ct = 0
|
||||||
|
prev = None
|
||||||
|
while self.heap:
|
||||||
|
if total == 0:
|
||||||
|
if prev is None:
|
||||||
|
prev = heapq.heappop(self.heap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr = heapq.heappop(self.heap)
|
||||||
|
diff = curr - prev
|
||||||
|
ct += 1
|
||||||
|
total += diff
|
||||||
|
prev = curr
|
||||||
|
|
||||||
|
return float(total) / ct
|
||||||
|
|
||||||
|
@aggregate(MATH)
|
||||||
|
class _range(object):
|
||||||
|
name = 'range'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._min = self._max = None
|
||||||
|
|
||||||
|
def step(self, value):
|
||||||
|
if self._min is None or value < self._min:
|
||||||
|
self._min = value
|
||||||
|
if self._max is None or value > self._max:
|
||||||
|
self._max = value
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self._min is not None and self._max is not None:
|
||||||
|
return self._max - self._min
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if cython_udf is not None:
|
||||||
|
damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist)
|
||||||
|
levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist)
|
||||||
|
str_dist = udf(STRING)(cython_udf.str_dist)
|
||||||
|
median = aggregate(MATH)(cython_udf.median)
|
||||||
|
|
||||||
|
|
||||||
|
if TableFunction is not None:
|
||||||
|
@table_function(STRING)
|
||||||
|
class RegexSearch(TableFunction):
|
||||||
|
params = ['regex', 'search_string']
|
||||||
|
columns = ['match']
|
||||||
|
name = 'regex_search'
|
||||||
|
|
||||||
|
def initialize(self, regex=None, search_string=None):
|
||||||
|
self._iter = re.finditer(regex, search_string)
|
||||||
|
|
||||||
|
def iterate(self, idx):
|
||||||
|
return (next(self._iter).group(0),)
|
||||||
|
|
||||||
|
@table_function(DATE)
|
||||||
|
class DateSeries(TableFunction):
|
||||||
|
params = ['start', 'stop', 'step_seconds']
|
||||||
|
columns = ['date']
|
||||||
|
name = 'date_series'
|
||||||
|
|
||||||
|
def initialize(self, start, stop, step_seconds=86400):
|
||||||
|
self.start = format_date_time_sqlite(start)
|
||||||
|
self.stop = format_date_time_sqlite(stop)
|
||||||
|
step_seconds = int(step_seconds)
|
||||||
|
self.step_seconds = datetime.timedelta(seconds=step_seconds)
|
||||||
|
|
||||||
|
if (self.start.hour == 0 and
|
||||||
|
self.start.minute == 0 and
|
||||||
|
self.start.second == 0 and
|
||||||
|
step_seconds >= 86400):
|
||||||
|
self.format = '%Y-%m-%d'
|
||||||
|
elif (self.start.year == 1900 and
|
||||||
|
self.start.month == 1 and
|
||||||
|
self.start.day == 1 and
|
||||||
|
self.stop.year == 1900 and
|
||||||
|
self.stop.month == 1 and
|
||||||
|
self.stop.day == 1 and
|
||||||
|
step_seconds < 86400):
|
||||||
|
self.format = '%H:%M:%S'
|
||||||
|
else:
|
||||||
|
self.format = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
def iterate(self, idx):
|
||||||
|
if self.start > self.stop:
|
||||||
|
raise StopIteration
|
||||||
|
current = self.start
|
||||||
|
self.start += self.step_seconds
|
||||||
|
return (current.strftime(self.format),)
|
@ -0,0 +1,330 @@
|
|||||||
|
import logging
|
||||||
|
import weakref
|
||||||
|
from threading import local as thread_local
|
||||||
|
from threading import Event
|
||||||
|
from threading import Thread
|
||||||
|
try:
|
||||||
|
from Queue import Queue
|
||||||
|
except ImportError:
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gevent
|
||||||
|
from gevent import Greenlet as GThread
|
||||||
|
from gevent.event import Event as GEvent
|
||||||
|
from gevent.local import local as greenlet_local
|
||||||
|
from gevent.queue import Queue as GQueue
|
||||||
|
except ImportError:
|
||||||
|
GThread = GQueue = GEvent = None
|
||||||
|
|
||||||
|
from peewee import SENTINEL
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('peewee.sqliteq')
|
||||||
|
|
||||||
|
|
||||||
|
class ResultTimeout(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class WriterPaused(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ShutdownException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCursor(object):
|
||||||
|
__slots__ = ('sql', 'params', 'commit', 'timeout',
|
||||||
|
'_event', '_cursor', '_exc', '_idx', '_rows', '_ready')
|
||||||
|
|
||||||
|
def __init__(self, event, sql, params, commit, timeout):
|
||||||
|
self._event = event
|
||||||
|
self.sql = sql
|
||||||
|
self.params = params
|
||||||
|
self.commit = commit
|
||||||
|
self.timeout = timeout
|
||||||
|
self._cursor = self._exc = self._idx = self._rows = None
|
||||||
|
self._ready = False
|
||||||
|
|
||||||
|
def set_result(self, cursor, exc=None):
|
||||||
|
self._cursor = cursor
|
||||||
|
self._exc = exc
|
||||||
|
self._idx = 0
|
||||||
|
self._rows = cursor.fetchall() if exc is None else []
|
||||||
|
self._event.set()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _wait(self, timeout=None):
|
||||||
|
timeout = timeout if timeout is not None else self.timeout
|
||||||
|
if not self._event.wait(timeout=timeout) and timeout:
|
||||||
|
raise ResultTimeout('results not ready, timed out.')
|
||||||
|
if self._exc is not None:
|
||||||
|
raise self._exc
|
||||||
|
self._ready = True
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if not self._ready:
|
||||||
|
self._wait()
|
||||||
|
if self._exc is not None:
|
||||||
|
raise self._exec
|
||||||
|
return self
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if not self._ready:
|
||||||
|
self._wait()
|
||||||
|
try:
|
||||||
|
obj = self._rows[self._idx]
|
||||||
|
except IndexError:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
self._idx += 1
|
||||||
|
return obj
|
||||||
|
__next__ = next
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lastrowid(self):
|
||||||
|
if not self._ready:
|
||||||
|
self._wait()
|
||||||
|
return self._cursor.lastrowid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rowcount(self):
|
||||||
|
if not self._ready:
|
||||||
|
self._wait()
|
||||||
|
return self._cursor.rowcount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return self._cursor.description
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._cursor.close()
|
||||||
|
|
||||||
|
def fetchall(self):
|
||||||
|
return list(self) # Iterating implies waiting until populated.
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
if not self._ready:
|
||||||
|
self._wait()
|
||||||
|
try:
|
||||||
|
return next(self)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
SHUTDOWN = StopIteration
|
||||||
|
PAUSE = object()
|
||||||
|
UNPAUSE = object()
|
||||||
|
|
||||||
|
|
||||||
|
class Writer(object):
|
||||||
|
__slots__ = ('database', 'queue')
|
||||||
|
|
||||||
|
def __init__(self, database, queue):
|
||||||
|
self.database = database
|
||||||
|
self.queue = queue
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
conn = self.database.connection()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if conn is None: # Paused.
|
||||||
|
if self.wait_unpause():
|
||||||
|
conn = self.database.connection()
|
||||||
|
else:
|
||||||
|
conn = self.loop(conn)
|
||||||
|
except ShutdownException:
|
||||||
|
logger.info('writer received shutdown request, exiting.')
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
if conn is not None:
|
||||||
|
self.database._close(conn)
|
||||||
|
self.database._state.reset()
|
||||||
|
|
||||||
|
def wait_unpause(self):
|
||||||
|
obj = self.queue.get()
|
||||||
|
if obj is UNPAUSE:
|
||||||
|
logger.info('writer unpaused - reconnecting to database.')
|
||||||
|
return True
|
||||||
|
elif obj is SHUTDOWN:
|
||||||
|
raise ShutdownException()
|
||||||
|
elif obj is PAUSE:
|
||||||
|
logger.error('writer received pause, but is already paused.')
|
||||||
|
else:
|
||||||
|
obj.set_result(None, WriterPaused())
|
||||||
|
logger.warning('writer paused, not handling %s', obj)
|
||||||
|
|
||||||
|
def loop(self, conn):
|
||||||
|
obj = self.queue.get()
|
||||||
|
if isinstance(obj, AsyncCursor):
|
||||||
|
self.execute(obj)
|
||||||
|
elif obj is PAUSE:
|
||||||
|
logger.info('writer paused - closing database connection.')
|
||||||
|
self.database._close(conn)
|
||||||
|
self.database._state.reset()
|
||||||
|
return
|
||||||
|
elif obj is UNPAUSE:
|
||||||
|
logger.error('writer received unpause, but is already running.')
|
||||||
|
elif obj is SHUTDOWN:
|
||||||
|
raise ShutdownException()
|
||||||
|
else:
|
||||||
|
logger.error('writer received unsupported object: %s', obj)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def execute(self, obj):
|
||||||
|
logger.debug('received query %s', obj.sql)
|
||||||
|
try:
|
||||||
|
cursor = self.database._execute(obj.sql, obj.params, obj.commit)
|
||||||
|
except Exception as execute_err:
|
||||||
|
cursor = None
|
||||||
|
exc = execute_err # python3 is so fucking lame.
|
||||||
|
else:
|
||||||
|
exc = None
|
||||||
|
return obj.set_result(cursor, exc)
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteQueueDatabase(SqliteExtDatabase):
|
||||||
|
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL '
|
||||||
|
'journal mode when using this feature. WAL mode '
|
||||||
|
'allows one or more readers to continue reading '
|
||||||
|
'while another connection writes to the '
|
||||||
|
'database.')
|
||||||
|
|
||||||
|
def __init__(self, database, use_gevent=False, autostart=True,
|
||||||
|
queue_max_size=None, results_timeout=None, *args, **kwargs):
|
||||||
|
kwargs['check_same_thread'] = False
|
||||||
|
|
||||||
|
# Ensure that journal_mode is WAL. This value is passed to the parent
|
||||||
|
# class constructor below.
|
||||||
|
pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None))
|
||||||
|
|
||||||
|
# Reference to execute_sql on the parent class. Since we've overridden
|
||||||
|
# execute_sql(), this is just a handy way to reference the real
|
||||||
|
# implementation.
|
||||||
|
Parent = super(SqliteQueueDatabase, self)
|
||||||
|
self._execute = Parent.execute_sql
|
||||||
|
|
||||||
|
# Call the parent class constructor with our modified pragmas.
|
||||||
|
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
|
||||||
|
|
||||||
|
self._autostart = autostart
|
||||||
|
self._results_timeout = results_timeout
|
||||||
|
self._is_stopped = True
|
||||||
|
|
||||||
|
# Get different objects depending on the threading implementation.
|
||||||
|
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
|
||||||
|
|
||||||
|
# Create the writer thread, optionally starting it.
|
||||||
|
self._create_write_queue()
|
||||||
|
if self._autostart:
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def get_thread_impl(self, use_gevent):
|
||||||
|
return GreenletHelper if use_gevent else ThreadHelper
|
||||||
|
|
||||||
|
def _validate_journal_mode(self, pragmas=None):
|
||||||
|
if pragmas:
|
||||||
|
pdict = dict((k.lower(), v) for (k, v) in pragmas)
|
||||||
|
if pdict.get('journal_mode', 'wal').lower() != 'wal':
|
||||||
|
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
|
||||||
|
|
||||||
|
return [(k, v) for (k, v) in pragmas
|
||||||
|
if k != 'journal_mode'] + [('journal_mode', 'wal')]
|
||||||
|
else:
|
||||||
|
return [('journal_mode', 'wal')]
|
||||||
|
|
||||||
|
def _create_write_queue(self):
|
||||||
|
self._write_queue = self._thread_helper.queue()
|
||||||
|
|
||||||
|
def queue_size(self):
|
||||||
|
return self._write_queue.qsize()
|
||||||
|
|
||||||
|
def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None):
|
||||||
|
if commit is SENTINEL:
|
||||||
|
commit = not sql.lower().startswith('select')
|
||||||
|
|
||||||
|
if not commit:
|
||||||
|
return self._execute(sql, params, commit=commit)
|
||||||
|
|
||||||
|
cursor = AsyncCursor(
|
||||||
|
event=self._thread_helper.event(),
|
||||||
|
sql=sql,
|
||||||
|
params=params,
|
||||||
|
commit=commit,
|
||||||
|
timeout=self._results_timeout if timeout is None else timeout)
|
||||||
|
self._write_queue.put(cursor)
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
with self._lock:
|
||||||
|
if not self._is_stopped:
|
||||||
|
return False
|
||||||
|
def run():
|
||||||
|
writer = Writer(self, self._write_queue)
|
||||||
|
writer.run()
|
||||||
|
|
||||||
|
self._writer = self._thread_helper.thread(run)
|
||||||
|
self._writer.start()
|
||||||
|
self._is_stopped = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
logger.debug('environment stop requested.')
|
||||||
|
with self._lock:
|
||||||
|
if self._is_stopped:
|
||||||
|
return False
|
||||||
|
self._write_queue.put(SHUTDOWN)
|
||||||
|
self._writer.join()
|
||||||
|
self._is_stopped = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_stopped(self):
|
||||||
|
with self._lock:
|
||||||
|
return self._is_stopped
|
||||||
|
|
||||||
|
def pause(self):
|
||||||
|
with self._lock:
|
||||||
|
self._write_queue.put(PAUSE)
|
||||||
|
|
||||||
|
def unpause(self):
|
||||||
|
with self._lock:
|
||||||
|
self._write_queue.put(UNPAUSE)
|
||||||
|
|
||||||
|
def __unsupported__(self, *args, **kwargs):
|
||||||
|
raise ValueError('This method is not supported by %r.' % type(self))
|
||||||
|
atomic = transaction = savepoint = __unsupported__
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadHelper(object):
|
||||||
|
__slots__ = ('queue_max_size',)
|
||||||
|
|
||||||
|
def __init__(self, queue_max_size=None):
|
||||||
|
self.queue_max_size = queue_max_size
|
||||||
|
|
||||||
|
def event(self): return Event()
|
||||||
|
|
||||||
|
def queue(self, max_size=None):
|
||||||
|
max_size = max_size if max_size is not None else self.queue_max_size
|
||||||
|
return Queue(maxsize=max_size or 0)
|
||||||
|
|
||||||
|
def thread(self, fn, *args, **kwargs):
|
||||||
|
thread = Thread(target=fn, args=args, kwargs=kwargs)
|
||||||
|
thread.daemon = True
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
class GreenletHelper(ThreadHelper):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def event(self): return GEvent()
|
||||||
|
|
||||||
|
def queue(self, max_size=None):
|
||||||
|
max_size = max_size if max_size is not None else self.queue_max_size
|
||||||
|
return GQueue(maxsize=max_size or 0)
|
||||||
|
|
||||||
|
def thread(self, fn, *args, **kwargs):
|
||||||
|
def wrap(*a, **k):
|
||||||
|
gevent.sleep()
|
||||||
|
return fn(*a, **k)
|
||||||
|
return GThread(wrap, *args, **kwargs)
|
@ -0,0 +1,62 @@
|
|||||||
|
from functools import wraps
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('peewee')
|
||||||
|
|
||||||
|
|
||||||
|
class _QueryLogHandler(logging.Handler):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.queries = []
|
||||||
|
logging.Handler.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.queries.append(record)
|
||||||
|
|
||||||
|
|
||||||
|
class count_queries(object):
|
||||||
|
def __init__(self, only_select=False):
|
||||||
|
self.only_select = only_select
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def get_queries(self):
|
||||||
|
return self._handler.queries
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._handler = _QueryLogHandler()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(self._handler)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
logger.removeHandler(self._handler)
|
||||||
|
if self.only_select:
|
||||||
|
self.count = len([q for q in self._handler.queries
|
||||||
|
if q.msg[0].startswith('SELECT ')])
|
||||||
|
else:
|
||||||
|
self.count = len(self._handler.queries)
|
||||||
|
|
||||||
|
|
||||||
|
class assert_query_count(count_queries):
|
||||||
|
def __init__(self, expected, only_select=False):
|
||||||
|
super(assert_query_count, self).__init__(only_select=only_select)
|
||||||
|
self.expected = expected
|
||||||
|
|
||||||
|
def __call__(self, f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwds):
|
||||||
|
with self:
|
||||||
|
ret = f(*args, **kwds)
|
||||||
|
|
||||||
|
self._assert_count()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
def _assert_count(self):
|
||||||
|
error_msg = '%s != %s' % (self.count, self.expected)
|
||||||
|
assert self.count == self.expected, error_msg
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
self._assert_count()
|
@ -0,0 +1,221 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import sys
|
||||||
|
from getpass import getpass
|
||||||
|
from optparse import OptionParser
|
||||||
|
|
||||||
|
from peewee import *
|
||||||
|
from peewee import print_
|
||||||
|
from peewee import __version__ as peewee_version
|
||||||
|
from playhouse.reflection import *
|
||||||
|
|
||||||
|
|
||||||
|
HEADER = """from peewee import *%s
|
||||||
|
|
||||||
|
database = %s('%s'%s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
BASE_MODEL = """\
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = database
|
||||||
|
"""
|
||||||
|
|
||||||
|
UNKNOWN_FIELD = """\
|
||||||
|
class UnknownField(object):
|
||||||
|
def __init__(self, *_, **__): pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
DATABASE_ALIASES = {
|
||||||
|
MySQLDatabase: ['mysql', 'mysqldb'],
|
||||||
|
PostgresqlDatabase: ['postgres', 'postgresql'],
|
||||||
|
SqliteDatabase: ['sqlite', 'sqlite3'],
|
||||||
|
}
|
||||||
|
|
||||||
|
DATABASE_MAP = dict((value, key)
|
||||||
|
for key in DATABASE_ALIASES
|
||||||
|
for value in DATABASE_ALIASES[key])
|
||||||
|
|
||||||
|
def make_introspector(database_type, database_name, **kwargs):
|
||||||
|
if database_type not in DATABASE_MAP:
|
||||||
|
err('Unrecognized database, must be one of: %s' %
|
||||||
|
', '.join(DATABASE_MAP.keys()))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
schema = kwargs.pop('schema', None)
|
||||||
|
DatabaseClass = DATABASE_MAP[database_type]
|
||||||
|
db = DatabaseClass(database_name, **kwargs)
|
||||||
|
return Introspector.from_database(db, schema=schema)
|
||||||
|
|
||||||
|
def print_models(introspector, tables=None, preserve_order=False,
|
||||||
|
include_views=False, ignore_unknown=False, snake_case=True):
|
||||||
|
database = introspector.introspect(table_names=tables,
|
||||||
|
include_views=include_views,
|
||||||
|
snake_case=snake_case)
|
||||||
|
|
||||||
|
db_kwargs = introspector.get_database_kwargs()
|
||||||
|
header = HEADER % (
|
||||||
|
introspector.get_additional_imports(),
|
||||||
|
introspector.get_database_class().__name__,
|
||||||
|
introspector.get_database_name(),
|
||||||
|
', **%s' % repr(db_kwargs) if db_kwargs else '')
|
||||||
|
print_(header)
|
||||||
|
|
||||||
|
if not ignore_unknown:
|
||||||
|
print_(UNKNOWN_FIELD)
|
||||||
|
|
||||||
|
print_(BASE_MODEL)
|
||||||
|
|
||||||
|
def _print_table(table, seen, accum=None):
|
||||||
|
accum = accum or []
|
||||||
|
foreign_keys = database.foreign_keys[table]
|
||||||
|
for foreign_key in foreign_keys:
|
||||||
|
dest = foreign_key.dest_table
|
||||||
|
|
||||||
|
# In the event the destination table has already been pushed
|
||||||
|
# for printing, then we have a reference cycle.
|
||||||
|
if dest in accum and table not in accum:
|
||||||
|
print_('# Possible reference cycle: %s' % dest)
|
||||||
|
|
||||||
|
# If this is not a self-referential foreign key, and we have
|
||||||
|
# not already processed the destination table, do so now.
|
||||||
|
if dest not in seen and dest not in accum:
|
||||||
|
seen.add(dest)
|
||||||
|
if dest != table:
|
||||||
|
_print_table(dest, seen, accum + [table])
|
||||||
|
|
||||||
|
print_('class %s(BaseModel):' % database.model_names[table])
|
||||||
|
columns = database.columns[table].items()
|
||||||
|
if not preserve_order:
|
||||||
|
columns = sorted(columns)
|
||||||
|
primary_keys = database.primary_keys[table]
|
||||||
|
for name, column in columns:
|
||||||
|
skip = all([
|
||||||
|
name in primary_keys,
|
||||||
|
name == 'id',
|
||||||
|
len(primary_keys) == 1,
|
||||||
|
column.field_class in introspector.pk_classes])
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
if column.primary_key and len(primary_keys) > 1:
|
||||||
|
# If we have a CompositeKey, then we do not want to explicitly
|
||||||
|
# mark the columns as being primary keys.
|
||||||
|
column.primary_key = False
|
||||||
|
|
||||||
|
is_unknown = column.field_class is UnknownField
|
||||||
|
if is_unknown and ignore_unknown:
|
||||||
|
disp = '%s - %s' % (column.name, column.raw_column_type or '?')
|
||||||
|
print_(' # %s' % disp)
|
||||||
|
else:
|
||||||
|
print_(' %s' % column.get_field())
|
||||||
|
|
||||||
|
print_('')
|
||||||
|
print_(' class Meta:')
|
||||||
|
print_(' table_name = \'%s\'' % table)
|
||||||
|
multi_column_indexes = database.multi_column_indexes(table)
|
||||||
|
if multi_column_indexes:
|
||||||
|
print_(' indexes = (')
|
||||||
|
for fields, unique in sorted(multi_column_indexes):
|
||||||
|
print_(' ((%s), %s),' % (
|
||||||
|
', '.join("'%s'" % field for field in fields),
|
||||||
|
unique,
|
||||||
|
))
|
||||||
|
print_(' )')
|
||||||
|
|
||||||
|
if introspector.schema:
|
||||||
|
print_(' schema = \'%s\'' % introspector.schema)
|
||||||
|
if len(primary_keys) > 1:
|
||||||
|
pk_field_names = sorted([
|
||||||
|
field.name for col, field in columns
|
||||||
|
if col in primary_keys])
|
||||||
|
pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
|
||||||
|
print_(' primary_key = CompositeKey(%s)' % pk_list)
|
||||||
|
elif not primary_keys:
|
||||||
|
print_(' primary_key = False')
|
||||||
|
print_('')
|
||||||
|
|
||||||
|
seen.add(table)
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
for table in sorted(database.model_names.keys()):
|
||||||
|
if table not in seen:
|
||||||
|
if not tables or table in tables:
|
||||||
|
_print_table(table, seen)
|
||||||
|
|
||||||
|
def print_header(cmd_line, introspector):
|
||||||
|
timestamp = datetime.datetime.now()
|
||||||
|
print_('# Code generated by:')
|
||||||
|
print_('# python -m pwiz %s' % cmd_line)
|
||||||
|
print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
|
||||||
|
print_('# Database: %s' % introspector.get_database_name())
|
||||||
|
print_('# Peewee version: %s' % peewee_version)
|
||||||
|
print_('')
|
||||||
|
|
||||||
|
|
||||||
|
def err(msg):
|
||||||
|
sys.stderr.write('\033[91m%s\033[0m\n' % msg)
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
def get_option_parser():
|
||||||
|
parser = OptionParser(usage='usage: %prog [options] database_name')
|
||||||
|
ao = parser.add_option
|
||||||
|
ao('-H', '--host', dest='host')
|
||||||
|
ao('-p', '--port', dest='port', type='int')
|
||||||
|
ao('-u', '--user', dest='user')
|
||||||
|
ao('-P', '--password', dest='password', action='store_true')
|
||||||
|
engines = sorted(DATABASE_MAP)
|
||||||
|
ao('-e', '--engine', dest='engine', default='postgresql', choices=engines,
|
||||||
|
help=('Database type, e.g. sqlite, mysql or postgresql. Default '
|
||||||
|
'is "postgresql".'))
|
||||||
|
ao('-s', '--schema', dest='schema')
|
||||||
|
ao('-t', '--tables', dest='tables',
|
||||||
|
help=('Only generate the specified tables. Multiple table names should '
|
||||||
|
'be separated by commas.'))
|
||||||
|
ao('-v', '--views', dest='views', action='store_true',
|
||||||
|
help='Generate model classes for VIEWs in addition to tables.')
|
||||||
|
ao('-i', '--info', dest='info', action='store_true',
|
||||||
|
help=('Add database information and other metadata to top of the '
|
||||||
|
'generated file.'))
|
||||||
|
ao('-o', '--preserve-order', action='store_true', dest='preserve_order',
|
||||||
|
help='Model definition column ordering matches source table.')
|
||||||
|
ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown',
|
||||||
|
help='Ignore fields whose type cannot be determined.')
|
||||||
|
ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming',
|
||||||
|
help='Use legacy table- and column-name generation.')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def get_connect_kwargs(options):
|
||||||
|
ops = ('host', 'port', 'user', 'schema')
|
||||||
|
kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o))
|
||||||
|
if options.password:
|
||||||
|
kwargs['password'] = getpass()
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raw_argv = sys.argv
|
||||||
|
|
||||||
|
parser = get_option_parser()
|
||||||
|
options, args = parser.parse_args()
|
||||||
|
|
||||||
|
if len(args) < 1:
|
||||||
|
err('Missing required parameter "database"')
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
connect = get_connect_kwargs(options)
|
||||||
|
database = args[-1]
|
||||||
|
|
||||||
|
tables = None
|
||||||
|
if options.tables:
|
||||||
|
tables = [table.strip() for table in options.tables.split(',')
|
||||||
|
if table.strip()]
|
||||||
|
|
||||||
|
introspector = make_introspector(options.engine, database, **connect)
|
||||||
|
if options.info:
|
||||||
|
cmd_line = ' '.join(raw_argv[1:])
|
||||||
|
print_header(cmd_line, introspector)
|
||||||
|
|
||||||
|
print_models(introspector, tables, options.preserve_order, options.views,
|
||||||
|
options.ignore_unknown, not options.legacy_naming)
|
Loading…
Reference in new issue