diff --git a/bazarr/check_update.py b/bazarr/check_update.py index 4e800edba..f3ef75a7f 100644 --- a/bazarr/check_update.py +++ b/bazarr/check_update.py @@ -8,7 +8,7 @@ import tarfile from get_args import args from config import settings, bazarr_url from queueconfig import notifications -from database import System +from database import database if not args.no_update and not args.release_update: import git @@ -300,4 +300,4 @@ def updated(restart=True): logging.info('BAZARR Restart failed, please restart Bazarr manualy') updated(restart=False) else: - System.update({System.updated: 1}).execute() + database.execute("UPDATE system SET updated='1'") diff --git a/bazarr/database.py b/bazarr/database.py index 46ab691a8..6afa8a4b8 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -1,224 +1,7 @@ import os -import atexit +from sqlite3worker import Sqlite3Worker from get_args import args -from peewee import * -from playhouse.sqliteq import SqliteQueueDatabase -from playhouse.migrate import * - from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie -database = SqliteQueueDatabase( - None, - use_gevent=False, - autostart=False, - queue_max_size=256, # Max. # of pending writes that can accumulate. - results_timeout=30.0 # Max. time to wait for query to be executed. -) - - -@database.func('path_substitution') -def path_substitution(path): - return path_replace(path) - - -@database.func('path_substitution_movie') -def path_substitution_movie(path): - return path_replace_movie(path) - - -class UnknownField(object): - def __init__(self, *_, **__): pass - -class BaseModel(Model): - class Meta: - database = database - - -class System(BaseModel): - configured = TextField(null=True) - updated = TextField(null=True) - - class Meta: - table_name = 'system' - primary_key = False - - -class TableShows(BaseModel): - alternate_titles = TextField(column_name='alternateTitles', null=True) - audio_language = TextField(null=True) - fanart = TextField(null=True) - forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')]) - hearing_impaired = TextField(null=True) - languages = TextField(null=True) - overview = TextField(null=True) - path = TextField(null=False, unique=True) - poster = TextField(null=True) - sonarr_series_id = IntegerField(column_name='sonarrSeriesId', null=True, unique=True) - sort_title = TextField(column_name='sortTitle', null=True) - title = TextField(null=True) - tvdb_id = IntegerField(column_name='tvdbId', null=True, unique=True, primary_key=True) - year = TextField(null=True) - - class Meta: - table_name = 'table_shows' - - -class TableEpisodes(BaseModel): - rowid = IntegerField() - audio_codec = TextField(null=True) - episode = IntegerField(null=False) - failed_attempts = TextField(column_name='failedAttempts', null=True) - format = TextField(null=True) - missing_subtitles = TextField(null=True) - monitored = TextField(null=True) - path = TextField(null=False) - resolution = TextField(null=True) - scene_name = TextField(null=True) - season = IntegerField(null=False) - sonarr_episode_id = IntegerField(column_name='sonarrEpisodeId', unique=True, null=False) - sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False) - subtitles = TextField(null=True) - title = TextField(null=True) - video_codec = TextField(null=True) - episode_file_id = IntegerField(null=True) - - class Meta: - table_name = 'table_episodes' - primary_key = False - - -class TableMovies(BaseModel): - rowid = IntegerField() - alternative_titles = TextField(column_name='alternativeTitles', null=True) - audio_codec = TextField(null=True) - audio_language = TextField(null=True) - failed_attempts = TextField(column_name='failedAttempts', null=True) - fanart = TextField(null=True) - forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')]) - format = TextField(null=True) - hearing_impaired = TextField(null=True) - imdb_id = TextField(column_name='imdbId', null=True) - languages = TextField(null=True) - missing_subtitles = TextField(null=True) - monitored = TextField(null=True) - overview = TextField(null=True) - path = TextField(unique=True) - poster = TextField(null=True) - radarr_id = IntegerField(column_name='radarrId', null=False, unique=True) - resolution = TextField(null=True) - scene_name = TextField(column_name='sceneName', null=True) - sort_title = TextField(column_name='sortTitle', null=True) - subtitles = TextField(null=True) - title = TextField(null=False) - tmdb_id = TextField(column_name='tmdbId', primary_key=True, null=False) - video_codec = TextField(null=True) - year = TextField(null=True) - movie_file_id = IntegerField(null=True) - - class Meta: - table_name = 'table_movies' - - -class TableHistory(BaseModel): - id = PrimaryKeyField(null=False) - action = IntegerField(null=False) - description = TextField(null=False) - language = TextField(null=True) - provider = TextField(null=True) - score = TextField(null=True) - sonarr_episode_id = ForeignKeyField(TableEpisodes, field='sonarr_episode_id', column_name='sonarrEpisodeId', null=False) - sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False) - timestamp = IntegerField(null=False) - video_path = TextField(null=True) - - class Meta: - table_name = 'table_history' - - -class TableHistoryMovie(BaseModel): - id = PrimaryKeyField(null=False) - action = IntegerField(null=False) - description = TextField(null=False) - language = TextField(null=True) - provider = TextField(null=True) - radarr_id = ForeignKeyField(TableMovies, field='radarr_id', column_name='radarrId', null=False) - score = TextField(null=True) - timestamp = IntegerField(null=False) - video_path = TextField(null=True) - - class Meta: - table_name = 'table_history_movie' - - -class TableSettingsLanguages(BaseModel): - code2 = TextField(null=False) - code3 = TextField(null=False, unique=True, primary_key=True) - code3b = TextField(null=True) - enabled = IntegerField(null=True) - name = TextField(null=False) - - class Meta: - table_name = 'table_settings_languages' - - -class TableSettingsNotifier(BaseModel): - enabled = IntegerField(null=False) - name = TextField(null=False, primary_key=True) - url = TextField(null=True) - - class Meta: - table_name = 'table_settings_notifier' - - -def database_init(): - database.init(os.path.join(args.config_dir, 'db', 'bazarr.db')) - database.start() - database.connect() - - database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. - database.cache_size = -1024 # Number of KB of cache for wal-journal. - # Must be negative because positive means number of pages. - database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. - - # Database tables creation if they don't exists - models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages, - TableSettingsNotifier, System] - - database.create_tables(models_list, safe=True) - - # Upgrade DB schema - database_upgrade() - - -def database_upgrade(): - # Database migration - migrator = SqliteMigrator(database) - - # TableShows migration - table_shows_columns = [] - for column in database.get_columns('table_shows'): - table_shows_columns.append(column.name) - if 'forced' not in table_shows_columns: - migrate(migrator.add_column('table_shows', 'forced', TableShows.forced)) - - # TableEpisodes migration - table_episodes_columns = [] - for column in database.get_columns('table_episodes'): - table_episodes_columns.append(column.name) - if 'episode_file_id' not in table_episodes_columns: - migrate(migrator.add_column('table_episodes', 'episode_file_id', TableEpisodes.episode_file_id)) - - # TableMovies migration - table_movies_columns = [] - for column in database.get_columns('table_movies'): - table_movies_columns.append(column.name) - if 'forced' not in table_movies_columns: - migrate(migrator.add_column('table_movies', 'forced', TableMovies.forced)) - if 'movie_file_id' not in table_movies_columns: - migrate(migrator.add_column('table_movies', 'movie_file_id', TableMovies.movie_file_id)) - - -def wal_cleaning(): - database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. - database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. +database = Sqlite3Worker(os.path.join(args.config_dir, 'db', 'bazarr.db'), max_queue_size=256) diff --git a/bazarr/database_peewee.py b/bazarr/database_peewee.py new file mode 100644 index 000000000..46ab691a8 --- /dev/null +++ b/bazarr/database_peewee.py @@ -0,0 +1,224 @@ +import os +import atexit + +from get_args import args +from peewee import * +from playhouse.sqliteq import SqliteQueueDatabase +from playhouse.migrate import * + +from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie + +database = SqliteQueueDatabase( + None, + use_gevent=False, + autostart=False, + queue_max_size=256, # Max. # of pending writes that can accumulate. + results_timeout=30.0 # Max. time to wait for query to be executed. +) + + +@database.func('path_substitution') +def path_substitution(path): + return path_replace(path) + + +@database.func('path_substitution_movie') +def path_substitution_movie(path): + return path_replace_movie(path) + + +class UnknownField(object): + def __init__(self, *_, **__): pass + +class BaseModel(Model): + class Meta: + database = database + + +class System(BaseModel): + configured = TextField(null=True) + updated = TextField(null=True) + + class Meta: + table_name = 'system' + primary_key = False + + +class TableShows(BaseModel): + alternate_titles = TextField(column_name='alternateTitles', null=True) + audio_language = TextField(null=True) + fanart = TextField(null=True) + forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')]) + hearing_impaired = TextField(null=True) + languages = TextField(null=True) + overview = TextField(null=True) + path = TextField(null=False, unique=True) + poster = TextField(null=True) + sonarr_series_id = IntegerField(column_name='sonarrSeriesId', null=True, unique=True) + sort_title = TextField(column_name='sortTitle', null=True) + title = TextField(null=True) + tvdb_id = IntegerField(column_name='tvdbId', null=True, unique=True, primary_key=True) + year = TextField(null=True) + + class Meta: + table_name = 'table_shows' + + +class TableEpisodes(BaseModel): + rowid = IntegerField() + audio_codec = TextField(null=True) + episode = IntegerField(null=False) + failed_attempts = TextField(column_name='failedAttempts', null=True) + format = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + path = TextField(null=False) + resolution = TextField(null=True) + scene_name = TextField(null=True) + season = IntegerField(null=False) + sonarr_episode_id = IntegerField(column_name='sonarrEpisodeId', unique=True, null=False) + sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False) + subtitles = TextField(null=True) + title = TextField(null=True) + video_codec = TextField(null=True) + episode_file_id = IntegerField(null=True) + + class Meta: + table_name = 'table_episodes' + primary_key = False + + +class TableMovies(BaseModel): + rowid = IntegerField() + alternative_titles = TextField(column_name='alternativeTitles', null=True) + audio_codec = TextField(null=True) + audio_language = TextField(null=True) + failed_attempts = TextField(column_name='failedAttempts', null=True) + fanart = TextField(null=True) + forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')]) + format = TextField(null=True) + hearing_impaired = TextField(null=True) + imdb_id = TextField(column_name='imdbId', null=True) + languages = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + overview = TextField(null=True) + path = TextField(unique=True) + poster = TextField(null=True) + radarr_id = IntegerField(column_name='radarrId', null=False, unique=True) + resolution = TextField(null=True) + scene_name = TextField(column_name='sceneName', null=True) + sort_title = TextField(column_name='sortTitle', null=True) + subtitles = TextField(null=True) + title = TextField(null=False) + tmdb_id = TextField(column_name='tmdbId', primary_key=True, null=False) + video_codec = TextField(null=True) + year = TextField(null=True) + movie_file_id = IntegerField(null=True) + + class Meta: + table_name = 'table_movies' + + +class TableHistory(BaseModel): + id = PrimaryKeyField(null=False) + action = IntegerField(null=False) + description = TextField(null=False) + language = TextField(null=True) + provider = TextField(null=True) + score = TextField(null=True) + sonarr_episode_id = ForeignKeyField(TableEpisodes, field='sonarr_episode_id', column_name='sonarrEpisodeId', null=False) + sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False) + timestamp = IntegerField(null=False) + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history' + + +class TableHistoryMovie(BaseModel): + id = PrimaryKeyField(null=False) + action = IntegerField(null=False) + description = TextField(null=False) + language = TextField(null=True) + provider = TextField(null=True) + radarr_id = ForeignKeyField(TableMovies, field='radarr_id', column_name='radarrId', null=False) + score = TextField(null=True) + timestamp = IntegerField(null=False) + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history_movie' + + +class TableSettingsLanguages(BaseModel): + code2 = TextField(null=False) + code3 = TextField(null=False, unique=True, primary_key=True) + code3b = TextField(null=True) + enabled = IntegerField(null=True) + name = TextField(null=False) + + class Meta: + table_name = 'table_settings_languages' + + +class TableSettingsNotifier(BaseModel): + enabled = IntegerField(null=False) + name = TextField(null=False, primary_key=True) + url = TextField(null=True) + + class Meta: + table_name = 'table_settings_notifier' + + +def database_init(): + database.init(os.path.join(args.config_dir, 'db', 'bazarr.db')) + database.start() + database.connect() + + database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. + database.cache_size = -1024 # Number of KB of cache for wal-journal. + # Must be negative because positive means number of pages. + database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. + + # Database tables creation if they don't exists + models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages, + TableSettingsNotifier, System] + + database.create_tables(models_list, safe=True) + + # Upgrade DB schema + database_upgrade() + + +def database_upgrade(): + # Database migration + migrator = SqliteMigrator(database) + + # TableShows migration + table_shows_columns = [] + for column in database.get_columns('table_shows'): + table_shows_columns.append(column.name) + if 'forced' not in table_shows_columns: + migrate(migrator.add_column('table_shows', 'forced', TableShows.forced)) + + # TableEpisodes migration + table_episodes_columns = [] + for column in database.get_columns('table_episodes'): + table_episodes_columns.append(column.name) + if 'episode_file_id' not in table_episodes_columns: + migrate(migrator.add_column('table_episodes', 'episode_file_id', TableEpisodes.episode_file_id)) + + # TableMovies migration + table_movies_columns = [] + for column in database.get_columns('table_movies'): + table_movies_columns.append(column.name) + if 'forced' not in table_movies_columns: + migrate(migrator.add_column('table_movies', 'forced', TableMovies.forced)) + if 'movie_file_id' not in table_movies_columns: + migrate(migrator.add_column('table_movies', 'movie_file_id', TableMovies.movie_file_id)) + + +def wal_cleaning(): + database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. + database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. diff --git a/bazarr/get_languages.py b/bazarr/get_languages.py index ab71deb3f..2e9c7edc6 100644 --- a/bazarr/get_languages.py +++ b/bazarr/get_languages.py @@ -5,7 +5,7 @@ import pycountry from get_args import args from subzero.language import Language -from database import TableSettingsLanguages +from database import database def load_language_in_db(): @@ -16,17 +16,10 @@ def load_language_in_db(): # Insert languages in database table for lang in langs: - TableSettingsLanguages.insert( - lang - ).on_conflict_ignore().execute() - - TableSettingsLanguages.insert( - { - TableSettingsLanguages.code3: 'pob', - TableSettingsLanguages.code2: 'pb', - TableSettingsLanguages.name: 'Brazilian Portuguese' - } - ).on_conflict_ignore().execute() + database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) VALUES (?, ?, ?)", (lang,)) + + database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) " + "VALUES ('pob', 'pb', 'Brazilian Portuguese')") langs = [{'code3b': lang.bibliographic, 'code3': lang.alpha_3} for lang in pycountry.languages @@ -34,85 +27,49 @@ def load_language_in_db(): # Update languages in database table for lang in langs: - TableSettingsLanguages.update( - { - TableSettingsLanguages.code3b: lang['code3b'] - } - ).where( - TableSettingsLanguages.code3 == lang['code3'] - ).execute() + database.execute("UPDATE table_settings_languages SET code3b=? WHERE code3=?", (lang['code3b'], lang['code3'])) def language_from_alpha2(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.name - ).where( - TableSettingsLanguages.code2 == lang - ).first() - return result.name + result = database.execute("SELECT name FROM table_settings_languages WHERE code2=?", (lang,)) + return result[0]['name'] or None def language_from_alpha3(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.name - ).where( - (TableSettingsLanguages.code3 == lang) | - (TableSettingsLanguages.code3b == lang) - ).first() - return result.name + result = database.execute("SELECT name FROM table_settings_languages WHERE code3=? or code3b=?", (lang, lang)) + return result[0]['name'] or None def alpha2_from_alpha3(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.code2 - ).where( - (TableSettingsLanguages.code3 == lang) | - (TableSettingsLanguages.code3b == lang) - ).first() - return result.code2 + result = database.execute("SELECT code2 FROM table_settings_languages WHERE code3=? or code3b=?", (lang, lang)) + return result[0]['code2'] or None def alpha2_from_language(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.code2 - ).where( - TableSettingsLanguages.name == lang - ).first() - return result.code2 + result = database.execute("SELECT code2 FROM table_settings_languages WHERE name=?", (lang,)) + return result[0]['code2'] or None def alpha3_from_alpha2(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.code3 - ).where( - TableSettingsLanguages.code2 == lang - ).first() - return result.code3 + result = database.execute("SELECT code3 FROM table_settings_languages WHERE code2=?", (lang,)) + return result[0]['code3'] or None def alpha3_from_language(lang): - result = TableSettingsLanguages.select( - TableSettingsLanguages.code3 - ).where( - TableSettingsLanguages.name == lang - ).first() - return result.code3 + result = database.execute("SELECT code3 FROM table_settings_languages WHERE name=?", (lang,)) + return result[0]['code3'] or None def get_language_set(): - languages = TableSettingsLanguages.select( - TableSettingsLanguages.code3 - ).where( - TableSettingsLanguages.enabled == 1 - ) + languages = database.execute("SELECT code3 FROM table_settings_languages WHERE enabled=1") language_set = set() for lang in languages: - if lang.code3 == 'pob': + if lang['code3'] == 'pob': language_set.add(Language('por', 'BR')) else: - language_set.add(Language(lang.code3)) + language_set.add(Language(lang['code3'])) return language_set diff --git a/bazarr/init.py b/bazarr/init.py index e76b0c80f..552848bb7 100644 --- a/bazarr/init.py +++ b/bazarr/init.py @@ -11,7 +11,6 @@ from config import settings from check_update import check_releases from get_args import args from utils import get_binary -from database import database_init from dogpile.cache.region import register_backend as register_cache_backend import subliminal @@ -55,9 +54,6 @@ if not os.path.exists(os.path.join(args.config_dir, 'cache')): os.mkdir(os.path.join(args.config_dir, 'cache')) logging.debug("BAZARR Created cache folder") -# Initiate database -database_init() - # Configure dogpile file caching for Subliminal request register_cache_backend("subzero.cache.file", "subzero.cache_backends.file", "SZFileBackend") subliminal.region.configure('subzero.cache.file', expiration_time=datetime.timedelta(days=30), diff --git a/bazarr/main.py b/bazarr/main.py index 1ebe11bd6..5f7511fbd 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -1,6 +1,6 @@ # coding=utf-8 -bazarr_version = '0.8.2.5' +bazarr_version = '0.8.3' import gc import sys @@ -17,14 +17,12 @@ import warnings import queueconfig import platform import apprise -from peewee import fn, JOIN import operator from calendar import day_name from get_args import args from init import * -from database import database, TableEpisodes, TableShows, TableMovies, TableHistory, TableHistoryMovie, \ - TableSettingsLanguages, TableSettingsNotifier, System +from database import database from notifier import update_notifier from logger import configure_logging, empty_log @@ -85,16 +83,7 @@ else: bottle.ERROR_PAGE_TEMPLATE = bottle.ERROR_PAGE_TEMPLATE.replace('if DEBUG and', 'if') # Reset restart required warning on start -if System.select().count(): - System.update({ - System.configured: 0, - System.updated: 0 - }).execute() -else: - System.insert({ - System.configured: 0, - System.updated: 0 - }).execute() +database.execute("UPDATE system SET configured='0', updated='0'") # Load languages in database load_language_in_db() @@ -193,7 +182,6 @@ def shutdown(): else: server.stop() database.close() - database.stop() stop_file.write('') stop_file.close() sys.exit(0) @@ -210,7 +198,6 @@ def restart(): logging.info('Bazarr is being restarted...') server.stop() database.close() - database.stop() restart_file.write('') restart_file.close() sys.exit(0) @@ -222,7 +209,7 @@ def wizard(): authorize() # Get languages list - settings_languages = TableSettingsLanguages.select().order_by(TableSettingsLanguages.name) + settings_languages = database.execute("SELECT * FROM table_settings_languages ORDER BY name") # Get providers list settings_providers = sorted(provider_manager.names()) @@ -394,14 +381,10 @@ def save_wizard(): settings_subliminal_languages = request.forms.getall('settings_subliminal_languages') # Disable all languages in DB - TableSettingsLanguages.update({TableSettingsLanguages.enabled: 0}) + database.execute("UPDATE table_settings_languages SET enabled=0") for item in settings_subliminal_languages: # Enable each desired language in DB - TableSettingsLanguages.update( - {TableSettingsLanguages.enabled: 1} - ).where( - TableSettingsLanguages.code2 == item - ).execute() + database.execute("UPDATE table_settings_languages SET enabled=1 WHERE code2=?", item) settings_serie_default_enabled = request.forms.get('settings_serie_default_enabled') if settings_serie_default_enabled is None: @@ -534,7 +517,7 @@ def redirect_root(): def series(): authorize() - missing_count = TableShows.select().count() + missing_count = database.execute("SELECT COUNT(*) FROM table_shows") page = request.GET.page if page == "": page = "1" @@ -543,79 +526,46 @@ def series(): max_page = int(math.ceil(missing_count / (page_size + 0.0))) # Get list of series - data = TableShows.select( - TableShows.tvdb_id, - TableShows.title, - fn.path_substitution(TableShows.path).alias('path'), - TableShows.languages, - TableShows.hearing_impaired, - TableShows.sonarr_series_id, - TableShows.poster, - TableShows.audio_language, - TableShows.forced - ).order_by( - TableShows.sort_title.asc() - ).paginate( - int(page), - page_size - ) + # path_replace + data = database.execute("SELECT tvdbId, title, path, languages, hearing_impaired, sonarrSeriesId, poster, " + "audio_language, forced FROM table_shows ORDER BY sortTitle ASC LIMIT ? OFFSET ?", + (page_size, offset)) # Get languages list - languages = TableSettingsLanguages.select( - TableSettingsLanguages.code2, - TableSettingsLanguages.name - ).where( - TableSettingsLanguages.enabled == 1 - ) + languages = database.execute("SELECT code2, name FROM table_settings_languages WHERE enabled=1") # Build missing subtitles clause depending on only_monitored - missing_subtitles_clause = [ - (TableShows.languages != 'None'), - (TableEpisodes.missing_subtitles != '[]') - ] if settings.sonarr.getboolean('only_monitored'): - missing_subtitles_clause.append( - (TableEpisodes.monitored == 'True') - ) + missing_subtitles_clause = " AND table_episodes.monitored='True'" + else: + missing_subtitles_clause = '' # Get missing subtitles count by series - missing_subtitles_list = TableShows.select( - TableShows.sonarr_series_id, - fn.COUNT(TableEpisodes.missing_subtitles).alias('missing_subtitles') - ).join_from( - TableShows, TableEpisodes, JOIN.LEFT_OUTER - ).where( - reduce(operator.and_, missing_subtitles_clause) - ).group_by( - TableShows.sonarr_series_id - ) + missing_subtitles_list = database.execute("SELECT table_shows.sonarrSeriesId, " + "COUNT(table_episodes.missing_subtitles) FROM table_shows LEFT JOIN " + "table_episodes ON table_shows.sonarrSeriesId=" + "table_episodes.sonarrSeriesId WHERE table_shows.languages IS NOT 'None' " + "AND table_episodes.missing_subtitles IS NOT '[]'" + + missing_subtitles_clause + " GROUP BY table_shows.sonarrSeriesId") # Build total subtitles clause depending on only_monitored - total_subtitles_clause = [ - (TableShows.languages != 'None') - ] if settings.sonarr.getboolean('only_monitored'): - total_subtitles_clause.append( - (TableEpisodes.monitored == 'True') - ) + total_subtitles_clause = " AND table_episodes.monitored == 'True'" + else: + total_subtitles_clause = '' # Get total subtitles count by series - total_subtitles_list = TableShows.select( - TableShows.sonarr_series_id, - fn.COUNT(TableEpisodes.missing_subtitles).alias('missing_subtitles') - ).join_from( - TableShows, TableEpisodes, JOIN.LEFT_OUTER - ).where( - reduce(operator.and_, total_subtitles_clause) - ).group_by( - TableShows.sonarr_series_id - ) - - return template('series', bazarr_version=bazarr_version, rows=data, - missing_subtitles_list=missing_subtitles_list, total_subtitles_list=total_subtitles_list, - languages=languages, missing_count=missing_count, page=page, max_page=max_page, base_url=base_url, - single_language=settings.general.getboolean('single_language'), page_size=page_size, - current_port=settings.general.port) + total_subtitles_list = database.execute("SELECT table_shows.sonarrSeriesId, " + "COUNT(table_episodes.missing_subtitles) FROM table_shows LEFT JOIN " + "table_episodes ON table_shows.sonarrSeriesId=" + "table_episodes.sonarrSeriesId WHERE table_shows.languages IS NOT 'None'" + + total_subtitles_clause + " GROUP BY table_shows.sonarrSeriesId") + + return template('series', bazarr_version=bazarr_version, rows=data, missing_subtitles_list=missing_subtitles_list, + total_subtitles_list=total_subtitles_list, languages=languages, missing_count=missing_count, + page=page, max_page=max_page, base_url=base_url, + single_language=settings.general.getboolean('single_language'), page_size=page_size, + current_port=settings.general.port) @route(base_url + 'serieseditor') @@ -624,30 +574,15 @@ def serieseditor(): authorize() # Get missing count - missing_count = TableShows().select().count() - - # Get movies list - data = TableShows.select( - TableShows.tvdb_id, - TableShows.title, - fn.path_substitution(TableShows.path).alias('path'), - TableShows.languages, - TableShows.hearing_impaired, - TableShows.sonarr_series_id, - TableShows.poster, - TableShows.audio_language, - TableShows.forced - ).order_by( - TableShows.sort_title.asc() - ) + missing_count = database.execute("SELECT COUNT(*) FROM table_shows") + + # Get series list + # path_replace + data = database.execute("SELECT tvdbId, title, path, languages, hearing_impaired, sonarrSeriesId, poster, " + "audio_language, forced FROM table_shows ORDER BY sortTitle ASC") # Get languages list - languages = TableSettingsLanguages.select( - TableSettingsLanguages.code2, - TableSettingsLanguages.name - ).where( - TableSettingsLanguages.enabled == 1 - ) + languages = database.execute("SELECT code2, name FROM table_settings_languages WHERE enabled=1") return template('serieseditor', bazarr_version=bazarr_version, rows=data, languages=languages, missing_count=missing_count, base_url=base_url, @@ -664,33 +599,19 @@ def search_json(query): if settings.general.getboolean('use_sonarr'): # Get matching series - series = TableShows.select( - TableShows.title, - TableShows.sonarr_series_id, - TableShows.year - ).where( - TableShows.title ** query - ).order_by( - TableShows.title.asc() - ) + series = database.execute("SELECT title, sonarrSeriesId, year FROM table_shows WHERE title LIKE ? ORDER BY " + "title ASC", (query,)) for serie in series: - search_list.append(dict([('name', re.sub(r'\ \(\d{4}\)', '', serie.title) + ' (' + serie.year + ')'), - ('url', base_url + 'episodes/' + str(serie.sonarr_series_id))])) + search_list.append(dict([('name', re.sub(r'\ \(\d{4}\)', '', serie['title']) + ' (' + serie['year'] + ')'), + ('url', base_url + 'episodes/' + str(serie['sonarrSeriesId']))])) if settings.general.getboolean('use_radarr'): # Get matching movies - movies = TableMovies.select( - TableMovies.title, - TableMovies.radarr_id, - TableMovies.year - ).where( - TableMovies.title ** query - ).order_by( - TableMovies.title.asc() - ) + movies = database.execute("SELECT title, radarrId, year FROM table_movies WEHRE title LIKE ? ORDER BY " + "title ASC", (query,)) for movie in movies: - search_list.append(dict([('name', re.sub(r'\ \(\d{4}\)', '', movie.title) + ' (' + movie.year + ')'), - ('url', base_url + 'movie/' + str(movie.radarr_id))])) + search_list.append(dict([('name', re.sub(r'\ \(\d{4}\)', '', movie['title']) + ' (' + movie['year'] + ')'), + ('url', base_url + 'movie/' + str(movie['radarrId']))])) response.content_type = 'application/json' return dict(items=search_list) @@ -726,15 +647,8 @@ def edit_series(no): else: hi = "False" - result = TableShows.update( - { - TableShows.languages: lang, - TableShows.hearing_impaired: hi, - TableShows.forced: forced - } - ).where( - TableShows.sonarr_series_id == no - ).execute() + result = database.execute("UPDATE table_shows (languages, hearing_impaired, forced) VALUES (?,?,?) WHERE " + "sonarrSeriesId=?", (lang, hi, forced, no)) list_missing_subtitles(no=no) @@ -759,29 +673,11 @@ def edit_serieseditor(): lang = 'None' else: lang = str(lang) - TableShows.update( - { - TableShows.languages: lang - } - ).where( - TableShows.sonarr_series_id % serie - ).execute() + database.execute("UPDATE table_shows (languages) VALUES (?) WHERE sonarrSeriesId=?", (lang,serie)) if hi != '': - TableShows.update( - { - TableShows.hearing_impaired: hi - } - ).where( - TableShows.sonarr_series_id % serie - ).execute() + database.execute("UPDATE table_shows (hearing_impaired) VALUES (?) WHERE sonarrSeriesId=?", (hi, serie)) if forced != '': - TableShows.update( - { - TableShows.forced: forced - } - ).where( - TableShows.sonarr_series_id % serie - ).execute() + database.execute("UPDATE table_shows (forced) VALUES (?) WHERE sonarrSeriesId=?", (forced, serie)) for serie in series: list_missing_subtitles(no=serie) diff --git a/bazarr/notifier.py b/bazarr/notifier.py index 9a0eea98f..7e7c82233 100644 --- a/bazarr/notifier.py +++ b/bazarr/notifier.py @@ -5,7 +5,7 @@ import os import logging from get_args import args -from database import TableSettingsNotifier, TableShows, TableEpisodes, TableMovies +from database import database def update_notifier(): @@ -18,13 +18,11 @@ def update_notifier(): notifiers_new = [] notifiers_old = [] - notifiers_current_db = TableSettingsNotifier.select( - TableSettingsNotifier.name - ) + notifiers_current_db = database.execute("SELECT name FROM table_settings_notifier") notifiers_current = [] for notifier in notifiers_current_db: - notifiers_current.append(notifier.name) + notifiers_current.append(notifier['name']) for x in results['schemas']: if x['service_name'] not in notifiers_current: @@ -37,60 +35,32 @@ def update_notifier(): notifiers_to_delete = list(set(notifier_current) - set(notifiers_old)) for notifier_new in notifiers_new: - TableSettingsNotifier.insert( - { - TableSettingsNotifier.name: notifier_new, - TableSettingsNotifier.enabled: 0 - } - ).execute() + database.execute("INSERT INTO table_settings_notifier (name, enabled) VALUES (?, ?)", (notifier_new, 0)) for notifier_to_delete in notifiers_to_delete: - TableSettingsNotifier.delete().where( - TableSettingsNotifier.name == notifier_to_delete - ).execute() + database.execute("DELETE FROM table_settings_notifier WHERE name=?", (notifier_to_delete,)) def get_notifier_providers(): - providers = TableSettingsNotifier.select( - TableSettingsNotifier.name, - TableSettingsNotifier.url - ).where( - TableSettingsNotifier.enabled == 1 - ) - + providers = database.execute("SELECT name, url FROM table_settings_notifier WHERE enabled=1") return providers def get_series_name(sonarrSeriesId): - data = TableShows.select( - TableShows.title - ).where( - TableShows.sonarr_series_id == sonarrSeriesId - ).first() + data = database.execute("SELECT title FROM table_shows WHERE sonarrSeriesId=?", (sonarrSeriesId,)) - return data.title + return data[0]['title'] or None def get_episode_name(sonarrEpisodeId): - data = TableEpisodes.select( - TableEpisodes.title, - TableEpisodes.season, - TableEpisodes.episode - ).where( - TableEpisodes.sonarr_episode_id == sonarrEpisodeId - ).first() + data = database.execute("SELECT title, season, episode FROM table_episodes WHERE sonarrEpisodeId=?", (sonarrEpisodeId,)) - return data.title, data.season, data.episode + return data['title'], data['season'], data['episode'] def get_movies_name(radarrId): - data = TableMovies.select( - TableMovies.title - ).where( - TableMovies.radarr_id == radarrId - ).first() - - return data.title + data = database.execute("SELECT title FROM table_movies WHERE radarrId=?", (radarrId,)) + return data['title'] def send_notifications(sonarrSeriesId, sonarrEpisodeId, message): diff --git a/bazarr/utils.py b/bazarr/utils.py index 1409f495f..ec8d23a3b 100644 --- a/bazarr/utils.py +++ b/bazarr/utils.py @@ -10,7 +10,7 @@ import requests from whichcraft import which from get_args import args from config import settings, url_sonarr, url_radarr -from database import TableHistory, TableHistoryMovie +from database import database from subliminal import region as subliminal_cache_region import datetime @@ -19,35 +19,18 @@ import glob def history_log(action, sonarrSeriesId, sonarrEpisodeId, description, video_path=None, language=None, provider=None, score=None, forced=False): - TableHistory.insert( - { - TableHistory.action: action, - TableHistory.sonarr_series_id: sonarrSeriesId, - TableHistory.sonarr_episode_id: sonarrEpisodeId, - TableHistory.timestamp: time.time(), - TableHistory.description: description, - TableHistory.video_path: video_path, - TableHistory.language: language, - TableHistory.provider: provider, - TableHistory.score: score - } - ).execute() + database.execute("INSERT INTO table_history (action, sonarrSeriesId, sonarrEpisodeId, timestamp, description," + "video_path, language, provider, score) VALUES (?,?,?,?,?,?,?,?,?)", (action, sonarrSeriesId, + sonarrEpisodeId, time.time(), + description, video_path, + language, provider, score)) def history_log_movie(action, radarrId, description, video_path=None, language=None, provider=None, score=None, forced=False): - TableHistoryMovie.insert( - { - TableHistoryMovie.action: action, - TableHistoryMovie.radarr_id: radarrId, - TableHistoryMovie.timestamp: time.time(), - TableHistoryMovie.description: description, - TableHistoryMovie.video_path: video_path, - TableHistoryMovie.language: language, - TableHistoryMovie.provider: provider, - TableHistoryMovie.score: score - } - ).execute() + database.execute("INSERT INTO table_history_movie (action, radarrId, timestamp, description, video_path, language, " + "provider, score) VALUES (?,?,?,?,?,?,?,?)", (action, radarrId, time.time(), description, + video_path, language, provider, score)) def get_binary(name): diff --git a/libs/peewee.py b/libs/peewee.py deleted file mode 100644 index c41dc7135..000000000 --- a/libs/peewee.py +++ /dev/null @@ -1,7508 +0,0 @@ -from bisect import bisect_left -from bisect import bisect_right -from contextlib import contextmanager -from copy import deepcopy -from functools import wraps -from inspect import isclass -import calendar -import collections -import datetime -import decimal -import hashlib -import itertools -import logging -import operator -import re -import socket -import struct -import sys -import threading -import time -import uuid -import warnings -try: - from collections.abc import Mapping -except ImportError: - from collections import Mapping - -try: - from pysqlite3 import dbapi2 as pysq3 -except ImportError: - try: - from pysqlite2 import dbapi2 as pysq3 - except ImportError: - pysq3 = None -try: - import sqlite3 -except ImportError: - sqlite3 = pysq3 -else: - if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: - sqlite3 = pysq3 -try: - from psycopg2cffi import compat - compat.register() -except ImportError: - pass -try: - import psycopg2 - from psycopg2 import extensions as pg_extensions - try: - from psycopg2 import errors as pg_errors - except ImportError: - pg_errors = None -except ImportError: - psycopg2 = pg_errors = None - -mysql_passwd = False -try: - import pymysql as mysql -except ImportError: - try: - import MySQLdb as mysql - mysql_passwd = True - except ImportError: - mysql = None - - -__version__ = '3.11.2' -__all__ = [ - 'AsIs', - 'AutoField', - 'BareField', - 'BigAutoField', - 'BigBitField', - 'BigIntegerField', - 'BinaryUUIDField', - 'BitField', - 'BlobField', - 'BooleanField', - 'Case', - 'Cast', - 'CharField', - 'Check', - 'chunked', - 'Column', - 'CompositeKey', - 'Context', - 'Database', - 'DatabaseError', - 'DatabaseProxy', - 'DataError', - 'DateField', - 'DateTimeField', - 'DecimalField', - 'DeferredForeignKey', - 'DeferredThroughModel', - 'DJANGO_MAP', - 'DoesNotExist', - 'DoubleField', - 'DQ', - 'EXCLUDED', - 'Field', - 'FixedCharField', - 'FloatField', - 'fn', - 'ForeignKeyField', - 'IdentityField', - 'ImproperlyConfigured', - 'Index', - 'IntegerField', - 'IntegrityError', - 'InterfaceError', - 'InternalError', - 'IPField', - 'JOIN', - 'ManyToManyField', - 'Model', - 'ModelIndex', - 'MySQLDatabase', - 'NotSupportedError', - 'OP', - 'OperationalError', - 'PostgresqlDatabase', - 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. - 'prefetch', - 'ProgrammingError', - 'Proxy', - 'QualifiedNames', - 'SchemaManager', - 'SmallIntegerField', - 'Select', - 'SQL', - 'SqliteDatabase', - 'Table', - 'TextField', - 'TimeField', - 'TimestampField', - 'Tuple', - 'UUIDField', - 'Value', - 'ValuesList', - 'Window', -] - -try: # Python 2.7+ - from logging import NullHandler -except ImportError: - class NullHandler(logging.Handler): - def emit(self, record): - pass - -logger = logging.getLogger('peewee') -logger.addHandler(NullHandler()) - - -if sys.version_info[0] == 2: - text_type = unicode - bytes_type = str - buffer_type = buffer - izip_longest = itertools.izip_longest - callable_ = callable - exec('def reraise(tp, value, tb=None): raise tp, value, tb') - def print_(s): - sys.stdout.write(s) - sys.stdout.write('\n') -else: - import builtins - try: - from collections.abc import Callable - except ImportError: - from collections import Callable - from functools import reduce - callable_ = lambda c: isinstance(c, Callable) - text_type = str - bytes_type = bytes - buffer_type = memoryview - basestring = str - long = int - print_ = getattr(builtins, 'print') - izip_longest = itertools.zip_longest - def reraise(tp, value, tb=None): - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - - -if sqlite3: - sqlite3.register_adapter(decimal.Decimal, str) - sqlite3.register_adapter(datetime.date, str) - sqlite3.register_adapter(datetime.time, str) - __sqlite_version__ = sqlite3.sqlite_version_info -else: - __sqlite_version__ = (0, 0, 0) - - -__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) - -# Sqlite does not support the `date_part` SQL function, so we will define an -# implementation in python. -__sqlite_datetime_formats__ = ( - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d', - '%H:%M:%S', - '%H:%M:%S.%f', - '%H:%M') - -__sqlite_date_trunc__ = { - 'year': '%Y-01-01 00:00:00', - 'month': '%Y-%m-01 00:00:00', - 'day': '%Y-%m-%d 00:00:00', - 'hour': '%Y-%m-%d %H:00:00', - 'minute': '%Y-%m-%d %H:%M:00', - 'second': '%Y-%m-%d %H:%M:%S'} - -__mysql_date_trunc__ = __sqlite_date_trunc__.copy() -__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' -__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' - -def _sqlite_date_part(lookup_type, datetime_string): - assert lookup_type in __date_parts__ - if not datetime_string: - return - dt = format_date_time(datetime_string, __sqlite_datetime_formats__) - return getattr(dt, lookup_type) - -def _sqlite_date_trunc(lookup_type, datetime_string): - assert lookup_type in __sqlite_date_trunc__ - if not datetime_string: - return - dt = format_date_time(datetime_string, __sqlite_datetime_formats__) - return dt.strftime(__sqlite_date_trunc__[lookup_type]) - - -def __deprecated__(s): - warnings.warn(s, DeprecationWarning) - - -class attrdict(dict): - def __getattr__(self, attr): - try: - return self[attr] - except KeyError: - raise AttributeError(attr) - def __setattr__(self, attr, value): self[attr] = value - def __iadd__(self, rhs): self.update(rhs); return self - def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d - -SENTINEL = object() - -#: Operations for use in SQL expressions. -OP = attrdict( - AND='AND', - OR='OR', - ADD='+', - SUB='-', - MUL='*', - DIV='/', - BIN_AND='&', - BIN_OR='|', - XOR='#', - MOD='%', - EQ='=', - LT='<', - LTE='<=', - GT='>', - GTE='>=', - NE='!=', - IN='IN', - NOT_IN='NOT IN', - IS='IS', - IS_NOT='IS NOT', - LIKE='LIKE', - ILIKE='ILIKE', - BETWEEN='BETWEEN', - REGEXP='REGEXP', - IREGEXP='IREGEXP', - CONCAT='||', - BITWISE_NEGATION='~') - -# To support "django-style" double-underscore filters, create a mapping between -# operation name and operation code, e.g. "__eq" == OP.EQ. -DJANGO_MAP = attrdict({ - 'eq': operator.eq, - 'lt': operator.lt, - 'lte': operator.le, - 'gt': operator.gt, - 'gte': operator.ge, - 'ne': operator.ne, - 'in': operator.lshift, - 'is': lambda l, r: Expression(l, OP.IS, r), - 'like': lambda l, r: Expression(l, OP.LIKE, r), - 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), - 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), -}) - -#: Mapping of field type to the data-type supported by the database. Databases -#: may override or add to this list. -FIELD = attrdict( - AUTO='INTEGER', - BIGAUTO='BIGINT', - BIGINT='BIGINT', - BLOB='BLOB', - BOOL='SMALLINT', - CHAR='CHAR', - DATE='DATE', - DATETIME='DATETIME', - DECIMAL='DECIMAL', - DEFAULT='', - DOUBLE='REAL', - FLOAT='REAL', - INT='INTEGER', - SMALLINT='SMALLINT', - TEXT='TEXT', - TIME='TIME', - UUID='TEXT', - UUIDB='BLOB', - VARCHAR='VARCHAR') - -#: Join helpers (for convenience) -- all join types are supported, this object -#: is just to help avoid introducing errors by using strings everywhere. -JOIN = attrdict( - INNER='INNER', - LEFT_OUTER='LEFT OUTER', - RIGHT_OUTER='RIGHT OUTER', - FULL='FULL', - FULL_OUTER='FULL OUTER', - CROSS='CROSS', - NATURAL='NATURAL') - -# Row representations. -ROW = attrdict( - TUPLE=1, - DICT=2, - NAMED_TUPLE=3, - CONSTRUCTOR=4, - MODEL=5) - -SCOPE_NORMAL = 1 -SCOPE_SOURCE = 2 -SCOPE_VALUES = 4 -SCOPE_CTE = 8 -SCOPE_COLUMN = 16 - -# Rules for parentheses around subqueries in compound select. -CSQ_PARENTHESES_NEVER = 0 -CSQ_PARENTHESES_ALWAYS = 1 -CSQ_PARENTHESES_UNNESTED = 2 - -# Regular expressions used to convert class names to snake-case table names. -# First regex handles acronym followed by word or initial lower-word followed -# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. -# Second regex handles the normal case of two title-cased words. -SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') -SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') - -# Helper functions that are used in various parts of the codebase. -MODEL_BASE = '_metaclass_helper_' - -def with_metaclass(meta, base=object): - return meta(MODEL_BASE, (base,), {}) - -def merge_dict(source, overrides): - merged = source.copy() - if overrides: - merged.update(overrides) - return merged - -def quote(path, quote_chars): - if len(path) == 1: - return path[0].join(quote_chars) - return '.'.join([part.join(quote_chars) for part in path]) - -is_model = lambda o: isclass(o) and issubclass(o, Model) - -def ensure_tuple(value): - if value is not None: - return value if isinstance(value, (list, tuple)) else (value,) - -def ensure_entity(value): - if value is not None: - return value if isinstance(value, Node) else Entity(value) - -def make_snake_case(s): - first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) - return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() - -def chunked(it, n): - marker = object() - for group in (list(g) for g in izip_longest(*[iter(it)] * n, - fillvalue=marker)): - if group[-1] is marker: - del group[group.index(marker):] - yield group - - -class _callable_context_manager(object): - def __call__(self, fn): - @wraps(fn) - def inner(*args, **kwargs): - with self: - return fn(*args, **kwargs) - return inner - - -class Proxy(object): - """ - Create a proxy or placeholder for another object. - """ - __slots__ = ('obj', '_callbacks') - - def __init__(self): - self._callbacks = [] - self.initialize(None) - - def initialize(self, obj): - self.obj = obj - for callback in self._callbacks: - callback(obj) - - def attach_callback(self, callback): - self._callbacks.append(callback) - return callback - - def passthrough(method): - def inner(self, *args, **kwargs): - if self.obj is None: - raise AttributeError('Cannot use uninitialized Proxy.') - return getattr(self.obj, method)(*args, **kwargs) - return inner - - # Allow proxy to be used as a context-manager. - __enter__ = passthrough('__enter__') - __exit__ = passthrough('__exit__') - - def __getattr__(self, attr): - if self.obj is None: - raise AttributeError('Cannot use uninitialized Proxy.') - return getattr(self.obj, attr) - - def __setattr__(self, attr, value): - if attr not in self.__slots__: - raise AttributeError('Cannot set attribute on proxy.') - return super(Proxy, self).__setattr__(attr, value) - - -class DatabaseProxy(Proxy): - """ - Proxy implementation specifically for proxying `Database` objects. - """ - def connection_context(self): - return ConnectionContext(self) - def atomic(self): - return _atomic(self) - def manual_commit(self): - return _manual(self) - def transaction(self): - return _transaction(self) - def savepoint(self): - return _savepoint(self) - - -class ModelDescriptor(object): pass - - -# SQL Generation. - - -class AliasManager(object): - __slots__ = ('_counter', '_current_index', '_mapping') - - def __init__(self): - # A list of dictionaries containing mappings at various depths. - self._counter = 0 - self._current_index = 0 - self._mapping = [] - self.push() - - @property - def mapping(self): - return self._mapping[self._current_index - 1] - - def add(self, source): - if source not in self.mapping: - self._counter += 1 - self[source] = 't%d' % self._counter - return self.mapping[source] - - def get(self, source, any_depth=False): - if any_depth: - for idx in reversed(range(self._current_index)): - if source in self._mapping[idx]: - return self._mapping[idx][source] - return self.add(source) - - def __getitem__(self, source): - return self.get(source) - - def __setitem__(self, source, alias): - self.mapping[source] = alias - - def push(self): - self._current_index += 1 - if self._current_index > len(self._mapping): - self._mapping.append({}) - - def pop(self): - if self._current_index == 1: - raise ValueError('Cannot pop() from empty alias manager.') - self._current_index -= 1 - - -class State(collections.namedtuple('_State', ('scope', 'parentheses', - 'settings'))): - def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): - return super(State, cls).__new__(cls, scope, parentheses, kwargs) - - def __call__(self, scope=None, parentheses=None, **kwargs): - # Scope and settings are "inherited" (parentheses is not, however). - scope = self.scope if scope is None else scope - - # Try to avoid unnecessary dict copying. - if kwargs and self.settings: - settings = self.settings.copy() # Copy original settings dict. - settings.update(kwargs) # Update copy with overrides. - elif kwargs: - settings = kwargs - else: - settings = self.settings - return State(scope, parentheses, **settings) - - def __getattr__(self, attr_name): - return self.settings.get(attr_name) - - -def __scope_context__(scope): - @contextmanager - def inner(self, **kwargs): - with self(scope=scope, **kwargs): - yield self - return inner - - -class Context(object): - __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') - - def __init__(self, **settings): - self.stack = [] - self._sql = [] - self._values = [] - self.alias_manager = AliasManager() - self.state = State(**settings) - - def as_new(self): - return Context(**self.state.settings) - - def column_sort_key(self, item): - return item[0].get_sort_key(self) - - @property - def scope(self): - return self.state.scope - - @property - def parentheses(self): - return self.state.parentheses - - @property - def subquery(self): - return self.state.subquery - - def __call__(self, **overrides): - if overrides and overrides.get('scope') == self.scope: - del overrides['scope'] - - self.stack.append(self.state) - self.state = self.state(**overrides) - return self - - scope_normal = __scope_context__(SCOPE_NORMAL) - scope_source = __scope_context__(SCOPE_SOURCE) - scope_values = __scope_context__(SCOPE_VALUES) - scope_cte = __scope_context__(SCOPE_CTE) - scope_column = __scope_context__(SCOPE_COLUMN) - - def __enter__(self): - if self.parentheses: - self.literal('(') - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.parentheses: - self.literal(')') - self.state = self.stack.pop() - - @contextmanager - def push_alias(self): - self.alias_manager.push() - yield - self.alias_manager.pop() - - def sql(self, obj): - if isinstance(obj, (Node, Context)): - return obj.__sql__(self) - elif is_model(obj): - return obj._meta.table.__sql__(self) - else: - return self.sql(Value(obj)) - - def literal(self, keyword): - self._sql.append(keyword) - return self - - def value(self, value, converter=None, add_param=True): - if converter: - value = converter(value) - if isinstance(value, Node): - return self.sql(value) - elif converter is None and self.state.converter: - # Explicitly check for None so that "False" can be used to signify - # that no conversion should be applied. - value = self.state.converter(value) - - if isinstance(value, Node): - with self(converter=None): - return self.sql(value) - - self._values.append(value) - return self.literal(self.state.param or '?') if add_param else self - - def __sql__(self, ctx): - ctx._sql.extend(self._sql) - ctx._values.extend(self._values) - return ctx - - def parse(self, node): - return self.sql(node).query() - - def query(self): - return ''.join(self._sql), self._values - - -def query_to_string(query): - # NOTE: this function is not exported by default as it might be misused -- - # and this misuse could lead to sql injection vulnerabilities. This - # function is intended for debugging or logging purposes ONLY. - db = getattr(query, '_database', None) - if db is not None: - ctx = db.get_sql_context() - else: - ctx = Context() - - sql, params = ctx.sql(query).query() - if not params: - return sql - - param = ctx.state.param or '?' - if param == '?': - sql = sql.replace('?', '%s') - - return sql % tuple(map(_query_val_transform, params)) - -def _query_val_transform(v): - # Interpolate parameters. - if isinstance(v, (text_type, datetime.datetime, datetime.date, - datetime.time)): - v = "'%s'" % v - elif isinstance(v, bytes_type): - try: - v = v.decode('utf8') - except UnicodeDecodeError: - v = v.decode('raw_unicode_escape') - v = "'%s'" % v - elif isinstance(v, int): - v = '%s' % int(v) # Also handles booleans -> 1 or 0. - elif v is None: - v = 'NULL' - else: - v = str(v) - return v - - -# AST. - - -class Node(object): - _coerce = True - - def clone(self): - obj = self.__class__.__new__(self.__class__) - obj.__dict__ = self.__dict__.copy() - return obj - - def __sql__(self, ctx): - raise NotImplementedError - - @staticmethod - def copy(method): - def inner(self, *args, **kwargs): - clone = self.clone() - method(clone, *args, **kwargs) - return clone - return inner - - def coerce(self, _coerce=True): - if _coerce != self._coerce: - clone = self.clone() - clone._coerce = _coerce - return clone - return self - - def is_alias(self): - return False - - def unwrap(self): - return self - - -class ColumnFactory(object): - __slots__ = ('node',) - - def __init__(self, node): - self.node = node - - def __getattr__(self, attr): - return Column(self.node, attr) - - -class _DynamicColumn(object): - __slots__ = () - - def __get__(self, instance, instance_type=None): - if instance is not None: - return ColumnFactory(instance) # Implements __getattr__(). - return self - - -class _ExplicitColumn(object): - __slots__ = () - - def __get__(self, instance, instance_type=None): - if instance is not None: - raise AttributeError( - '%s specifies columns explicitly, and does not support ' - 'dynamic column lookups.' % instance) - return self - - -class Source(Node): - c = _DynamicColumn() - - def __init__(self, alias=None): - super(Source, self).__init__() - self._alias = alias - - @Node.copy - def alias(self, name): - self._alias = name - - def select(self, *columns): - if not columns: - columns = (SQL('*'),) - return Select((self,), columns) - - def join(self, dest, join_type='INNER', on=None): - return Join(self, dest, join_type, on) - - def left_outer_join(self, dest, on=None): - return Join(self, dest, JOIN.LEFT_OUTER, on) - - def cte(self, name, recursive=False, columns=None): - return CTE(name, self, recursive=recursive, columns=columns) - - def get_sort_key(self, ctx): - if self._alias: - return (self._alias,) - return (ctx.alias_manager[self],) - - def apply_alias(self, ctx): - # If we are defining the source, include the "AS alias" declaration. An - # alias is created for the source if one is not already defined. - if ctx.scope == SCOPE_SOURCE: - if self._alias: - ctx.alias_manager[self] = self._alias - ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) - return ctx - - def apply_column(self, ctx): - if self._alias: - ctx.alias_manager[self] = self._alias - return ctx.sql(Entity(ctx.alias_manager[self])) - - -class _HashableSource(object): - def __init__(self, *args, **kwargs): - super(_HashableSource, self).__init__(*args, **kwargs) - self._update_hash() - - @Node.copy - def alias(self, name): - self._alias = name - self._update_hash() - - def _update_hash(self): - self._hash = self._get_hash() - - def _get_hash(self): - return hash((self.__class__, self._path, self._alias)) - - def __hash__(self): - return self._hash - - def __eq__(self, other): - return self._hash == other._hash - - def __ne__(self, other): - return not (self == other) - - -def __bind_database__(meth): - @wraps(meth) - def inner(self, *args, **kwargs): - result = meth(self, *args, **kwargs) - if self._database: - return result.bind(self._database) - return result - return inner - - -def __join__(join_type='INNER', inverted=False): - def method(self, other): - if inverted: - self, other = other, self - return Join(self, other, join_type=join_type) - return method - - -class BaseTable(Source): - __and__ = __join__(JOIN.INNER) - __add__ = __join__(JOIN.LEFT_OUTER) - __sub__ = __join__(JOIN.RIGHT_OUTER) - __or__ = __join__(JOIN.FULL_OUTER) - __mul__ = __join__(JOIN.CROSS) - __rand__ = __join__(JOIN.INNER, inverted=True) - __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) - __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) - __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) - __rmul__ = __join__(JOIN.CROSS, inverted=True) - - -class _BoundTableContext(_callable_context_manager): - def __init__(self, table, database): - self.table = table - self.database = database - - def __enter__(self): - self._orig_database = self.table._database - self.table.bind(self.database) - if self.table._model is not None: - self.table._model.bind(self.database) - return self.table - - def __exit__(self, exc_type, exc_val, exc_tb): - self.table.bind(self._orig_database) - if self.table._model is not None: - self.table._model.bind(self._orig_database) - - -class Table(_HashableSource, BaseTable): - def __init__(self, name, columns=None, primary_key=None, schema=None, - alias=None, _model=None, _database=None): - self.__name__ = name - self._columns = columns - self._primary_key = primary_key - self._schema = schema - self._path = (schema, name) if schema else (name,) - self._model = _model - self._database = _database - super(Table, self).__init__(alias=alias) - - # Allow tables to restrict what columns are available. - if columns is not None: - self.c = _ExplicitColumn() - for column in columns: - setattr(self, column, Column(self, column)) - - if primary_key: - col_src = self if self._columns else self.c - self.primary_key = getattr(col_src, primary_key) - else: - self.primary_key = None - - def clone(self): - # Ensure a deep copy of the column instances. - return Table( - self.__name__, - columns=self._columns, - primary_key=self._primary_key, - schema=self._schema, - alias=self._alias, - _model=self._model, - _database=self._database) - - def bind(self, database=None): - self._database = database - return self - - def bind_ctx(self, database=None): - return _BoundTableContext(self, database) - - def _get_hash(self): - return hash((self.__class__, self._path, self._alias, self._model)) - - @__bind_database__ - def select(self, *columns): - if not columns and self._columns: - columns = [Column(self, column) for column in self._columns] - return Select((self,), columns) - - @__bind_database__ - def insert(self, insert=None, columns=None, **kwargs): - if kwargs: - insert = {} if insert is None else insert - src = self if self._columns else self.c - for key, value in kwargs.items(): - insert[getattr(src, key)] = value - return Insert(self, insert=insert, columns=columns) - - @__bind_database__ - def replace(self, insert=None, columns=None, **kwargs): - return (self - .insert(insert=insert, columns=columns) - .on_conflict('REPLACE')) - - @__bind_database__ - def update(self, update=None, **kwargs): - if kwargs: - update = {} if update is None else update - for key, value in kwargs.items(): - src = self if self._columns else self.c - update[getattr(src, key)] = value - return Update(self, update=update) - - @__bind_database__ - def delete(self): - return Delete(self) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - # Return the quoted table name. - return ctx.sql(Entity(*self._path)) - - if self._alias: - ctx.alias_manager[self] = self._alias - - if ctx.scope == SCOPE_SOURCE: - # Define the table and its alias. - return self.apply_alias(ctx.sql(Entity(*self._path))) - else: - # Refer to the table using the alias. - return self.apply_column(ctx) - - -class Join(BaseTable): - def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): - super(Join, self).__init__(alias=alias) - self.lhs = lhs - self.rhs = rhs - self.join_type = join_type - self._on = on - - def on(self, predicate): - self._on = predicate - return self - - def __sql__(self, ctx): - (ctx - .sql(self.lhs) - .literal(' %s JOIN ' % self.join_type) - .sql(self.rhs)) - if self._on is not None: - ctx.literal(' ON ').sql(self._on) - return ctx - - -class ValuesList(_HashableSource, BaseTable): - def __init__(self, values, columns=None, alias=None): - self._values = values - self._columns = columns - super(ValuesList, self).__init__(alias=alias) - - def _get_hash(self): - return hash((self.__class__, id(self._values), self._alias)) - - @Node.copy - def columns(self, *names): - self._columns = names - - def __sql__(self, ctx): - if self._alias: - ctx.alias_manager[self] = self._alias - - if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: - with ctx(parentheses=not ctx.parentheses): - ctx = (ctx - .literal('VALUES ') - .sql(CommaNodeList([ - EnclosedNodeList(row) for row in self._values]))) - - if ctx.scope == SCOPE_SOURCE: - ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) - if self._columns: - entities = [Entity(c) for c in self._columns] - ctx.sql(EnclosedNodeList(entities)) - else: - ctx.sql(Entity(ctx.alias_manager[self])) - - return ctx - - -class CTE(_HashableSource, Source): - def __init__(self, name, query, recursive=False, columns=None): - self._alias = name - self._query = query - self._recursive = recursive - if columns is not None: - columns = [Entity(c) if isinstance(c, basestring) else c - for c in columns] - self._columns = columns - query._cte_list = () - super(CTE, self).__init__(alias=name) - - def select_from(self, *columns): - if not columns: - raise ValueError('select_from() must specify one or more columns ' - 'from the CTE to select.') - - query = (Select((self,), columns) - .with_cte(self) - .bind(self._query._database)) - try: - query = query.objects(self._query.model) - except AttributeError: - pass - return query - - def _get_hash(self): - return hash((self.__class__, self._alias, id(self._query))) - - def union_all(self, rhs): - clone = self._query.clone() - return CTE(self._alias, clone + rhs, self._recursive, self._columns) - __add__ = union_all - - def __sql__(self, ctx): - if ctx.scope != SCOPE_CTE: - return ctx.sql(Entity(self._alias)) - - with ctx.push_alias(): - ctx.alias_manager[self] = self._alias - ctx.sql(Entity(self._alias)) - - if self._columns: - ctx.literal(' ').sql(EnclosedNodeList(self._columns)) - ctx.literal(' AS ') - with ctx.scope_normal(parentheses=True): - ctx.sql(self._query) - return ctx - - -class ColumnBase(Node): - def alias(self, alias): - if alias: - return Alias(self, alias) - return self - - def unalias(self): - return self - - def cast(self, as_type): - return Cast(self, as_type) - - def asc(self, collation=None, nulls=None): - return Asc(self, collation=collation, nulls=nulls) - __pos__ = asc - - def desc(self, collation=None, nulls=None): - return Desc(self, collation=collation, nulls=nulls) - __neg__ = desc - - def __invert__(self): - return Negated(self) - - def _e(op, inv=False): - """ - Lightweight factory which returns a method that builds an Expression - consisting of the left-hand and right-hand operands, using `op`. - """ - def inner(self, rhs): - if inv: - return Expression(rhs, op, self) - return Expression(self, op, rhs) - return inner - __and__ = _e(OP.AND) - __or__ = _e(OP.OR) - - __add__ = _e(OP.ADD) - __sub__ = _e(OP.SUB) - __mul__ = _e(OP.MUL) - __div__ = __truediv__ = _e(OP.DIV) - __xor__ = _e(OP.XOR) - __radd__ = _e(OP.ADD, inv=True) - __rsub__ = _e(OP.SUB, inv=True) - __rmul__ = _e(OP.MUL, inv=True) - __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) - __rand__ = _e(OP.AND, inv=True) - __ror__ = _e(OP.OR, inv=True) - __rxor__ = _e(OP.XOR, inv=True) - - def __eq__(self, rhs): - op = OP.IS if rhs is None else OP.EQ - return Expression(self, op, rhs) - def __ne__(self, rhs): - op = OP.IS_NOT if rhs is None else OP.NE - return Expression(self, op, rhs) - - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lshift__ = _e(OP.IN) - __rshift__ = _e(OP.IS) - __mod__ = _e(OP.LIKE) - __pow__ = _e(OP.ILIKE) - - bin_and = _e(OP.BIN_AND) - bin_or = _e(OP.BIN_OR) - in_ = _e(OP.IN) - not_in = _e(OP.NOT_IN) - regexp = _e(OP.REGEXP) - - # Special expressions. - def is_null(self, is_null=True): - op = OP.IS if is_null else OP.IS_NOT - return Expression(self, op, None) - def contains(self, rhs): - if isinstance(rhs, Node): - rhs = Expression('%', OP.CONCAT, - Expression(rhs, OP.CONCAT, '%')) - else: - rhs = '%%%s%%' % rhs - return Expression(self, OP.ILIKE, rhs) - def startswith(self, rhs): - if isinstance(rhs, Node): - rhs = Expression(rhs, OP.CONCAT, '%') - else: - rhs = '%s%%' % rhs - return Expression(self, OP.ILIKE, rhs) - def endswith(self, rhs): - if isinstance(rhs, Node): - rhs = Expression('%', OP.CONCAT, rhs) - else: - rhs = '%%%s' % rhs - return Expression(self, OP.ILIKE, rhs) - def between(self, lo, hi): - return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) - def concat(self, rhs): - return StringExpression(self, OP.CONCAT, rhs) - def regexp(self, rhs): - return Expression(self, OP.REGEXP, rhs) - def iregexp(self, rhs): - return Expression(self, OP.IREGEXP, rhs) - def __getitem__(self, item): - if isinstance(item, slice): - if item.start is None or item.stop is None: - raise ValueError('BETWEEN range must have both a start- and ' - 'end-point.') - return self.between(item.start, item.stop) - return self == item - - def distinct(self): - return NodeList((SQL('DISTINCT'), self)) - - def collate(self, collation): - return NodeList((self, SQL('COLLATE %s' % collation))) - - def get_sort_key(self, ctx): - return () - - -class Column(ColumnBase): - def __init__(self, source, name): - self.source = source - self.name = name - - def get_sort_key(self, ctx): - if ctx.scope == SCOPE_VALUES: - return (self.name,) - else: - return self.source.get_sort_key(ctx) + (self.name,) - - def __hash__(self): - return hash((self.source, self.name)) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - return ctx.sql(Entity(self.name)) - else: - with ctx.scope_column(): - return ctx.sql(self.source).literal('.').sql(Entity(self.name)) - - -class WrappedNode(ColumnBase): - def __init__(self, node): - self.node = node - self._coerce = getattr(node, '_coerce', True) - - def is_alias(self): - return self.node.is_alias() - - def unwrap(self): - return self.node.unwrap() - - -class EntityFactory(object): - __slots__ = ('node',) - def __init__(self, node): - self.node = node - def __getattr__(self, attr): - return Entity(self.node, attr) - - -class _DynamicEntity(object): - __slots__ = () - def __get__(self, instance, instance_type=None): - if instance is not None: - return EntityFactory(instance._alias) # Implements __getattr__(). - return self - - -class Alias(WrappedNode): - c = _DynamicEntity() - - def __init__(self, node, alias): - super(Alias, self).__init__(node) - self._alias = alias - - def __hash__(self): - return hash(self._alias) - - def alias(self, alias=None): - if alias is None: - return self.node - else: - return Alias(self.node, alias) - - def unalias(self): - return self.node - - def is_alias(self): - return True - - def __sql__(self, ctx): - if ctx.scope == SCOPE_SOURCE: - return (ctx - .sql(self.node) - .literal(' AS ') - .sql(Entity(self._alias))) - else: - return ctx.sql(Entity(self._alias)) - - -class Negated(WrappedNode): - def __invert__(self): - return self.node - - def __sql__(self, ctx): - return ctx.literal('NOT ').sql(self.node) - - -class BitwiseMixin(object): - def __and__(self, other): - return self.bin_and(other) - - def __or__(self, other): - return self.bin_or(other) - - def __sub__(self, other): - return self.bin_and(other.bin_negated()) - - def __invert__(self): - return BitwiseNegated(self) - - -class BitwiseNegated(BitwiseMixin, WrappedNode): - def __invert__(self): - return self.node - - def __sql__(self, ctx): - if ctx.state.operations: - op_sql = ctx.state.operations.get(self.op, self.op) - else: - op_sql = self.op - return ctx.literal(op_sql).sql(self.node) - - -class Value(ColumnBase): - _multi_types = (list, tuple, frozenset, set) - - def __init__(self, value, converter=None, unpack=True): - self.value = value - self.converter = converter - self.multi = isinstance(self.value, self._multi_types) and unpack - if self.multi: - self.values = [] - for item in self.value: - if isinstance(item, Node): - self.values.append(item) - else: - self.values.append(Value(item, self.converter)) - - def __sql__(self, ctx): - if self.multi: - # For multi-part values (e.g. lists of IDs). - return ctx.sql(EnclosedNodeList(self.values)) - - return ctx.value(self.value, self.converter) - - -def AsIs(value): - return Value(value, unpack=False) - - -class Cast(WrappedNode): - def __init__(self, node, cast): - super(Cast, self).__init__(node) - self._cast = cast - self._coerce = False - - def __sql__(self, ctx): - return (ctx - .literal('CAST(') - .sql(self.node) - .literal(' AS %s)' % self._cast)) - - -class Ordering(WrappedNode): - def __init__(self, node, direction, collation=None, nulls=None): - super(Ordering, self).__init__(node) - self.direction = direction - self.collation = collation - self.nulls = nulls - if nulls and nulls.lower() not in ('first', 'last'): - raise ValueError('Ordering nulls= parameter must be "first" or ' - '"last", got: %s' % nulls) - - def collate(self, collation=None): - return Ordering(self.node, self.direction, collation) - - def _null_ordering_case(self, nulls): - if nulls.lower() == 'last': - ifnull, notnull = 1, 0 - elif nulls.lower() == 'first': - ifnull, notnull = 0, 1 - else: - raise ValueError('unsupported value for nulls= ordering.') - return Case(None, ((self.node.is_null(), ifnull),), notnull) - - def __sql__(self, ctx): - if self.nulls and not ctx.state.nulls_ordering: - ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') - - ctx.sql(self.node).literal(' %s' % self.direction) - if self.collation: - ctx.literal(' COLLATE %s' % self.collation) - if self.nulls and ctx.state.nulls_ordering: - ctx.literal(' NULLS %s' % self.nulls) - return ctx - - -def Asc(node, collation=None, nulls=None): - return Ordering(node, 'ASC', collation, nulls) - - -def Desc(node, collation=None, nulls=None): - return Ordering(node, 'DESC', collation, nulls) - - -class Expression(ColumnBase): - def __init__(self, lhs, op, rhs, flat=False): - self.lhs = lhs - self.op = op - self.rhs = rhs - self.flat = flat - - def __sql__(self, ctx): - overrides = {'parentheses': not self.flat, 'in_expr': True} - - # First attempt to unwrap the node on the left-hand-side, so that we - # can get at the underlying Field if one is present. - node = raw_node = self.lhs - if isinstance(raw_node, WrappedNode): - node = raw_node.unwrap() - - # Set up the appropriate converter if we have a field on the left side. - if isinstance(node, Field) and raw_node._coerce: - overrides['converter'] = node.db_value - else: - overrides['converter'] = None - - if ctx.state.operations: - op_sql = ctx.state.operations.get(self.op, self.op) - else: - op_sql = self.op - - with ctx(**overrides): - # Postgresql reports an error for IN/NOT IN (), so convert to - # the equivalent boolean expression. - op_in = self.op == OP.IN or self.op == OP.NOT_IN - if op_in and ctx.as_new().parse(self.rhs)[0] == '()': - return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') - - return (ctx - .sql(self.lhs) - .literal(' %s ' % op_sql) - .sql(self.rhs)) - - -class StringExpression(Expression): - def __add__(self, rhs): - return self.concat(rhs) - def __radd__(self, lhs): - return StringExpression(lhs, OP.CONCAT, self) - - -class Entity(ColumnBase): - def __init__(self, *path): - self._path = [part.replace('"', '""') for part in path if part] - - def __getattr__(self, attr): - return Entity(*self._path + [attr]) - - def get_sort_key(self, ctx): - return tuple(self._path) - - def __hash__(self): - return hash((self.__class__.__name__, tuple(self._path))) - - def __sql__(self, ctx): - return ctx.literal(quote(self._path, ctx.state.quote or '""')) - - -class SQL(ColumnBase): - def __init__(self, sql, params=None): - self.sql = sql - self.params = params - - def __sql__(self, ctx): - ctx.literal(self.sql) - if self.params: - for param in self.params: - ctx.value(param, False, add_param=False) - return ctx - - -def Check(constraint): - return SQL('CHECK (%s)' % constraint) - - -class Function(ColumnBase): - def __init__(self, name, arguments, coerce=True, python_value=None): - self.name = name - self.arguments = arguments - self._filter = None - self._python_value = python_value - if name and name.lower() in ('sum', 'count', 'cast'): - self._coerce = False - else: - self._coerce = coerce - - def __getattr__(self, attr): - def decorator(*args, **kwargs): - return Function(attr, args, **kwargs) - return decorator - - @Node.copy - def filter(self, where=None): - self._filter = where - - @Node.copy - def python_value(self, func=None): - self._python_value = func - - def over(self, partition_by=None, order_by=None, start=None, end=None, - frame_type=None, window=None, exclude=None): - if isinstance(partition_by, Window) and window is None: - window = partition_by - - if window is not None: - node = WindowAlias(window) - else: - node = Window(partition_by=partition_by, order_by=order_by, - start=start, end=end, frame_type=frame_type, - exclude=exclude, _inline=True) - return NodeList((self, SQL('OVER'), node)) - - def __sql__(self, ctx): - ctx.literal(self.name) - if not len(self.arguments): - ctx.literal('()') - else: - with ctx(in_function=True, function_arg_count=len(self.arguments)): - ctx.sql(EnclosedNodeList([ - (argument if isinstance(argument, Node) - else Value(argument, False)) - for argument in self.arguments])) - - if self._filter: - ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') - return ctx - - -fn = Function(None, None) - - -class Window(Node): - # Frame start/end and frame exclusion. - CURRENT_ROW = SQL('CURRENT ROW') - GROUP = SQL('GROUP') - TIES = SQL('TIES') - NO_OTHERS = SQL('NO OTHERS') - - # Frame types. - GROUPS = 'GROUPS' - RANGE = 'RANGE' - ROWS = 'ROWS' - - def __init__(self, partition_by=None, order_by=None, start=None, end=None, - frame_type=None, extends=None, exclude=None, alias=None, - _inline=False): - super(Window, self).__init__() - if start is not None and not isinstance(start, SQL): - start = SQL(start) - if end is not None and not isinstance(end, SQL): - end = SQL(end) - - self.partition_by = ensure_tuple(partition_by) - self.order_by = ensure_tuple(order_by) - self.start = start - self.end = end - if self.start is None and self.end is not None: - raise ValueError('Cannot specify WINDOW end without start.') - self._alias = alias or 'w' - self._inline = _inline - self.frame_type = frame_type - self._extends = extends - self._exclude = exclude - - def alias(self, alias=None): - self._alias = alias or 'w' - return self - - @Node.copy - def as_range(self): - self.frame_type = Window.RANGE - - @Node.copy - def as_rows(self): - self.frame_type = Window.ROWS - - @Node.copy - def as_groups(self): - self.frame_type = Window.GROUPS - - @Node.copy - def extends(self, window=None): - self._extends = window - - @Node.copy - def exclude(self, frame_exclusion=None): - if isinstance(frame_exclusion, basestring): - frame_exclusion = SQL(frame_exclusion) - self._exclude = frame_exclusion - - @staticmethod - def following(value=None): - if value is None: - return SQL('UNBOUNDED FOLLOWING') - return SQL('%d FOLLOWING' % value) - - @staticmethod - def preceding(value=None): - if value is None: - return SQL('UNBOUNDED PRECEDING') - return SQL('%d PRECEDING' % value) - - def __sql__(self, ctx): - if ctx.scope != SCOPE_SOURCE and not self._inline: - ctx.literal(self._alias) - ctx.literal(' AS ') - - with ctx(parentheses=True): - parts = [] - if self._extends is not None: - ext = self._extends - if isinstance(ext, Window): - ext = SQL(ext._alias) - elif isinstance(ext, basestring): - ext = SQL(ext) - parts.append(ext) - if self.partition_by: - parts.extend(( - SQL('PARTITION BY'), - CommaNodeList(self.partition_by))) - if self.order_by: - parts.extend(( - SQL('ORDER BY'), - CommaNodeList(self.order_by))) - if self.start is not None and self.end is not None: - frame = self.frame_type or 'ROWS' - parts.extend(( - SQL('%s BETWEEN' % frame), - self.start, - SQL('AND'), - self.end)) - elif self.start is not None: - parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) - elif self.frame_type is not None: - parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) - if self._exclude is not None: - parts.extend((SQL('EXCLUDE'), self._exclude)) - ctx.sql(NodeList(parts)) - return ctx - - -class WindowAlias(Node): - def __init__(self, window): - self.window = window - - def alias(self, window_alias): - self.window._alias = window_alias - return self - - def __sql__(self, ctx): - return ctx.literal(self.window._alias or 'w') - - -def Case(predicate, expression_tuples, default=None): - clauses = [SQL('CASE')] - if predicate is not None: - clauses.append(predicate) - for expr, value in expression_tuples: - clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) - if default is not None: - clauses.extend((SQL('ELSE'), default)) - clauses.append(SQL('END')) - return NodeList(clauses) - - -class NodeList(ColumnBase): - def __init__(self, nodes, glue=' ', parens=False): - self.nodes = nodes - self.glue = glue - self.parens = parens - if parens and len(self.nodes) == 1: - if isinstance(self.nodes[0], Expression): - # Hack to avoid double-parentheses. - self.nodes[0].flat = True - - def __sql__(self, ctx): - n_nodes = len(self.nodes) - if n_nodes == 0: - return ctx.literal('()') if self.parens else ctx - with ctx(parentheses=self.parens): - for i in range(n_nodes - 1): - ctx.sql(self.nodes[i]) - ctx.literal(self.glue) - ctx.sql(self.nodes[n_nodes - 1]) - return ctx - - -def CommaNodeList(nodes): - return NodeList(nodes, ', ') - - -def EnclosedNodeList(nodes): - return NodeList(nodes, ', ', True) - - -class _Namespace(Node): - __slots__ = ('_name',) - def __init__(self, name): - self._name = name - def __getattr__(self, attr): - return NamespaceAttribute(self, attr) - __getitem__ = __getattr__ - -class NamespaceAttribute(ColumnBase): - def __init__(self, namespace, attribute): - self._namespace = namespace - self._attribute = attribute - - def __sql__(self, ctx): - return (ctx - .literal(self._namespace._name + '.') - .sql(Entity(self._attribute))) - -EXCLUDED = _Namespace('EXCLUDED') - - -class DQ(ColumnBase): - def __init__(self, **query): - super(DQ, self).__init__() - self.query = query - self._negated = False - - @Node.copy - def __invert__(self): - self._negated = not self._negated - - def clone(self): - node = DQ(**self.query) - node._negated = self._negated - return node - -#: Represent a row tuple. -Tuple = lambda *a: EnclosedNodeList(a) - - -class QualifiedNames(WrappedNode): - def __sql__(self, ctx): - with ctx.scope_column(): - return ctx.sql(self.node) - - -def qualify_names(node): - # Search a node heirarchy to ensure that any column-like objects are - # referenced using fully-qualified names. - if isinstance(node, Expression): - return node.__class__(qualify_names(node.lhs), node.op, - qualify_names(node.rhs), node.flat) - elif isinstance(node, ColumnBase): - return QualifiedNames(node) - return node - - -class OnConflict(Node): - def __init__(self, action=None, update=None, preserve=None, where=None, - conflict_target=None, conflict_where=None, - conflict_constraint=None): - self._action = action - self._update = update - self._preserve = ensure_tuple(preserve) - self._where = where - if conflict_target is not None and conflict_constraint is not None: - raise ValueError('only one of "conflict_target" and ' - '"conflict_constraint" may be specified.') - self._conflict_target = ensure_tuple(conflict_target) - self._conflict_where = conflict_where - self._conflict_constraint = conflict_constraint - - def get_conflict_statement(self, ctx, query): - return ctx.state.conflict_statement(self, query) - - def get_conflict_update(self, ctx, query): - return ctx.state.conflict_update(self, query) - - @Node.copy - def preserve(self, *columns): - self._preserve = columns - - @Node.copy - def update(self, _data=None, **kwargs): - if _data and kwargs and not isinstance(_data, dict): - raise ValueError('Cannot mix data with keyword arguments in the ' - 'OnConflict update method.') - _data = _data or {} - if kwargs: - _data.update(kwargs) - self._update = _data - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def conflict_target(self, *constraints): - self._conflict_constraint = None - self._conflict_target = constraints - - @Node.copy - def conflict_where(self, *expressions): - if self._conflict_where is not None: - expressions = (self._conflict_where,) + expressions - self._conflict_where = reduce(operator.and_, expressions) - - @Node.copy - def conflict_constraint(self, constraint): - self._conflict_constraint = constraint - self._conflict_target = None - - -def database_required(method): - @wraps(method) - def inner(self, database=None, *args, **kwargs): - database = self._database if database is None else database - if not database: - raise InterfaceError('Query must be bound to a database in order ' - 'to call "%s".' % method.__name__) - return method(self, database, *args, **kwargs) - return inner - -# BASE QUERY INTERFACE. - -class BaseQuery(Node): - default_row_type = ROW.DICT - - def __init__(self, _database=None, **kwargs): - self._database = _database - self._cursor_wrapper = None - self._row_type = None - self._constructor = None - super(BaseQuery, self).__init__(**kwargs) - - def bind(self, database=None): - self._database = database - return self - - def clone(self): - query = super(BaseQuery, self).clone() - query._cursor_wrapper = None - return query - - @Node.copy - def dicts(self, as_dict=True): - self._row_type = ROW.DICT if as_dict else None - return self - - @Node.copy - def tuples(self, as_tuple=True): - self._row_type = ROW.TUPLE if as_tuple else None - return self - - @Node.copy - def namedtuples(self, as_namedtuple=True): - self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None - return self - - @Node.copy - def objects(self, constructor=None): - self._row_type = ROW.CONSTRUCTOR if constructor else None - self._constructor = constructor - return self - - def _get_cursor_wrapper(self, cursor): - row_type = self._row_type or self.default_row_type - - if row_type == ROW.DICT: - return DictCursorWrapper(cursor) - elif row_type == ROW.TUPLE: - return CursorWrapper(cursor) - elif row_type == ROW.NAMED_TUPLE: - return NamedTupleCursorWrapper(cursor) - elif row_type == ROW.CONSTRUCTOR: - return ObjectCursorWrapper(cursor, self._constructor) - else: - raise ValueError('Unrecognized row type: "%s".' % row_type) - - def __sql__(self, ctx): - raise NotImplementedError - - def sql(self): - if self._database: - context = self._database.get_sql_context() - else: - context = Context() - return context.parse(self) - - @database_required - def execute(self, database): - return self._execute(database) - - def _execute(self, database): - raise NotImplementedError - - def iterator(self, database=None): - return iter(self.execute(database).iterator()) - - def _ensure_execution(self): - if not self._cursor_wrapper: - if not self._database: - raise ValueError('Query has not been executed.') - self.execute() - - def __iter__(self): - self._ensure_execution() - return iter(self._cursor_wrapper) - - def __getitem__(self, value): - self._ensure_execution() - if isinstance(value, slice): - index = value.stop - else: - index = value - if index is not None: - index = index + 1 if index >= 0 else 0 - self._cursor_wrapper.fill_cache(index) - return self._cursor_wrapper.row_cache[value] - - def __len__(self): - self._ensure_execution() - return len(self._cursor_wrapper) - - def __str__(self): - return query_to_string(self) - - -class RawQuery(BaseQuery): - def __init__(self, sql=None, params=None, **kwargs): - super(RawQuery, self).__init__(**kwargs) - self._sql = sql - self._params = params - - def __sql__(self, ctx): - ctx.literal(self._sql) - if self._params: - for param in self._params: - ctx.value(param, add_param=False) - return ctx - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - -class Query(BaseQuery): - def __init__(self, where=None, order_by=None, limit=None, offset=None, - **kwargs): - super(Query, self).__init__(**kwargs) - self._where = where - self._order_by = order_by - self._limit = limit - self._offset = offset - - self._cte_list = None - - @Node.copy - def with_cte(self, *cte_list): - self._cte_list = cte_list - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def orwhere(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.or_, expressions) - - @Node.copy - def order_by(self, *values): - self._order_by = values - - @Node.copy - def order_by_extend(self, *values): - self._order_by = ((self._order_by or ()) + values) or None - - @Node.copy - def limit(self, value=None): - self._limit = value - - @Node.copy - def offset(self, value=None): - self._offset = value - - @Node.copy - def paginate(self, page, paginate_by=20): - if page > 0: - page -= 1 - self._limit = paginate_by - self._offset = page * paginate_by - - def _apply_ordering(self, ctx): - if self._order_by: - (ctx - .literal(' ORDER BY ') - .sql(CommaNodeList(self._order_by))) - if self._limit is not None or (self._offset is not None and - ctx.state.limit_max): - ctx.literal(' LIMIT ').sql(self._limit or ctx.state.limit_max) - if self._offset is not None: - ctx.literal(' OFFSET ').sql(self._offset) - return ctx - - def __sql__(self, ctx): - if self._cte_list: - # The CTE scope is only used at the very beginning of the query, - # when we are describing the various CTEs we will be using. - recursive = any(cte._recursive for cte in self._cte_list) - - # Explicitly disable the "subquery" flag here, so as to avoid - # unnecessary parentheses around subsequent selects. - with ctx.scope_cte(subquery=False): - (ctx - .literal('WITH RECURSIVE ' if recursive else 'WITH ') - .sql(CommaNodeList(self._cte_list)) - .literal(' ')) - return ctx - - -def __compound_select__(operation, inverted=False): - def method(self, other): - if inverted: - self, other = other, self - return CompoundSelectQuery(self, operation, other) - return method - - -class SelectQuery(Query): - union_all = __add__ = __compound_select__('UNION ALL') - union = __or__ = __compound_select__('UNION') - intersect = __and__ = __compound_select__('INTERSECT') - except_ = __sub__ = __compound_select__('EXCEPT') - __radd__ = __compound_select__('UNION ALL', inverted=True) - __ror__ = __compound_select__('UNION', inverted=True) - __rand__ = __compound_select__('INTERSECT', inverted=True) - __rsub__ = __compound_select__('EXCEPT', inverted=True) - - def select_from(self, *columns): - if not columns: - raise ValueError('select_from() must specify one or more columns.') - - query = (Select((self,), columns) - .bind(self._database)) - if getattr(self, 'model', None) is not None: - # Bind to the sub-select's model type, if defined. - query = query.objects(self.model) - return query - - -class SelectBase(_HashableSource, Source, SelectQuery): - def _get_hash(self): - return hash((self.__class__, self._alias or id(self))) - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - @database_required - def peek(self, database, n=1): - rows = self.execute(database)[:n] - if rows: - return rows[0] if n == 1 else rows - - @database_required - def first(self, database, n=1): - if self._limit != n: - self._limit = n - self._cursor_wrapper = None - return self.peek(database, n=n) - - @database_required - def scalar(self, database, as_tuple=False): - row = self.tuples().peek(database) - return row[0] if row and not as_tuple else row - - @database_required - def count(self, database, clear_limit=False): - clone = self.order_by().alias('_wrapped') - if clear_limit: - clone._limit = clone._offset = None - try: - if clone._having is None and clone._group_by is None and \ - clone._windows is None and clone._distinct is None and \ - clone._simple_distinct is not True: - clone = clone.select(SQL('1')) - except AttributeError: - pass - return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) - - @database_required - def exists(self, database): - clone = self.columns(SQL('1')) - clone._limit = 1 - clone._offset = None - return bool(clone.scalar()) - - @database_required - def get(self, database): - self._cursor_wrapper = None - try: - return self.execute(database)[0] - except IndexError: - pass - - -# QUERY IMPLEMENTATIONS. - - -class CompoundSelectQuery(SelectBase): - def __init__(self, lhs, op, rhs): - super(CompoundSelectQuery, self).__init__() - self.lhs = lhs - self.op = op - self.rhs = rhs - - @property - def _returning(self): - return self.lhs._returning - - @database_required - def exists(self, database): - query = Select((self.limit(1),), (SQL('1'),)).bind(database) - return bool(query.scalar()) - - def _get_query_key(self): - return (self.lhs.get_query_key(), self.rhs.get_query_key()) - - def _wrap_parens(self, ctx, subq): - csq_setting = ctx.state.compound_select_parentheses - - if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: - return False - elif csq_setting == CSQ_PARENTHESES_ALWAYS: - return True - elif csq_setting == CSQ_PARENTHESES_UNNESTED: - if ctx.state.in_expr or ctx.state.in_function: - # If this compound select query is being used inside an - # expression, e.g., an IN or EXISTS(). - return False - - # If the query on the left or right is itself a compound select - # query, then we do not apply parentheses. However, if it is a - # regular SELECT query, we will apply parentheses. - return not isinstance(subq, CompoundSelectQuery) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_COLUMN: - return self.apply_column(ctx) - - outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) - with ctx(parentheses=outer_parens): - # Should the left-hand query be wrapped in parentheses? - lhs_parens = self._wrap_parens(ctx, self.lhs) - with ctx.scope_normal(parentheses=lhs_parens, subquery=False): - ctx.sql(self.lhs) - ctx.literal(' %s ' % self.op) - with ctx.push_alias(): - # Should the right-hand query be wrapped in parentheses? - rhs_parens = self._wrap_parens(ctx, self.rhs) - with ctx.scope_normal(parentheses=rhs_parens, subquery=False): - ctx.sql(self.rhs) - - # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that - # entity names are not fully-qualified. This is a bit of a hack, as - # we're relying on the logic in Column.__sql__() to not fully - # qualify column names. - with ctx.scope_values(): - self._apply_ordering(ctx) - - return self.apply_alias(ctx) - - -class Select(SelectBase): - def __init__(self, from_list=None, columns=None, group_by=None, - having=None, distinct=None, windows=None, for_update=None, - **kwargs): - super(Select, self).__init__(**kwargs) - self._from_list = (list(from_list) if isinstance(from_list, tuple) - else from_list) or [] - self._returning = columns - self._group_by = group_by - self._having = having - self._windows = None - self._for_update = 'FOR UPDATE' if for_update is True else for_update - - self._distinct = self._simple_distinct = None - if distinct: - if isinstance(distinct, bool): - self._simple_distinct = distinct - else: - self._distinct = distinct - - self._cursor_wrapper = None - - def clone(self): - clone = super(Select, self).clone() - if clone._from_list: - clone._from_list = list(clone._from_list) - return clone - - @Node.copy - def columns(self, *columns, **kwargs): - self._returning = columns - select = columns - - @Node.copy - def select_extend(self, *columns): - self._returning = tuple(self._returning) + columns - - @Node.copy - def from_(self, *sources): - self._from_list = list(sources) - - @Node.copy - def join(self, dest, join_type='INNER', on=None): - if not self._from_list: - raise ValueError('No sources to join on.') - item = self._from_list.pop() - self._from_list.append(Join(item, dest, join_type, on)) - - @Node.copy - def group_by(self, *columns): - grouping = [] - for column in columns: - if isinstance(column, Table): - if not column._columns: - raise ValueError('Cannot pass a table to group_by() that ' - 'does not have columns explicitly ' - 'declared.') - grouping.extend([getattr(column, col_name) - for col_name in column._columns]) - else: - grouping.append(column) - self._group_by = grouping - - def group_by_extend(self, *values): - """@Node.copy used from group_by() call""" - group_by = tuple(self._group_by or ()) + values - return self.group_by(*group_by) - - @Node.copy - def having(self, *expressions): - if self._having is not None: - expressions = (self._having,) + expressions - self._having = reduce(operator.and_, expressions) - - @Node.copy - def distinct(self, *columns): - if len(columns) == 1 and (columns[0] is True or columns[0] is False): - self._simple_distinct = columns[0] - else: - self._simple_distinct = False - self._distinct = columns - - @Node.copy - def window(self, *windows): - self._windows = windows if windows else None - - @Node.copy - def for_update(self, for_update=True): - self._for_update = 'FOR UPDATE' if for_update is True else for_update - - def _get_query_key(self): - return self._alias - - def __sql_selection__(self, ctx, is_subquery=False): - return ctx.sql(CommaNodeList(self._returning)) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_COLUMN: - return self.apply_column(ctx) - - is_subquery = ctx.subquery - state = { - 'converter': None, - 'in_function': False, - 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), - 'subquery': True, - } - if ctx.state.in_function and ctx.state.function_arg_count == 1: - state['parentheses'] = False - - with ctx.scope_normal(**state): - # Defer calling parent SQL until here. This ensures that any CTEs - # for this query will be properly nested if this query is a - # sub-select or is used in an expression. See GH#1809 for example. - super(Select, self).__sql__(ctx) - - ctx.literal('SELECT ') - if self._simple_distinct or self._distinct is not None: - ctx.literal('DISTINCT ') - if self._distinct: - (ctx - .literal('ON ') - .sql(EnclosedNodeList(self._distinct)) - .literal(' ')) - - with ctx.scope_source(): - ctx = self.__sql_selection__(ctx, is_subquery) - - if self._from_list: - with ctx.scope_source(parentheses=False): - ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) - - if self._where is not None: - ctx.literal(' WHERE ').sql(self._where) - - if self._group_by: - ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) - - if self._having is not None: - ctx.literal(' HAVING ').sql(self._having) - - if self._windows is not None: - ctx.literal(' WINDOW ') - ctx.sql(CommaNodeList(self._windows)) - - # Apply ORDER BY, LIMIT, OFFSET. - self._apply_ordering(ctx) - - if self._for_update: - if not ctx.state.for_update: - raise ValueError('FOR UPDATE specified but not supported ' - 'by database.') - ctx.literal(' ') - ctx.sql(SQL(self._for_update)) - - # If the subquery is inside a function -or- we are evaluating a - # subquery on either side of an expression w/o an explicit alias, do - # not generate an alias + AS clause. - if ctx.state.in_function or (ctx.state.in_expr and - self._alias is None): - return ctx - - return self.apply_alias(ctx) - - -class _WriteQuery(Query): - def __init__(self, table, returning=None, **kwargs): - self.table = table - self._returning = returning - self._return_cursor = True if returning else False - super(_WriteQuery, self).__init__(**kwargs) - - @Node.copy - def returning(self, *returning): - self._returning = returning - self._return_cursor = True if returning else False - - def apply_returning(self, ctx): - if self._returning: - with ctx.scope_normal(): - ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) - return ctx - - def _execute(self, database): - if self._returning: - cursor = self.execute_returning(database) - else: - cursor = database.execute(self) - return self.handle_result(database, cursor) - - def execute_returning(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self) - self._cursor_wrapper = self._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - def handle_result(self, database, cursor): - if self._return_cursor: - return cursor - return database.rows_affected(cursor) - - def _set_table_alias(self, ctx): - ctx.alias_manager[self.table] = self.table.__name__ - - def __sql__(self, ctx): - super(_WriteQuery, self).__sql__(ctx) - # We explicitly set the table alias to the table's name, which ensures - # that if a sub-select references a column on the outer table, we won't - # assign it a new alias (e.g. t2) but will refer to it as table.column. - self._set_table_alias(ctx) - return ctx - - -class Update(_WriteQuery): - def __init__(self, table, update=None, **kwargs): - super(Update, self).__init__(table, **kwargs) - self._update = update - self._from = None - - @Node.copy - def from_(self, *sources): - self._from = sources - - def __sql__(self, ctx): - super(Update, self).__sql__(ctx) - - with ctx.scope_values(subquery=True): - ctx.literal('UPDATE ') - - expressions = [] - for k, v in sorted(self._update.items(), key=ctx.column_sort_key): - if not isinstance(v, Node): - converter = k.db_value if isinstance(k, Field) else None - v = Value(v, converter=converter, unpack=False) - if not isinstance(v, Value): - v = qualify_names(v) - expressions.append(NodeList((k, SQL('='), v))) - - (ctx - .sql(self.table) - .literal(' SET ') - .sql(CommaNodeList(expressions))) - - if self._from: - with ctx.scope_source(parentheses=False): - ctx.literal(' FROM ').sql(CommaNodeList(self._from)) - - if self._where: - with ctx.scope_normal(): - ctx.literal(' WHERE ').sql(self._where) - self._apply_ordering(ctx) - return self.apply_returning(ctx) - - -class Insert(_WriteQuery): - SIMPLE = 0 - QUERY = 1 - MULTI = 2 - class DefaultValuesException(Exception): pass - - def __init__(self, table, insert=None, columns=None, on_conflict=None, - **kwargs): - super(Insert, self).__init__(table, **kwargs) - self._insert = insert - self._columns = columns - self._on_conflict = on_conflict - self._query_type = None - - def where(self, *expressions): - raise NotImplementedError('INSERT queries cannot have a WHERE clause.') - - @Node.copy - def on_conflict_ignore(self, ignore=True): - self._on_conflict = OnConflict('IGNORE') if ignore else None - - @Node.copy - def on_conflict_replace(self, replace=True): - self._on_conflict = OnConflict('REPLACE') if replace else None - - @Node.copy - def on_conflict(self, *args, **kwargs): - self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) - else None) - - def _simple_insert(self, ctx): - if not self._insert: - raise self.DefaultValuesException('Error: no data to insert.') - return self._generate_insert((self._insert,), ctx) - - def get_default_data(self): - return {} - - def get_default_columns(self): - if self.table._columns: - return [getattr(self.table, col) for col in self.table._columns - if col != self.table._primary_key] - - def _generate_insert(self, insert, ctx): - rows_iter = iter(insert) - columns = self._columns - - # Load and organize column defaults (if provided). - defaults = self.get_default_data() - - # First figure out what columns are being inserted (if they weren't - # specified explicitly). Resulting columns are normalized and ordered. - if not columns: - try: - row = next(rows_iter) - except StopIteration: - raise self.DefaultValuesException('Error: no rows to insert.') - - if not isinstance(row, Mapping): - columns = self.get_default_columns() - if columns is None: - raise ValueError('Bulk insert must specify columns.') - else: - # Infer column names from the dict of data being inserted. - accum = [] - for column in row: - if isinstance(column, basestring): - column = getattr(self.table, column) - accum.append(column) - - # Add any columns present in the default data that are not - # accounted for by the dictionary of row data. - column_set = set(accum) - for col in (set(defaults) - column_set): - accum.append(col) - - columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) - rows_iter = itertools.chain(iter((row,)), rows_iter) - else: - clean_columns = [] - seen = set() - for column in columns: - if isinstance(column, basestring): - column_obj = getattr(self.table, column) - else: - column_obj = column - clean_columns.append(column_obj) - seen.add(column_obj) - - columns = clean_columns - for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): - if col not in seen: - columns.append(col) - - value_lookups = {} - for column in columns: - lookups = [column, column.name] - if isinstance(column, Field) and column.name != column.column_name: - lookups.append(column.column_name) - value_lookups[column] = lookups - - ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') - columns_converters = [ - (column, column.db_value if isinstance(column, Field) else None) - for column in columns] - - all_values = [] - for row in rows_iter: - values = [] - is_dict = isinstance(row, Mapping) - for i, (column, converter) in enumerate(columns_converters): - try: - if is_dict: - # The logic is a bit convoluted, but in order to be - # flexible in what we accept (dict keyed by - # column/field, field name, or underlying column name), - # we try accessing the row data dict using each - # possible key. If no match is found, throw an error. - for lookup in value_lookups[column]: - try: - val = row[lookup] - except KeyError: pass - else: break - else: - raise KeyError - else: - val = row[i] - except (KeyError, IndexError): - if column in defaults: - val = defaults[column] - if callable_(val): - val = val() - else: - raise ValueError('Missing value for %s.' % column.name) - - if not isinstance(val, Node): - val = Value(val, converter=converter, unpack=False) - values.append(val) - - all_values.append(EnclosedNodeList(values)) - - if not all_values: - raise self.DefaultValuesException('Error: no data to insert.') - - with ctx.scope_values(subquery=True): - return ctx.sql(CommaNodeList(all_values)) - - def _query_insert(self, ctx): - return (ctx - .sql(EnclosedNodeList(self._columns)) - .literal(' ') - .sql(self._insert)) - - def _default_values(self, ctx): - if not self._database: - return ctx.literal('DEFAULT VALUES') - return self._database.default_values_insert(ctx) - - def __sql__(self, ctx): - super(Insert, self).__sql__(ctx) - with ctx.scope_values(): - stmt = None - if self._on_conflict is not None: - stmt = self._on_conflict.get_conflict_statement(ctx, self) - - (ctx - .sql(stmt or SQL('INSERT')) - .literal(' INTO ') - .sql(self.table) - .literal(' ')) - - if isinstance(self._insert, Mapping) and not self._columns: - try: - self._simple_insert(ctx) - except self.DefaultValuesException: - self._default_values(ctx) - self._query_type = Insert.SIMPLE - elif isinstance(self._insert, (SelectQuery, SQL)): - self._query_insert(ctx) - self._query_type = Insert.QUERY - else: - self._generate_insert(self._insert, ctx) - self._query_type = Insert.MULTI - - if self._on_conflict is not None: - update = self._on_conflict.get_conflict_update(ctx, self) - if update is not None: - ctx.literal(' ').sql(update) - - return self.apply_returning(ctx) - - def _execute(self, database): - if self._returning is None and database.returning_clause \ - and self.table._primary_key: - self._returning = (self.table._primary_key,) - try: - return super(Insert, self)._execute(database) - except self.DefaultValuesException: - pass - - def handle_result(self, database, cursor): - if self._return_cursor: - return cursor - return database.last_insert_id(cursor, self._query_type) - - -class Delete(_WriteQuery): - def __sql__(self, ctx): - super(Delete, self).__sql__(ctx) - - with ctx.scope_values(subquery=True): - ctx.literal('DELETE FROM ').sql(self.table) - if self._where is not None: - with ctx.scope_normal(): - ctx.literal(' WHERE ').sql(self._where) - - self._apply_ordering(ctx) - return self.apply_returning(ctx) - - -class Index(Node): - def __init__(self, name, table, expressions, unique=False, safe=False, - where=None, using=None): - self._name = name - self._table = Entity(table) if not isinstance(table, Table) else table - self._expressions = expressions - self._where = where - self._unique = unique - self._safe = safe - self._using = using - - @Node.copy - def safe(self, _safe=True): - self._safe = _safe - - @Node.copy - def where(self, *expressions): - if self._where is not None: - expressions = (self._where,) + expressions - self._where = reduce(operator.and_, expressions) - - @Node.copy - def using(self, _using=None): - self._using = _using - - def __sql__(self, ctx): - statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' - with ctx.scope_values(subquery=True): - ctx.literal(statement) - if self._safe: - ctx.literal('IF NOT EXISTS ') - - # Sqlite uses CREATE INDEX . ON , whereas most - # others use: CREATE INDEX ON .
. - if ctx.state.index_schema_prefix and \ - isinstance(self._table, Table) and self._table._schema: - index_name = Entity(self._table._schema, self._name) - table_name = Entity(self._table.__name__) - else: - index_name = Entity(self._name) - table_name = self._table - - (ctx - .sql(index_name) - .literal(' ON ') - .sql(table_name) - .literal(' ')) - if self._using is not None: - ctx.literal('USING %s ' % self._using) - - ctx.sql(EnclosedNodeList([ - SQL(expr) if isinstance(expr, basestring) else expr - for expr in self._expressions])) - if self._where is not None: - ctx.literal(' WHERE ').sql(self._where) - - return ctx - - -class ModelIndex(Index): - def __init__(self, model, fields, unique=False, safe=True, where=None, - using=None, name=None): - self._model = model - if name is None: - name = self._generate_name_from_fields(model, fields) - if using is None: - for field in fields: - if isinstance(field, Field) and hasattr(field, 'index_type'): - using = field.index_type - super(ModelIndex, self).__init__( - name=name, - table=model._meta.table, - expressions=fields, - unique=unique, - safe=safe, - where=where, - using=using) - - def _generate_name_from_fields(self, model, fields): - accum = [] - for field in fields: - if isinstance(field, basestring): - accum.append(field.split()[0]) - else: - if isinstance(field, Node) and not isinstance(field, Field): - field = field.unwrap() - if isinstance(field, Field): - accum.append(field.column_name) - - if not accum: - raise ValueError('Unable to generate a name for the index, please ' - 'explicitly specify a name.') - - clean_field_names = re.sub('[^\w]+', '', '_'.join(accum)) - meta = model._meta - prefix = meta.name if meta.legacy_table_names else meta.table_name - return _truncate_constraint_name('_'.join((prefix, clean_field_names))) - - -def _truncate_constraint_name(constraint, maxlen=64): - if len(constraint) > maxlen: - name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() - constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) - return constraint - - -# DB-API 2.0 EXCEPTIONS. - - -class PeeweeException(Exception): pass -class ImproperlyConfigured(PeeweeException): pass -class DatabaseError(PeeweeException): pass -class DataError(DatabaseError): pass -class IntegrityError(DatabaseError): pass -class InterfaceError(PeeweeException): pass -class InternalError(DatabaseError): pass -class NotSupportedError(DatabaseError): pass -class OperationalError(DatabaseError): pass -class ProgrammingError(DatabaseError): pass - - -class ExceptionWrapper(object): - __slots__ = ('exceptions',) - def __init__(self, exceptions): - self.exceptions = exceptions - def __enter__(self): pass - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - return - # psycopg2.8 shits out a million cute error types. Try to catch em all. - if pg_errors is not None and exc_type.__name__ not in self.exceptions \ - and issubclass(exc_type, pg_errors.Error): - exc_type = exc_type.__bases__[0] - if exc_type.__name__ in self.exceptions: - new_type = self.exceptions[exc_type.__name__] - exc_args = exc_value.args - reraise(new_type, new_type(*exc_args), traceback) - - -EXCEPTIONS = { - 'ConstraintError': IntegrityError, - 'DatabaseError': DatabaseError, - 'DataError': DataError, - 'IntegrityError': IntegrityError, - 'InterfaceError': InterfaceError, - 'InternalError': InternalError, - 'NotSupportedError': NotSupportedError, - 'OperationalError': OperationalError, - 'ProgrammingError': ProgrammingError} - -__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) - - -# DATABASE INTERFACE AND CONNECTION MANAGEMENT. - - -IndexMetadata = collections.namedtuple( - 'IndexMetadata', - ('name', 'sql', 'columns', 'unique', 'table')) -ColumnMetadata = collections.namedtuple( - 'ColumnMetadata', - ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) -ForeignKeyMetadata = collections.namedtuple( - 'ForeignKeyMetadata', - ('column', 'dest_table', 'dest_column', 'table')) -ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) - - -class _ConnectionState(object): - def __init__(self, **kwargs): - super(_ConnectionState, self).__init__(**kwargs) - self.reset() - - def reset(self): - self.closed = True - self.conn = None - self.ctx = [] - self.transactions = [] - - def set_connection(self, conn): - self.conn = conn - self.closed = False - self.ctx = [] - self.transactions = [] - - -class _ConnectionLocal(_ConnectionState, threading.local): pass -class _NoopLock(object): - __slots__ = () - def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): pass - - -class ConnectionContext(_callable_context_manager): - __slots__ = ('db',) - def __init__(self, db): self.db = db - def __enter__(self): - if self.db.is_closed(): - self.db.connect() - def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() - - -class Database(_callable_context_manager): - context_class = Context - field_types = {} - operations = {} - param = '?' - quote = '""' - server_version = None - - # Feature toggles. - commit_select = False - compound_select_parentheses = CSQ_PARENTHESES_NEVER - for_update = False - index_schema_prefix = False - limit_max = None - nulls_ordering = False - returning_clause = False - safe_create_index = True - safe_drop_index = True - sequences = False - truncate_table = True - - def __init__(self, database, thread_safe=True, autorollback=False, - field_types=None, operations=None, autocommit=None, - autoconnect=True, **kwargs): - self._field_types = merge_dict(FIELD, self.field_types) - self._operations = merge_dict(OP, self.operations) - if field_types: - self._field_types.update(field_types) - if operations: - self._operations.update(operations) - - self.autoconnect = autoconnect - self.autorollback = autorollback - self.thread_safe = thread_safe - if thread_safe: - self._state = _ConnectionLocal() - self._lock = threading.Lock() - else: - self._state = _ConnectionState() - self._lock = _NoopLock() - - if autocommit is not None: - __deprecated__('Peewee no longer uses the "autocommit" option, as ' - 'the semantics now require it to always be True. ' - 'Because some database-drivers also use the ' - '"autocommit" parameter, you are receiving a ' - 'warning so you may update your code and remove ' - 'the parameter, as in the future, specifying ' - 'autocommit could impact the behavior of the ' - 'database driver you are using.') - - self.connect_params = {} - self.init(database, **kwargs) - - def init(self, database, **kwargs): - if not self.is_closed(): - self.close() - self.database = database - self.connect_params.update(kwargs) - self.deferred = not bool(database) - - def __enter__(self): - if self.is_closed(): - self.connect() - ctx = self.atomic() - self._state.ctx.append(ctx) - ctx.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - ctx = self._state.ctx.pop() - try: - ctx.__exit__(exc_type, exc_val, exc_tb) - finally: - if not self._state.ctx: - self.close() - - def connection_context(self): - return ConnectionContext(self) - - def _connect(self): - raise NotImplementedError - - def connect(self, reuse_if_open=False): - with self._lock: - if self.deferred: - raise InterfaceError('Error, database must be initialized ' - 'before opening a connection.') - if not self._state.closed: - if reuse_if_open: - return False - raise OperationalError('Connection already opened.') - - self._state.reset() - with __exception_wrapper__: - self._state.set_connection(self._connect()) - if self.server_version is None: - self._set_server_version(self._state.conn) - self._initialize_connection(self._state.conn) - return True - - def _initialize_connection(self, conn): - pass - - def _set_server_version(self, conn): - self.server_version = 0 - - def close(self): - with self._lock: - if self.deferred: - raise InterfaceError('Error, database must be initialized ' - 'before opening a connection.') - if self.in_transaction(): - raise OperationalError('Attempting to close database while ' - 'transaction is open.') - is_open = not self._state.closed - try: - if is_open: - with __exception_wrapper__: - self._close(self._state.conn) - finally: - self._state.reset() - return is_open - - def _close(self, conn): - conn.close() - - def is_closed(self): - return self._state.closed - - def connection(self): - if self.is_closed(): - self.connect() - return self._state.conn - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - return self._state.conn.cursor() - - def execute_sql(self, sql, params=None, commit=SENTINEL): - logger.debug((sql, params)) - if commit is SENTINEL: - if self.in_transaction(): - commit = False - elif self.commit_select: - commit = True - else: - commit = not sql[:6].lower().startswith('select') - - with __exception_wrapper__: - cursor = self.cursor(commit) - try: - cursor.execute(sql, params or ()) - except Exception: - if self.autorollback and not self.in_transaction(): - self.rollback() - raise - else: - if commit and not self.in_transaction(): - self.commit() - return cursor - - def execute(self, query, commit=SENTINEL, **context_options): - ctx = self.get_sql_context(**context_options) - sql, params = ctx.sql(query).query() - return self.execute_sql(sql, params, commit=commit) - - def get_context_options(self): - return { - 'field_types': self._field_types, - 'operations': self._operations, - 'param': self.param, - 'quote': self.quote, - 'compound_select_parentheses': self.compound_select_parentheses, - 'conflict_statement': self.conflict_statement, - 'conflict_update': self.conflict_update, - 'for_update': self.for_update, - 'index_schema_prefix': self.index_schema_prefix, - 'limit_max': self.limit_max, - 'nulls_ordering': self.nulls_ordering, - } - - def get_sql_context(self, **context_options): - context = self.get_context_options() - if context_options: - context.update(context_options) - return self.context_class(**context) - - def conflict_statement(self, on_conflict, query): - raise NotImplementedError - - def conflict_update(self, on_conflict, query): - raise NotImplementedError - - def _build_on_conflict_update(self, on_conflict, query): - if on_conflict._conflict_target: - stmt = SQL('ON CONFLICT') - target = EnclosedNodeList([ - Entity(col) if isinstance(col, basestring) else col - for col in on_conflict._conflict_target]) - if on_conflict._conflict_where is not None: - target = NodeList([target, SQL('WHERE'), - on_conflict._conflict_where]) - else: - stmt = SQL('ON CONFLICT ON CONSTRAINT') - target = on_conflict._conflict_constraint - if isinstance(target, basestring): - target = Entity(target) - - updates = [] - if on_conflict._preserve: - for column in on_conflict._preserve: - excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), - glue='.') - expression = NodeList((ensure_entity(column), SQL('='), - excluded)) - updates.append(expression) - - if on_conflict._update: - for k, v in on_conflict._update.items(): - if not isinstance(v, Node): - # Attempt to resolve string field-names to their respective - # field object, to apply data-type conversions. - if isinstance(k, basestring): - k = getattr(query.table, k) - converter = k.db_value if isinstance(k, Field) else None - v = Value(v, converter=converter, unpack=False) - else: - v = QualifiedNames(v) - updates.append(NodeList((ensure_entity(k), SQL('='), v))) - - parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] - if on_conflict._where: - parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) - - return NodeList(parts) - - def last_insert_id(self, cursor, query_type=None): - return cursor.lastrowid - - def rows_affected(self, cursor): - return cursor.rowcount - - def default_values_insert(self, ctx): - return ctx.literal('DEFAULT VALUES') - - def session_start(self): - with self._lock: - return self.transaction().__enter__() - - def session_commit(self): - with self._lock: - try: - txn = self.pop_transaction() - except IndexError: - return False - txn.commit(begin=self.in_transaction()) - return True - - def session_rollback(self): - with self._lock: - try: - txn = self.pop_transaction() - except IndexError: - return False - txn.rollback(begin=self.in_transaction()) - return True - - def in_transaction(self): - return bool(self._state.transactions) - - def push_transaction(self, transaction): - self._state.transactions.append(transaction) - - def pop_transaction(self): - return self._state.transactions.pop() - - def transaction_depth(self): - return len(self._state.transactions) - - def top_transaction(self): - if self._state.transactions: - return self._state.transactions[-1] - - def atomic(self): - return _atomic(self) - - def manual_commit(self): - return _manual(self) - - def transaction(self): - return _transaction(self) - - def savepoint(self): - return _savepoint(self) - - def begin(self): - if self.is_closed(): - self.connect() - - def commit(self): - return self._state.conn.commit() - - def rollback(self): - return self._state.conn.rollback() - - def batch_commit(self, it, n): - for group in chunked(it, n): - with self.atomic(): - for obj in group: - yield obj - - def table_exists(self, table_name, schema=None): - return table_name in self.get_tables(schema=schema) - - def get_tables(self, schema=None): - raise NotImplementedError - - def get_indexes(self, table, schema=None): - raise NotImplementedError - - def get_columns(self, table, schema=None): - raise NotImplementedError - - def get_primary_keys(self, table, schema=None): - raise NotImplementedError - - def get_foreign_keys(self, table, schema=None): - raise NotImplementedError - - def sequence_exists(self, seq): - raise NotImplementedError - - def create_tables(self, models, **options): - for model in sort_models(models): - model.create_table(**options) - - def drop_tables(self, models, **kwargs): - for model in reversed(sort_models(models)): - model.drop_table(**kwargs) - - def extract_date(self, date_part, date_field): - raise NotImplementedError - - def truncate_date(self, date_part, date_field): - raise NotImplementedError - - def to_timestamp(self, date_field): - raise NotImplementedError - - def from_timestamp(self, date_field): - raise NotImplementedError - - def random(self): - return fn.random() - - def bind(self, models, bind_refs=True, bind_backrefs=True): - for model in models: - model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) - - def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): - return _BoundModelsContext(models, self, bind_refs, bind_backrefs) - - def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) - - -def __pragma__(name): - def __get__(self): - return self.pragma(name) - def __set__(self, value): - return self.pragma(name, value) - return property(__get__, __set__) - - -class SqliteDatabase(Database): - field_types = { - 'BIGAUTO': FIELD.AUTO, - 'BIGINT': FIELD.INT, - 'BOOL': FIELD.INT, - 'DOUBLE': FIELD.FLOAT, - 'SMALLINT': FIELD.INT, - 'UUID': FIELD.TEXT} - operations = { - 'LIKE': 'GLOB', - 'ILIKE': 'LIKE'} - index_schema_prefix = True - limit_max = -1 - server_version = __sqlite_version__ - truncate_table = False - - def __init__(self, database, *args, **kwargs): - self._pragmas = kwargs.pop('pragmas', ()) - super(SqliteDatabase, self).__init__(database, *args, **kwargs) - self._aggregates = {} - self._collations = {} - self._functions = {} - self._window_functions = {} - self._table_functions = [] - self._extensions = set() - self._attached = {} - self.register_function(_sqlite_date_part, 'date_part', 2) - self.register_function(_sqlite_date_trunc, 'date_trunc', 2) - - def init(self, database, pragmas=None, timeout=5, **kwargs): - if pragmas is not None: - self._pragmas = pragmas - if isinstance(self._pragmas, dict): - self._pragmas = list(self._pragmas.items()) - self._timeout = timeout - super(SqliteDatabase, self).init(database, **kwargs) - - def _set_server_version(self, conn): - pass - - def _connect(self): - if sqlite3 is None: - raise ImproperlyConfigured('SQLite driver not installed!') - conn = sqlite3.connect(self.database, timeout=self._timeout, - isolation_level=None, **self.connect_params) - try: - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def _add_conn_hooks(self, conn): - if self._attached: - self._attach_databases(conn) - if self._pragmas: - self._set_pragmas(conn) - self._load_aggregates(conn) - self._load_collations(conn) - self._load_functions(conn) - if self.server_version >= (3, 25, 0): - self._load_window_functions(conn) - if self._table_functions: - for table_function in self._table_functions: - table_function.register(conn) - if self._extensions: - self._load_extensions(conn) - - def _set_pragmas(self, conn): - cursor = conn.cursor() - for pragma, value in self._pragmas: - cursor.execute('PRAGMA %s = %s;' % (pragma, value)) - cursor.close() - - def _attach_databases(self, conn): - cursor = conn.cursor() - for name, db in self._attached.items(): - cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) - cursor.close() - - def pragma(self, key, value=SENTINEL, permanent=False, schema=None): - if schema is not None: - key = '"%s".%s' % (schema, key) - sql = 'PRAGMA %s' % key - if value is not SENTINEL: - sql += ' = %s' % (value or 0) - if permanent: - pragmas = dict(self._pragmas or ()) - pragmas[key] = value - self._pragmas = list(pragmas.items()) - elif permanent: - raise ValueError('Cannot specify a permanent pragma without value') - row = self.execute_sql(sql).fetchone() - if row: - return row[0] - - cache_size = __pragma__('cache_size') - foreign_keys = __pragma__('foreign_keys') - journal_mode = __pragma__('journal_mode') - journal_size_limit = __pragma__('journal_size_limit') - mmap_size = __pragma__('mmap_size') - page_size = __pragma__('page_size') - read_uncommitted = __pragma__('read_uncommitted') - synchronous = __pragma__('synchronous') - wal_autocheckpoint = __pragma__('wal_autocheckpoint') - - @property - def timeout(self): - return self._timeout - - @timeout.setter - def timeout(self, seconds): - if self._timeout == seconds: - return - - self._timeout = seconds - if not self.is_closed(): - # PySQLite multiplies user timeout by 1000, but the unit of the - # timeout PRAGMA is actually milliseconds. - self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) - - def _load_aggregates(self, conn): - for name, (klass, num_params) in self._aggregates.items(): - conn.create_aggregate(name, num_params, klass) - - def _load_collations(self, conn): - for name, fn in self._collations.items(): - conn.create_collation(name, fn) - - def _load_functions(self, conn): - for name, (fn, num_params) in self._functions.items(): - conn.create_function(name, num_params, fn) - - def _load_window_functions(self, conn): - for name, (klass, num_params) in self._window_functions.items(): - conn.create_window_function(name, num_params, klass) - - def register_aggregate(self, klass, name=None, num_params=-1): - self._aggregates[name or klass.__name__.lower()] = (klass, num_params) - if not self.is_closed(): - self._load_aggregates(self.connection()) - - def aggregate(self, name=None, num_params=-1): - def decorator(klass): - self.register_aggregate(klass, name, num_params) - return klass - return decorator - - def register_collation(self, fn, name=None): - name = name or fn.__name__ - def _collation(*args): - expressions = args + (SQL('collate %s' % name),) - return NodeList(expressions) - fn.collation = _collation - self._collations[name] = fn - if not self.is_closed(): - self._load_collations(self.connection()) - - def collation(self, name=None): - def decorator(fn): - self.register_collation(fn, name) - return fn - return decorator - - def register_function(self, fn, name=None, num_params=-1): - self._functions[name or fn.__name__] = (fn, num_params) - if not self.is_closed(): - self._load_functions(self.connection()) - - def func(self, name=None, num_params=-1): - def decorator(fn): - self.register_function(fn, name, num_params) - return fn - return decorator - - def register_window_function(self, klass, name=None, num_params=-1): - name = name or klass.__name__.lower() - self._window_functions[name] = (klass, num_params) - if not self.is_closed(): - self._load_window_functions(self.connection()) - - def window_function(self, name=None, num_params=-1): - def decorator(klass): - self.register_window_function(klass, name, num_params) - return klass - return decorator - - def register_table_function(self, klass, name=None): - if name is not None: - klass.name = name - self._table_functions.append(klass) - if not self.is_closed(): - klass.register(self.connection()) - - def table_function(self, name=None): - def decorator(klass): - self.register_table_function(klass, name) - return klass - return decorator - - def unregister_aggregate(self, name): - del(self._aggregates[name]) - - def unregister_collation(self, name): - del(self._collations[name]) - - def unregister_function(self, name): - del(self._functions[name]) - - def unregister_window_function(self, name): - del(self._window_functions[name]) - - def unregister_table_function(self, name): - for idx, klass in enumerate(self._table_functions): - if klass.name == name: - break - else: - return False - self._table_functions.pop(idx) - return True - - def _load_extensions(self, conn): - conn.enable_load_extension(True) - for extension in self._extensions: - conn.load_extension(extension) - - def load_extension(self, extension): - self._extensions.add(extension) - if not self.is_closed(): - conn = self.connection() - conn.enable_load_extension(True) - conn.load_extension(extension) - - def unload_extension(self, extension): - self._extensions.remove(extension) - - def attach(self, filename, name): - if name in self._attached: - if self._attached[name] == filename: - return False - raise OperationalError('schema "%s" already attached.' % name) - - self._attached[name] = filename - if not self.is_closed(): - self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) - return True - - def detach(self, name): - if name not in self._attached: - return False - - del self._attached[name] - if not self.is_closed(): - self.execute_sql('DETACH DATABASE "%s"' % name) - return True - - def atomic(self, lock_type=None): - return _atomic(self, lock_type=lock_type) - - def transaction(self, lock_type=None): - return _transaction(self, lock_type=lock_type) - - def begin(self, lock_type=None): - statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' - self.execute_sql(statement, commit=False) - - def get_tables(self, schema=None): - schema = schema or 'main' - cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' - 'type=? ORDER BY name' % schema, ('table',)) - return [row for row, in cursor.fetchall()] - - def get_views(self, schema=None): - sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' - 'ORDER BY name') % (schema or 'main') - return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] - - def get_indexes(self, table, schema=None): - schema = schema or 'main' - query = ('SELECT name, sql FROM "%s".sqlite_master ' - 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema - cursor = self.execute_sql(query, (table, 'index')) - index_to_sql = dict(cursor.fetchall()) - - # Determine which indexes have a unique constraint. - unique_indexes = set() - cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % - (schema, table)) - for row in cursor.fetchall(): - name = row[1] - is_unique = int(row[2]) == 1 - if is_unique: - unique_indexes.add(name) - - # Retrieve the indexed columns. - index_columns = {} - for index_name in sorted(index_to_sql): - cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % - (schema, index_name)) - index_columns[index_name] = [row[2] for row in cursor.fetchall()] - - return [ - IndexMetadata( - name, - index_to_sql[name], - index_columns[name], - name in unique_indexes, - table) - for name in sorted(index_to_sql)] - - def get_columns(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % - (schema or 'main', table)) - return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) - for r in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % - (schema or 'main', table)) - return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] - - def get_foreign_keys(self, table, schema=None): - cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % - (schema or 'main', table)) - return [ForeignKeyMetadata(row[3], row[2], row[4], table) - for row in cursor.fetchall()] - - def get_binary_type(self): - return sqlite3.Binary - - def conflict_statement(self, on_conflict, query): - action = on_conflict._action.lower() if on_conflict._action else '' - if action and action not in ('nothing', 'update'): - return SQL('INSERT OR %s' % on_conflict._action.upper()) - - def conflict_update(self, oc, query): - # Sqlite prior to 3.24.0 does not support Postgres-style upsert. - if self.server_version < (3, 24, 0) and \ - any((oc._preserve, oc._update, oc._where, oc._conflict_target, - oc._conflict_constraint)): - raise ValueError('SQLite does not support specifying which values ' - 'to preserve or update.') - - action = oc._action.lower() if oc._action else '' - if action and action not in ('nothing', 'update', ''): - return - - if action == 'nothing': - return SQL('ON CONFLICT DO NOTHING') - elif not oc._update and not oc._preserve: - raise ValueError('If you are not performing any updates (or ' - 'preserving any INSERTed values), then the ' - 'conflict resolution action should be set to ' - '"NOTHING".') - elif oc._conflict_constraint: - raise ValueError('SQLite does not support specifying named ' - 'constraints for conflict resolution.') - elif not oc._conflict_target: - raise ValueError('SQLite requires that a conflict target be ' - 'specified when doing an upsert.') - - return self._build_on_conflict_update(oc, query) - - def extract_date(self, date_part, date_field): - return fn.date_part(date_part, date_field, python_value=int) - - def truncate_date(self, date_part, date_field): - return fn.date_trunc(date_part, date_field, - python_value=simple_date_time) - - def to_timestamp(self, date_field): - return fn.strftime('%s', date_field).cast('integer') - - def from_timestamp(self, date_field): - return fn.datetime(date_field, 'unixepoch') - - -class PostgresqlDatabase(Database): - field_types = { - 'AUTO': 'SERIAL', - 'BIGAUTO': 'BIGSERIAL', - 'BLOB': 'BYTEA', - 'BOOL': 'BOOLEAN', - 'DATETIME': 'TIMESTAMP', - 'DECIMAL': 'NUMERIC', - 'DOUBLE': 'DOUBLE PRECISION', - 'UUID': 'UUID', - 'UUIDB': 'BYTEA'} - operations = {'REGEXP': '~', 'IREGEXP': '~*'} - param = '%s' - - commit_select = True - compound_select_parentheses = CSQ_PARENTHESES_ALWAYS - for_update = True - nulls_ordering = True - returning_clause = True - safe_create_index = False - sequences = True - - def init(self, database, register_unicode=True, encoding=None, - isolation_level=None, **kwargs): - self._register_unicode = register_unicode - self._encoding = encoding - self._isolation_level = isolation_level - super(PostgresqlDatabase, self).init(database, **kwargs) - - def _connect(self): - if psycopg2 is None: - raise ImproperlyConfigured('Postgres driver not installed!') - conn = psycopg2.connect(database=self.database, **self.connect_params) - if self._register_unicode: - pg_extensions.register_type(pg_extensions.UNICODE, conn) - pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) - if self._encoding: - conn.set_client_encoding(self._encoding) - if self._isolation_level: - conn.set_isolation_level(self._isolation_level) - return conn - - def _set_server_version(self, conn): - self.server_version = conn.server_version - if self.server_version >= 90600: - self.safe_create_index = True - - def last_insert_id(self, cursor, query_type=None): - try: - return cursor if query_type else cursor[0][0] - except (IndexError, KeyError, TypeError): - pass - - def get_tables(self, schema=None): - query = ('SELECT tablename FROM pg_catalog.pg_tables ' - 'WHERE schemaname = %s ORDER BY tablename') - cursor = self.execute_sql(query, (schema or 'public',)) - return [table for table, in cursor.fetchall()] - - def get_views(self, schema=None): - query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' - 'WHERE schemaname = %s ORDER BY viewname') - cursor = self.execute_sql(query, (schema or 'public',)) - return [ViewMetadata(v, sql.strip()) for (v, sql) in cursor.fetchall()] - - def get_indexes(self, table, schema=None): - query = """ - SELECT - i.relname, idxs.indexdef, idx.indisunique, - array_to_string(array_agg(cols.attname), ',') - FROM pg_catalog.pg_class AS t - INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid - INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid - INNER JOIN pg_catalog.pg_indexes AS idxs ON - (idxs.tablename = t.relname AND idxs.indexname = i.relname) - LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON - (cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey)) - WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s - GROUP BY i.relname, idxs.indexdef, idx.indisunique - ORDER BY idx.indisunique DESC, i.relname;""" - cursor = self.execute_sql(query, (table, 'r', schema or 'public')) - return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table) - for row in cursor.fetchall()] - - def get_columns(self, table, schema=None): - query = """ - SELECT column_name, is_nullable, data_type, column_default - FROM information_schema.columns - WHERE table_name = %s AND table_schema = %s - ORDER BY ordinal_position""" - cursor = self.execute_sql(query, (table, schema or 'public')) - pks = set(self.get_primary_keys(table, schema)) - return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) - for name, null, dt, df in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - query = """ - SELECT kc.column_name - FROM information_schema.table_constraints AS tc - INNER JOIN information_schema.key_column_usage AS kc ON ( - tc.table_name = kc.table_name AND - tc.table_schema = kc.table_schema AND - tc.constraint_name = kc.constraint_name) - WHERE - tc.constraint_type = %s AND - tc.table_name = %s AND - tc.table_schema = %s""" - ctype = 'PRIMARY KEY' - cursor = self.execute_sql(query, (ctype, table, schema or 'public')) - return [pk for pk, in cursor.fetchall()] - - def get_foreign_keys(self, table, schema=None): - sql = """ - SELECT - kcu.column_name, ccu.table_name, ccu.column_name - FROM information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON (tc.constraint_name = kcu.constraint_name AND - tc.constraint_schema = kcu.constraint_schema) - JOIN information_schema.constraint_column_usage AS ccu - ON (ccu.constraint_name = tc.constraint_name AND - ccu.constraint_schema = tc.constraint_schema) - WHERE - tc.constraint_type = 'FOREIGN KEY' AND - tc.table_name = %s AND - tc.table_schema = %s""" - cursor = self.execute_sql(sql, (table, schema or 'public')) - return [ForeignKeyMetadata(row[0], row[1], row[2], table) - for row in cursor.fetchall()] - - def sequence_exists(self, sequence): - res = self.execute_sql(""" - SELECT COUNT(*) FROM pg_class, pg_namespace - WHERE relkind='S' - AND pg_class.relnamespace = pg_namespace.oid - AND relname=%s""", (sequence,)) - return bool(res.fetchone()[0]) - - def get_binary_type(self): - return psycopg2.Binary - - def conflict_statement(self, on_conflict, query): - return - - def conflict_update(self, oc, query): - action = oc._action.lower() if oc._action else '' - if action in ('ignore', 'nothing'): - return SQL('ON CONFLICT DO NOTHING') - elif action and action != 'update': - raise ValueError('The only supported actions for conflict ' - 'resolution with Postgresql are "ignore" or ' - '"update".') - elif not oc._update and not oc._preserve: - raise ValueError('If you are not performing any updates (or ' - 'preserving any INSERTed values), then the ' - 'conflict resolution action should be set to ' - '"IGNORE".') - elif not (oc._conflict_target or oc._conflict_constraint): - raise ValueError('Postgres requires that a conflict target be ' - 'specified when doing an upsert.') - - return self._build_on_conflict_update(oc, query) - - def extract_date(self, date_part, date_field): - return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) - - def truncate_date(self, date_part, date_field): - return fn.DATE_TRUNC(date_part, date_field) - - def to_timestamp(self, date_field): - return self.extract_date('EPOCH', date_field) - - def from_timestamp(self, date_field): - # Ironically, here, Postgres means "to the Postgresql timestamp type". - return fn.to_timestamp(date_field) - - def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) - - def set_time_zone(self, timezone): - self.execute_sql('set time zone "%s";' % timezone) - - -class MySQLDatabase(Database): - field_types = { - 'AUTO': 'INTEGER AUTO_INCREMENT', - 'BIGAUTO': 'BIGINT AUTO_INCREMENT', - 'BOOL': 'BOOL', - 'DECIMAL': 'NUMERIC', - 'DOUBLE': 'DOUBLE PRECISION', - 'FLOAT': 'FLOAT', - 'UUID': 'VARCHAR(40)', - 'UUIDB': 'VARBINARY(16)'} - operations = { - 'LIKE': 'LIKE BINARY', - 'ILIKE': 'LIKE', - 'REGEXP': 'REGEXP BINARY', - 'IREGEXP': 'REGEXP', - 'XOR': 'XOR'} - param = '%s' - quote = '``' - - commit_select = True - compound_select_parentheses = CSQ_PARENTHESES_UNNESTED - for_update = True - limit_max = 2 ** 64 - 1 - safe_create_index = False - safe_drop_index = False - sql_mode = 'PIPES_AS_CONCAT' - - def init(self, database, **kwargs): - params = { - 'charset': 'utf8', - 'sql_mode': self.sql_mode, - 'use_unicode': True} - params.update(kwargs) - if 'password' in params and mysql_passwd: - params['passwd'] = params.pop('password') - super(MySQLDatabase, self).init(database, **params) - - def _connect(self): - if mysql is None: - raise ImproperlyConfigured('MySQL driver not installed!') - conn = mysql.connect(db=self.database, **self.connect_params) - return conn - - def _set_server_version(self, conn): - try: - version_raw = conn.server_version - except AttributeError: - version_raw = conn.get_server_info() - self.server_version = self._extract_server_version(version_raw) - - def _extract_server_version(self, version): - version = version.lower() - if 'maria' in version: - match_obj = re.search(r'(1\d\.\d+\.\d+)', version) - else: - match_obj = re.search(r'(\d\.\d+\.\d+)', version) - if match_obj is not None: - return tuple(int(num) for num in match_obj.groups()[0].split('.')) - - warnings.warn('Unable to determine MySQL version: "%s"' % version) - return (0, 0, 0) # Unable to determine version! - - def default_values_insert(self, ctx): - return ctx.literal('() VALUES ()') - - def get_tables(self, schema=None): - query = ('SELECT table_name FROM information_schema.tables ' - 'WHERE table_schema = DATABASE() AND table_type != %s ' - 'ORDER BY table_name') - return [table for table, in self.execute_sql(query, ('VIEW',))] - - def get_views(self, schema=None): - query = ('SELECT table_name, view_definition ' - 'FROM information_schema.views ' - 'WHERE table_schema = DATABASE() ORDER BY table_name') - cursor = self.execute_sql(query) - return [ViewMetadata(*row) for row in cursor.fetchall()] - - def get_indexes(self, table, schema=None): - cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) - unique = set() - indexes = {} - for row in cursor.fetchall(): - if not row[1]: - unique.add(row[2]) - indexes.setdefault(row[2], []) - indexes[row[2]].append(row[4]) - return [IndexMetadata(name, None, indexes[name], name in unique, table) - for name in indexes] - - def get_columns(self, table, schema=None): - sql = """ - SELECT column_name, is_nullable, data_type, column_default - FROM information_schema.columns - WHERE table_name = %s AND table_schema = DATABASE()""" - cursor = self.execute_sql(sql, (table,)) - pks = set(self.get_primary_keys(table)) - return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) - for name, null, dt, df in cursor.fetchall()] - - def get_primary_keys(self, table, schema=None): - cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) - return [row[4] for row in - filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] - - def get_foreign_keys(self, table, schema=None): - query = """ - SELECT column_name, referenced_table_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE table_name = %s - AND table_schema = DATABASE() - AND referenced_table_name IS NOT NULL - AND referenced_column_name IS NOT NULL""" - cursor = self.execute_sql(query, (table,)) - return [ - ForeignKeyMetadata(column, dest_table, dest_column, table) - for column, dest_table, dest_column in cursor.fetchall()] - - def get_binary_type(self): - return mysql.Binary - - def conflict_statement(self, on_conflict, query): - if not on_conflict._action: return - - action = on_conflict._action.lower() - if action == 'replace': - return SQL('REPLACE') - elif action == 'ignore': - return SQL('INSERT IGNORE') - elif action != 'update': - raise ValueError('Un-supported action for conflict resolution. ' - 'MySQL supports REPLACE, IGNORE and UPDATE.') - - def conflict_update(self, on_conflict, query): - if on_conflict._where or on_conflict._conflict_target or \ - on_conflict._conflict_constraint: - raise ValueError('MySQL does not support the specification of ' - 'where clauses or conflict targets for conflict ' - 'resolution.') - - updates = [] - if on_conflict._preserve: - # Here we need to determine which function to use, which varies - # depending on the MySQL server version. MySQL and MariaDB prior to - # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". - version = self.server_version or (0,) - if version[0] == 10 and version >= (10, 3, 3): - VALUE_FN = fn.VALUE - else: - VALUE_FN = fn.VALUES - - for column in on_conflict._preserve: - entity = ensure_entity(column) - expression = NodeList(( - ensure_entity(column), - SQL('='), - VALUE_FN(entity))) - updates.append(expression) - - if on_conflict._update: - for k, v in on_conflict._update.items(): - if not isinstance(v, Node): - # Attempt to resolve string field-names to their respective - # field object, to apply data-type conversions. - if isinstance(k, basestring): - k = getattr(query.table, k) - converter = k.db_value if isinstance(k, Field) else None - v = Value(v, converter=converter, unpack=False) - updates.append(NodeList((ensure_entity(k), SQL('='), v))) - - if updates: - return NodeList((SQL('ON DUPLICATE KEY UPDATE'), - CommaNodeList(updates))) - - def extract_date(self, date_part, date_field): - return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) - - def truncate_date(self, date_part, date_field): - return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], - python_value=simple_date_time) - - def to_timestamp(self, date_field): - return fn.UNIX_TIMESTAMP(date_field) - - def from_timestamp(self, date_field): - return fn.FROM_UNIXTIME(date_field) - - def random(self): - return fn.rand() - - def get_noop_select(self, ctx): - return ctx.literal('DO 0') - - -# TRANSACTION CONTROL. - - -class _manual(_callable_context_manager): - def __init__(self, db): - self.db = db - - def __enter__(self): - top = self.db.top_transaction() - if top and not isinstance(self.db.top_transaction(), _manual): - raise ValueError('Cannot enter manual commit block while a ' - 'transaction is active.') - self.db.push_transaction(self) - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.db.pop_transaction() is not self: - raise ValueError('Transaction stack corrupted while exiting ' - 'manual commit block.') - - -class _atomic(_callable_context_manager): - def __init__(self, db, lock_type=None): - self.db = db - self._lock_type = lock_type - self._transaction_args = (lock_type,) if lock_type is not None else () - - def __enter__(self): - if self.db.transaction_depth() == 0: - self._helper = self.db.transaction(*self._transaction_args) - elif isinstance(self.db.top_transaction(), _manual): - raise ValueError('Cannot enter atomic commit block while in ' - 'manual commit mode.') - else: - self._helper = self.db.savepoint() - return self._helper.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._helper.__exit__(exc_type, exc_val, exc_tb) - - -class _transaction(_callable_context_manager): - def __init__(self, db, lock_type=None): - self.db = db - self._lock_type = lock_type - - def _begin(self): - if self._lock_type: - self.db.begin(self._lock_type) - else: - self.db.begin() - - def commit(self, begin=True): - self.db.commit() - if begin: - self._begin() - - def rollback(self, begin=True): - self.db.rollback() - if begin: - self._begin() - - def __enter__(self): - if self.db.transaction_depth() == 0: - self._begin() - self.db.push_transaction(self) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - if exc_type: - self.rollback(False) - elif self.db.transaction_depth() == 1: - try: - self.commit(False) - except: - self.rollback(False) - raise - finally: - self.db.pop_transaction() - - -class _savepoint(_callable_context_manager): - def __init__(self, db, sid=None): - self.db = db - self.sid = sid or 's' + uuid.uuid4().hex - self.quoted_sid = self.sid.join(self.db.quote) - - def _begin(self): - self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) - - def commit(self, begin=True): - self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) - if begin: self._begin() - - def rollback(self): - self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) - - def __enter__(self): - self._begin() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type: - self.rollback() - else: - try: - self.commit(begin=False) - except: - self.rollback() - raise - - -# CURSOR REPRESENTATIONS. - - -class CursorWrapper(object): - def __init__(self, cursor): - self.cursor = cursor - self.count = 0 - self.index = 0 - self.initialized = False - self.populated = False - self.row_cache = [] - - def __iter__(self): - if self.populated: - return iter(self.row_cache) - return ResultIterator(self) - - def __getitem__(self, item): - if isinstance(item, slice): - stop = item.stop - if stop is None or stop < 0: - self.fill_cache() - else: - self.fill_cache(stop) - return self.row_cache[item] - elif isinstance(item, int): - self.fill_cache(item if item > 0 else 0) - return self.row_cache[item] - else: - raise ValueError('CursorWrapper only supports integer and slice ' - 'indexes.') - - def __len__(self): - self.fill_cache() - return self.count - - def initialize(self): - pass - - def iterate(self, cache=True): - row = self.cursor.fetchone() - if row is None: - self.populated = True - self.cursor.close() - raise StopIteration - elif not self.initialized: - self.initialize() # Lazy initialization. - self.initialized = True - self.count += 1 - result = self.process_row(row) - if cache: - self.row_cache.append(result) - return result - - def process_row(self, row): - return row - - def iterator(self): - """Efficient one-pass iteration over the result set.""" - while True: - try: - yield self.iterate(False) - except StopIteration: - return - - def fill_cache(self, n=0): - n = n or float('Inf') - if n < 0: - raise ValueError('Negative values are not supported.') - - iterator = ResultIterator(self) - iterator.index = self.count - while not self.populated and (n > self.count): - try: - iterator.next() - except StopIteration: - break - - -class DictCursorWrapper(CursorWrapper): - def _initialize_columns(self): - description = self.cursor.description - self.columns = [t[0][t[0].find('.') + 1:].strip('"') - for t in description] - self.ncols = len(description) - - initialize = _initialize_columns - - def _row_to_dict(self, row): - result = {} - for i in range(self.ncols): - result.setdefault(self.columns[i], row[i]) # Do not overwrite. - return result - - process_row = _row_to_dict - - -class NamedTupleCursorWrapper(CursorWrapper): - def initialize(self): - description = self.cursor.description - self.tuple_class = collections.namedtuple( - 'Row', - [col[0][col[0].find('.') + 1:].strip('"') for col in description]) - - def process_row(self, row): - return self.tuple_class(*row) - - -class ObjectCursorWrapper(DictCursorWrapper): - def __init__(self, cursor, constructor): - super(ObjectCursorWrapper, self).__init__(cursor) - self.constructor = constructor - - def process_row(self, row): - row_dict = self._row_to_dict(row) - return self.constructor(**row_dict) - - -class ResultIterator(object): - def __init__(self, cursor_wrapper): - self.cursor_wrapper = cursor_wrapper - self.index = 0 - - def __iter__(self): - return self - - def next(self): - if self.index < self.cursor_wrapper.count: - obj = self.cursor_wrapper.row_cache[self.index] - elif not self.cursor_wrapper.populated: - self.cursor_wrapper.iterate() - obj = self.cursor_wrapper.row_cache[self.index] - else: - raise StopIteration - self.index += 1 - return obj - - __next__ = next - -# FIELDS - -class FieldAccessor(object): - def __init__(self, model, field, name): - self.model = model - self.field = field - self.name = name - - def __get__(self, instance, instance_type=None): - if instance is not None: - return instance.__data__.get(self.name) - return self.field - - def __set__(self, instance, value): - instance.__data__[self.name] = value - instance._dirty.add(self.name) - - -class ForeignKeyAccessor(FieldAccessor): - def __init__(self, model, field, name): - super(ForeignKeyAccessor, self).__init__(model, field, name) - self.rel_model = field.rel_model - - def get_rel_instance(self, instance): - value = instance.__data__.get(self.name) - if value is not None or self.name in instance.__rel__: - if self.name not in instance.__rel__: - obj = self.rel_model.get(self.field.rel_field == value) - instance.__rel__[self.name] = obj - return instance.__rel__[self.name] - elif not self.field.null: - raise self.rel_model.DoesNotExist - return value - - def __get__(self, instance, instance_type=None): - if instance is not None: - return self.get_rel_instance(instance) - return self.field - - def __set__(self, instance, obj): - if isinstance(obj, self.rel_model): - instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) - instance.__rel__[self.name] = obj - else: - fk_value = instance.__data__.get(self.name) - instance.__data__[self.name] = obj - if obj != fk_value and self.name in instance.__rel__: - del instance.__rel__[self.name] - instance._dirty.add(self.name) - - -class NoQueryForeignKeyAccessor(ForeignKeyAccessor): - def get_rel_instance(self, instance): - value = instance.__data__.get(self.name) - if value is not None: - return instance.__rel__.get(self.name, value) - elif not self.field.null: - raise self.rel_model.DoesNotExist - - -class BackrefAccessor(object): - def __init__(self, field): - self.field = field - self.model = field.rel_model - self.rel_model = field.model - - def __get__(self, instance, instance_type=None): - if instance is not None: - dest = self.field.rel_field.name - return (self.rel_model - .select() - .where(self.field == getattr(instance, dest))) - return self - - -class ObjectIdAccessor(object): - """Gives direct access to the underlying id""" - def __init__(self, field): - self.field = field - - def __get__(self, instance, instance_type=None): - if instance is not None: - return instance.__data__.get(self.field.name) - return self.field - - def __set__(self, instance, value): - setattr(instance, self.field.name, value) - - -class Field(ColumnBase): - _field_counter = 0 - _order = 0 - accessor_class = FieldAccessor - auto_increment = False - default_index_type = None - field_type = 'DEFAULT' - - def __init__(self, null=False, index=False, unique=False, column_name=None, - default=None, primary_key=False, constraints=None, - sequence=None, collation=None, unindexed=False, choices=None, - help_text=None, verbose_name=None, index_type=None, - db_column=None, _hidden=False): - if db_column is not None: - __deprecated__('"db_column" has been deprecated in favor of ' - '"column_name" for Field objects.') - column_name = db_column - - self.null = null - self.index = index - self.unique = unique - self.column_name = column_name - self.default = default - self.primary_key = primary_key - self.constraints = constraints # List of column constraints. - self.sequence = sequence # Name of sequence, e.g. foo_id_seq. - self.collation = collation - self.unindexed = unindexed - self.choices = choices - self.help_text = help_text - self.verbose_name = verbose_name - self.index_type = index_type or self.default_index_type - self._hidden = _hidden - - # Used internally for recovering the order in which Fields were defined - # on the Model class. - Field._field_counter += 1 - self._order = Field._field_counter - self._sort_key = (self.primary_key and 1 or 2), self._order - - def __hash__(self): - return hash(self.name + '.' + self.model.__name__) - - def __repr__(self): - if hasattr(self, 'model') and getattr(self, 'name', None): - return '<%s: %s.%s>' % (type(self).__name__, - self.model.__name__, - self.name) - return '<%s: (unbound)>' % type(self).__name__ - - def bind(self, model, name, set_attribute=True): - self.model = model - self.name = self.safe_name = name - self.column_name = self.column_name or name - if set_attribute: - setattr(model, name, self.accessor_class(model, self, name)) - - @property - def column(self): - return Column(self.model._meta.table, self.column_name) - - def adapt(self, value): - return value - - def db_value(self, value): - return value if value is None else self.adapt(value) - - def python_value(self, value): - return value if value is None else self.adapt(value) - - def get_sort_key(self, ctx): - return self._sort_key - - def __sql__(self, ctx): - return ctx.sql(self.column) - - def get_modifiers(self): - pass - - def ddl_datatype(self, ctx): - if ctx and ctx.state.field_types: - column_type = ctx.state.field_types.get(self.field_type, - self.field_type) - else: - column_type = self.field_type - - modifiers = self.get_modifiers() - if column_type and modifiers: - modifier_literal = ', '.join([str(m) for m in modifiers]) - return SQL('%s(%s)' % (column_type, modifier_literal)) - else: - return SQL(column_type) - - def ddl(self, ctx): - accum = [Entity(self.column_name)] - data_type = self.ddl_datatype(ctx) - if data_type: - accum.append(data_type) - if self.unindexed: - accum.append(SQL('UNINDEXED')) - if not self.null: - accum.append(SQL('NOT NULL')) - if self.primary_key: - accum.append(SQL('PRIMARY KEY')) - if self.sequence: - accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) - if self.constraints: - accum.extend(self.constraints) - if self.collation: - accum.append(SQL('COLLATE %s' % self.collation)) - return NodeList(accum) - - -class IntegerField(Field): - field_type = 'INT' - - def adapt(self, value): - try: - return int(value) - except ValueError: - return value - - -class BigIntegerField(IntegerField): - field_type = 'BIGINT' - - -class SmallIntegerField(IntegerField): - field_type = 'SMALLINT' - - -class AutoField(IntegerField): - auto_increment = True - field_type = 'AUTO' - - def __init__(self, *args, **kwargs): - if kwargs.get('primary_key') is False: - raise ValueError('%s must always be a primary key.' % type(self)) - kwargs['primary_key'] = True - super(AutoField, self).__init__(*args, **kwargs) - - -class BigAutoField(AutoField): - field_type = 'BIGAUTO' - - -class IdentityField(AutoField): - field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' - - -class PrimaryKeyField(AutoField): - def __init__(self, *args, **kwargs): - __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' - 'Please update your code accordingly as this will be ' - 'completely removed in a subsequent release.') - super(PrimaryKeyField, self).__init__(*args, **kwargs) - - -class FloatField(Field): - field_type = 'FLOAT' - - def adapt(self, value): - try: - return float(value) - except ValueError: - return value - - -class DoubleField(FloatField): - field_type = 'DOUBLE' - - -class DecimalField(Field): - field_type = 'DECIMAL' - - def __init__(self, max_digits=10, decimal_places=5, auto_round=False, - rounding=None, *args, **kwargs): - self.max_digits = max_digits - self.decimal_places = decimal_places - self.auto_round = auto_round - self.rounding = rounding or decimal.DefaultContext.rounding - self._exp = decimal.Decimal(10) ** (-self.decimal_places) - super(DecimalField, self).__init__(*args, **kwargs) - - def get_modifiers(self): - return [self.max_digits, self.decimal_places] - - def db_value(self, value): - D = decimal.Decimal - if not value: - return value if value is None else D(0) - if self.auto_round: - decimal_value = D(text_type(value)) - return decimal_value.quantize(self._exp, rounding=self.rounding) - return value - - def python_value(self, value): - if value is not None: - if isinstance(value, decimal.Decimal): - return value - return decimal.Decimal(text_type(value)) - - -class _StringField(Field): - def adapt(self, value): - if isinstance(value, text_type): - return value - elif isinstance(value, bytes_type): - return value.decode('utf-8') - return text_type(value) - - def __add__(self, other): return StringExpression(self, OP.CONCAT, other) - def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) - - -class CharField(_StringField): - field_type = 'VARCHAR' - - def __init__(self, max_length=255, *args, **kwargs): - self.max_length = max_length - super(CharField, self).__init__(*args, **kwargs) - - def get_modifiers(self): - return self.max_length and [self.max_length] or None - - -class FixedCharField(CharField): - field_type = 'CHAR' - - def python_value(self, value): - value = super(FixedCharField, self).python_value(value) - if value: - value = value.strip() - return value - - -class TextField(_StringField): - field_type = 'TEXT' - - -class BlobField(Field): - field_type = 'BLOB' - - def _db_hook(self, database): - if database is None: - self._constructor = bytearray - else: - self._constructor = database.get_binary_type() - - def bind(self, model, name, set_attribute=True): - self._constructor = bytearray - if model._meta.database: - if isinstance(model._meta.database, Proxy): - model._meta.database.attach_callback(self._db_hook) - else: - self._db_hook(model._meta.database) - - # Attach a hook to the model metadata; in the event the database is - # changed or set at run-time, we will be sure to apply our callback and - # use the proper data-type for our database driver. - model._meta._db_hooks.append(self._db_hook) - return super(BlobField, self).bind(model, name, set_attribute) - - def db_value(self, value): - if isinstance(value, text_type): - value = value.encode('raw_unicode_escape') - if isinstance(value, bytes_type): - return self._constructor(value) - return value - - -class BitField(BitwiseMixin, BigIntegerField): - def __init__(self, *args, **kwargs): - kwargs.setdefault('default', 0) - super(BitField, self).__init__(*args, **kwargs) - self.__current_flag = 1 - - def flag(self, value=None): - if value is None: - value = self.__current_flag - self.__current_flag <<= 1 - else: - self.__current_flag = value << 1 - - class FlagDescriptor(object): - def __init__(self, field, value): - self._field = field - self._value = value - def __get__(self, instance, instance_type=None): - if instance is None: - return self._field.bin_and(self._value) != 0 - value = getattr(instance, self._field.name) or 0 - return (value & self._value) != 0 - def __set__(self, instance, is_set): - if is_set not in (True, False): - raise ValueError('Value must be either True or False') - value = getattr(instance, self._field.name) or 0 - if is_set: - value |= self._value - else: - value &= ~self._value - setattr(instance, self._field.name, value) - return FlagDescriptor(self, value) - - -class BigBitFieldData(object): - def __init__(self, instance, name): - self.instance = instance - self.name = name - value = self.instance.__data__.get(self.name) - if not value: - value = bytearray() - elif not isinstance(value, bytearray): - value = bytearray(value) - self._buffer = self.instance.__data__[self.name] = value - - def _ensure_length(self, idx): - byte_num, byte_offset = divmod(idx, 8) - cur_size = len(self._buffer) - if cur_size <= byte_num: - self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) - return byte_num, byte_offset - - def set_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] |= (1 << byte_offset) - - def clear_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] &= ~(1 << byte_offset) - - def toggle_bit(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - self._buffer[byte_num] ^= (1 << byte_offset) - return bool(self._buffer[byte_num] & (1 << byte_offset)) - - def is_set(self, idx): - byte_num, byte_offset = self._ensure_length(idx) - return bool(self._buffer[byte_num] & (1 << byte_offset)) - - def __repr__(self): - return repr(self._buffer) - - -class BigBitFieldAccessor(FieldAccessor): - def __get__(self, instance, instance_type=None): - if instance is None: - return self.field - return BigBitFieldData(instance, self.name) - def __set__(self, instance, value): - if isinstance(value, memoryview): - value = value.tobytes() - elif isinstance(value, buffer_type): - value = bytes(value) - elif isinstance(value, bytearray): - value = bytes_type(value) - elif isinstance(value, BigBitFieldData): - value = bytes_type(value._buffer) - elif isinstance(value, text_type): - value = value.encode('utf-8') - elif not isinstance(value, bytes_type): - raise ValueError('Value must be either a bytes, memoryview or ' - 'BigBitFieldData instance.') - super(BigBitFieldAccessor, self).__set__(instance, value) - - -class BigBitField(BlobField): - accessor_class = BigBitFieldAccessor - - def __init__(self, *args, **kwargs): - kwargs.setdefault('default', bytes_type) - super(BigBitField, self).__init__(*args, **kwargs) - - def db_value(self, value): - return bytes_type(value) if value is not None else value - - -class UUIDField(Field): - field_type = 'UUID' - - def db_value(self, value): - if isinstance(value, basestring) and len(value) == 32: - # Hex string. No transformation is necessary. - return value - elif isinstance(value, bytes) and len(value) == 16: - # Allow raw binary representation. - value = uuid.UUID(bytes=value) - if isinstance(value, uuid.UUID): - return value.hex - try: - return uuid.UUID(value).hex - except: - return value - - def python_value(self, value): - if isinstance(value, uuid.UUID): - return value - return uuid.UUID(value) if value is not None else None - - -class BinaryUUIDField(BlobField): - field_type = 'UUIDB' - - def db_value(self, value): - if isinstance(value, bytes) and len(value) == 16: - # Raw binary value. No transformation is necessary. - return self._constructor(value) - elif isinstance(value, basestring) and len(value) == 32: - # Allow hex string representation. - value = uuid.UUID(hex=value) - if isinstance(value, uuid.UUID): - return self._constructor(value.bytes) - elif value is not None: - raise ValueError('value for binary UUID field must be UUID(), ' - 'a hexadecimal string, or a bytes object.') - - def python_value(self, value): - if isinstance(value, uuid.UUID): - return value - elif isinstance(value, memoryview): - value = value.tobytes() - elif value and not isinstance(value, bytes): - value = bytes(value) - return uuid.UUID(bytes=value) if value is not None else None - - -def _date_part(date_part): - def dec(self): - return self.model._meta.database.extract_date(date_part, self) - return dec - -def format_date_time(value, formats, post_process=None): - post_process = post_process or (lambda x: x) - for fmt in formats: - try: - return post_process(datetime.datetime.strptime(value, fmt)) - except ValueError: - pass - return value - -def simple_date_time(value): - try: - return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') - except (TypeError, ValueError): - return value - - -class _BaseFormattedField(Field): - formats = None - - def __init__(self, formats=None, *args, **kwargs): - if formats is not None: - self.formats = formats - super(_BaseFormattedField, self).__init__(*args, **kwargs) - - -class DateTimeField(_BaseFormattedField): - field_type = 'DATETIME' - formats = [ - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d', - ] - - def adapt(self, value): - if value and isinstance(value, basestring): - return format_date_time(value, self.formats) - return value - - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) - - year = property(_date_part('year')) - month = property(_date_part('month')) - day = property(_date_part('day')) - hour = property(_date_part('hour')) - minute = property(_date_part('minute')) - second = property(_date_part('second')) - - -class DateField(_BaseFormattedField): - field_type = 'DATE' - formats = [ - '%Y-%m-%d', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - ] - - def adapt(self, value): - if value and isinstance(value, basestring): - pp = lambda x: x.date() - return format_date_time(value, self.formats, pp) - elif value and isinstance(value, datetime.datetime): - return value.date() - return value - - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) - - year = property(_date_part('year')) - month = property(_date_part('month')) - day = property(_date_part('day')) - - -class TimeField(_BaseFormattedField): - field_type = 'TIME' - formats = [ - '%H:%M:%S.%f', - '%H:%M:%S', - '%H:%M', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d %H:%M:%S', - ] - - def adapt(self, value): - if value: - if isinstance(value, basestring): - pp = lambda x: x.time() - return format_date_time(value, self.formats, pp) - elif isinstance(value, datetime.datetime): - return value.time() - if value is not None and isinstance(value, datetime.timedelta): - return (datetime.datetime.min + value).time() - return value - - hour = property(_date_part('hour')) - minute = property(_date_part('minute')) - second = property(_date_part('second')) - - -def _timestamp_date_part(date_part): - def dec(self): - db = self.model._meta.database - expr = ((self / Value(self.resolution, converter=False)) - if self.resolution > 1 else self) - return db.extract_date(date_part, db.from_timestamp(expr)) - return dec - - -class TimestampField(BigIntegerField): - # Support second -> microsecond resolution. - valid_resolutions = [10**i for i in range(7)] - - def __init__(self, *args, **kwargs): - self.resolution = kwargs.pop('resolution', None) - - if not self.resolution: - self.resolution = 1 - elif self.resolution in range(2, 7): - self.resolution = 10 ** self.resolution - elif self.resolution not in self.valid_resolutions: - raise ValueError('TimestampField resolution must be one of: %s' % - ', '.join(str(i) for i in self.valid_resolutions)) - self.ticks_to_microsecond = 1000000 // self.resolution - - self.utc = kwargs.pop('utc', False) or False - dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now - kwargs.setdefault('default', dflt) - super(TimestampField, self).__init__(*args, **kwargs) - - def local_to_utc(self, dt): - # Convert naive local datetime into naive UTC, e.g.: - # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. - # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. - # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. - return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) - - def utc_to_local(self, dt): - # Convert a naive UTC datetime into local time, e.g.: - # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. - # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. - # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. - ts = calendar.timegm(dt.utctimetuple()) - return datetime.datetime.fromtimestamp(ts) - - def get_timestamp(self, value): - if self.utc: - # If utc-mode is on, then we assume all naive datetimes are in UTC. - return calendar.timegm(value.utctimetuple()) - else: - return time.mktime(value.timetuple()) - - def db_value(self, value): - if value is None: - return - - if isinstance(value, datetime.datetime): - pass - elif isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - else: - return int(round(value * self.resolution)) - - timestamp = self.get_timestamp(value) - if self.resolution > 1: - timestamp += (value.microsecond * .000001) - timestamp *= self.resolution - return int(round(timestamp)) - - def python_value(self, value): - if value is not None and isinstance(value, (int, float, long)): - if self.resolution > 1: - value, ticks = divmod(value, self.resolution) - microseconds = int(ticks * self.ticks_to_microsecond) - else: - microseconds = 0 - - if self.utc: - value = datetime.datetime.utcfromtimestamp(value) - else: - value = datetime.datetime.fromtimestamp(value) - - if microseconds: - value = value.replace(microsecond=microseconds) - - return value - - def from_timestamp(self): - expr = ((self / Value(self.resolution, converter=False)) - if self.resolution > 1 else self) - return self.model._meta.database.from_timestamp(expr) - - year = property(_timestamp_date_part('year')) - month = property(_timestamp_date_part('month')) - day = property(_timestamp_date_part('day')) - hour = property(_timestamp_date_part('hour')) - minute = property(_timestamp_date_part('minute')) - second = property(_timestamp_date_part('second')) - - -class IPField(BigIntegerField): - def db_value(self, val): - if val is not None: - return struct.unpack('!I', socket.inet_aton(val))[0] - - def python_value(self, val): - if val is not None: - return socket.inet_ntoa(struct.pack('!I', val)) - - -class BooleanField(Field): - field_type = 'BOOL' - adapt = bool - - -class BareField(Field): - def __init__(self, adapt=None, *args, **kwargs): - super(BareField, self).__init__(*args, **kwargs) - if adapt is not None: - self.adapt = adapt - - def ddl_datatype(self, ctx): - return - - -class ForeignKeyField(Field): - accessor_class = ForeignKeyAccessor - - def __init__(self, model, field=None, backref=None, on_delete=None, - on_update=None, deferrable=None, _deferred=None, - rel_model=None, to_field=None, object_id_name=None, - lazy_load=True, related_name=None, *args, **kwargs): - kwargs.setdefault('index', True) - - # If lazy_load is disable, we use a different descriptor/accessor that - # will ensure we don't accidentally perform a query. - if not lazy_load: - self.accessor_class = NoQueryForeignKeyAccessor - - super(ForeignKeyField, self).__init__(*args, **kwargs) - - if rel_model is not None: - __deprecated__('"rel_model" has been deprecated in favor of ' - '"model" for ForeignKeyField objects.') - model = rel_model - if to_field is not None: - __deprecated__('"to_field" has been deprecated in favor of ' - '"field" for ForeignKeyField objects.') - field = to_field - if related_name is not None: - __deprecated__('"related_name" has been deprecated in favor of ' - '"backref" for Field objects.') - backref = related_name - - self._is_self_reference = model == 'self' - self.rel_model = model - self.rel_field = field - self.declared_backref = backref - self.backref = None - self.on_delete = on_delete - self.on_update = on_update - self.deferrable = deferrable - self.deferred = _deferred - self.object_id_name = object_id_name - self.lazy_load = lazy_load - - @property - def field_type(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.field_type - elif isinstance(self.rel_field, BigAutoField): - return BigIntegerField.field_type - return IntegerField.field_type - - def get_modifiers(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.get_modifiers() - return super(ForeignKeyField, self).get_modifiers() - - def adapt(self, value): - return self.rel_field.adapt(value) - - def db_value(self, value): - if isinstance(value, self.rel_model): - value = value.get_id() - return self.rel_field.db_value(value) - - def python_value(self, value): - if isinstance(value, self.rel_model): - return value - return self.rel_field.python_value(value) - - def bind(self, model, name, set_attribute=True): - if not self.column_name: - self.column_name = name if name.endswith('_id') else name + '_id' - if not self.object_id_name: - self.object_id_name = self.column_name - if self.object_id_name == name: - self.object_id_name += '_id' - elif self.object_id_name == name: - raise ValueError('ForeignKeyField "%s"."%s" specifies an ' - 'object_id_name that conflicts with its field ' - 'name.' % (model._meta.name, name)) - if self._is_self_reference: - self.rel_model = model - if isinstance(self.rel_field, basestring): - self.rel_field = getattr(self.rel_model, self.rel_field) - elif self.rel_field is None: - self.rel_field = self.rel_model._meta.primary_key - - # Bind field before assigning backref, so field is bound when - # calling declared_backref() (if callable). - super(ForeignKeyField, self).bind(model, name, set_attribute) - self.safe_name = self.object_id_name - - if callable_(self.declared_backref): - self.backref = self.declared_backref(self) - else: - self.backref, self.declared_backref = self.declared_backref, None - if not self.backref: - self.backref = '%s_set' % model._meta.name - - if set_attribute: - setattr(model, self.object_id_name, ObjectIdAccessor(self)) - if self.backref not in '!+': - setattr(self.rel_model, self.backref, BackrefAccessor(self)) - - def foreign_key_constraint(self): - parts = [ - SQL('FOREIGN KEY'), - EnclosedNodeList((self,)), - SQL('REFERENCES'), - self.rel_model, - EnclosedNodeList((self.rel_field,))] - if self.on_delete: - parts.append(SQL('ON DELETE %s' % self.on_delete)) - if self.on_update: - parts.append(SQL('ON UPDATE %s' % self.on_update)) - if self.deferrable: - parts.append(SQL('DEFERRABLE %s' % self.deferrable)) - return NodeList(parts) - - def __getattr__(self, attr): - if attr.startswith('__'): - # Prevent recursion error when deep-copying. - raise AttributeError('Cannot look-up non-existant "__" methods.') - if attr in self.rel_model._meta.fields: - return self.rel_model._meta.fields[attr] - raise AttributeError('Foreign-key has no attribute %s, nor is it a ' - 'valid field on the related model.' % attr) - - -class DeferredForeignKey(Field): - _unresolved = set() - - def __init__(self, rel_model_name, **kwargs): - self.field_kwargs = kwargs - self.rel_model_name = rel_model_name.lower() - DeferredForeignKey._unresolved.add(self) - super(DeferredForeignKey, self).__init__( - column_name=kwargs.get('column_name'), - null=kwargs.get('null')) - - __hash__ = object.__hash__ - - def __deepcopy__(self, memo=None): - return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) - - def set_model(self, rel_model): - field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) - self.model._meta.add_field(self.name, field) - - @staticmethod - def resolve(model_cls): - unresolved = sorted(DeferredForeignKey._unresolved, - key=operator.attrgetter('_order')) - for dr in unresolved: - if dr.rel_model_name == model_cls.__name__.lower(): - dr.set_model(model_cls) - DeferredForeignKey._unresolved.discard(dr) - - -class DeferredThroughModel(object): - def __init__(self): - self._refs = [] - - def set_field(self, model, field, name): - self._refs.append((model, field, name)) - - def set_model(self, through_model): - for src_model, m2mfield, name in self._refs: - m2mfield.through_model = through_model - src_model._meta.add_field(name, m2mfield) - - -class MetaField(Field): - column_name = default = model = name = None - primary_key = False - - -class ManyToManyFieldAccessor(FieldAccessor): - def __init__(self, model, field, name): - super(ManyToManyFieldAccessor, self).__init__(model, field, name) - self.model = field.model - self.rel_model = field.rel_model - self.through_model = field.through_model - src_fks = self.through_model._meta.model_refs[self.model] - dest_fks = self.through_model._meta.model_refs[self.rel_model] - if not src_fks: - raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % - (self.model, self.through_model)) - elif not dest_fks: - raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % - (self.rel_model, self.through_model)) - self.src_fk = src_fks[0] - self.dest_fk = dest_fks[0] - - def __get__(self, instance, instance_type=None, force_query=False): - if instance is not None: - if not force_query and self.src_fk.backref != '+': - backref = getattr(instance, self.src_fk.backref) - if isinstance(backref, list): - return [getattr(obj, self.dest_fk.name) for obj in backref] - - src_id = getattr(instance, self.src_fk.rel_field.name) - return (ManyToManyQuery(instance, self, self.rel_model) - .join(self.through_model) - .join(self.model) - .where(self.src_fk == src_id)) - - return self.field - - def __set__(self, instance, value): - query = self.__get__(instance, force_query=True) - query.add(value, clear_existing=True) - - -class ManyToManyField(MetaField): - accessor_class = ManyToManyFieldAccessor - - def __init__(self, model, backref=None, through_model=None, on_delete=None, - on_update=None, _is_backref=False): - if through_model is not None: - if not (isinstance(through_model, DeferredThroughModel) or - is_model(through_model)): - raise TypeError('Unexpected value for through_model. Expected ' - 'Model or DeferredThroughModel.') - if not _is_backref and (on_delete is not None or on_update is not None): - raise ValueError('Cannot specify on_delete or on_update when ' - 'through_model is specified.') - self.rel_model = model - self.backref = backref - self._through_model = through_model - self._on_delete = on_delete - self._on_update = on_update - self._is_backref = _is_backref - - def _get_descriptor(self): - return ManyToManyFieldAccessor(self) - - def bind(self, model, name, set_attribute=True): - if isinstance(self._through_model, DeferredThroughModel): - self._through_model.set_field(model, self, name) - return - - super(ManyToManyField, self).bind(model, name, set_attribute) - - if not self._is_backref: - many_to_many_field = ManyToManyField( - self.model, - backref=name, - through_model=self.through_model, - on_delete=self._on_delete, - on_update=self._on_update, - _is_backref=True) - self.backref = self.backref or model._meta.name + 's' - self.rel_model._meta.add_field(self.backref, many_to_many_field) - - def get_models(self): - return [model for _, model in sorted(( - (self._is_backref, self.model), - (not self._is_backref, self.rel_model)))] - - @property - def through_model(self): - if self._through_model is None: - self._through_model = self._create_through_model() - return self._through_model - - @through_model.setter - def through_model(self, value): - self._through_model = value - - def _create_through_model(self): - lhs, rhs = self.get_models() - tables = [model._meta.table_name for model in (lhs, rhs)] - - class Meta: - database = self.model._meta.database - schema = self.model._meta.schema - table_name = '%s_%s_through' % tuple(tables) - indexes = ( - ((lhs._meta.name, rhs._meta.name), - True),) - - params = {'on_delete': self._on_delete, 'on_update': self._on_update} - attrs = { - lhs._meta.name: ForeignKeyField(lhs, **params), - rhs._meta.name: ForeignKeyField(rhs, **params), - 'Meta': Meta} - - klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) - return type(klass_name, (Model,), attrs) - - def get_through_model(self): - # XXX: Deprecated. Just use the "through_model" property. - return self.through_model - - -class VirtualField(MetaField): - field_class = None - - def __init__(self, field_class=None, *args, **kwargs): - Field = field_class if field_class is not None else self.field_class - self.field_instance = Field() if Field is not None else None - super(VirtualField, self).__init__(*args, **kwargs) - - def db_value(self, value): - if self.field_instance is not None: - return self.field_instance.db_value(value) - return value - - def python_value(self, value): - if self.field_instance is not None: - return self.field_instance.python_value(value) - return value - - def bind(self, model, name, set_attribute=True): - self.model = model - self.column_name = self.name = self.safe_name = name - setattr(model, name, self.accessor_class(model, self, name)) - - -class CompositeKey(MetaField): - sequence = None - - def __init__(self, *field_names): - self.field_names = field_names - - def __get__(self, instance, instance_type=None): - if instance is not None: - return tuple([getattr(instance, field_name) - for field_name in self.field_names]) - return self - - def __set__(self, instance, value): - if not isinstance(value, (list, tuple)): - raise TypeError('A list or tuple must be used to set the value of ' - 'a composite primary key.') - if len(value) != len(self.field_names): - raise ValueError('The length of the value must equal the number ' - 'of columns of the composite primary key.') - for idx, field_value in enumerate(value): - setattr(instance, self.field_names[idx], field_value) - - def __eq__(self, other): - expressions = [(self.model._meta.fields[field] == value) - for field, value in zip(self.field_names, other)] - return reduce(operator.and_, expressions) - - def __ne__(self, other): - return ~(self == other) - - def __hash__(self): - return hash((self.model.__name__, self.field_names)) - - def __sql__(self, ctx): - # If the composite PK is being selected, do not use parens. Elsewhere, - # such as in an expression, we want to use parentheses and treat it as - # a row value. - parens = ctx.scope != SCOPE_SOURCE - return ctx.sql(NodeList([self.model._meta.fields[field] - for field in self.field_names], ', ', parens)) - - def bind(self, model, name, set_attribute=True): - self.model = model - self.column_name = self.name = self.safe_name = name - setattr(model, self.name, self) - - -class _SortedFieldList(object): - __slots__ = ('_keys', '_items') - - def __init__(self): - self._keys = [] - self._items = [] - - def __getitem__(self, i): - return self._items[i] - - def __iter__(self): - return iter(self._items) - - def __contains__(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - j = bisect_right(self._keys, k) - return item in self._items[i:j] - - def index(self, field): - return self._keys.index(field._sort_key) - - def insert(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - self._keys.insert(i, k) - self._items.insert(i, item) - - def remove(self, item): - idx = self.index(item) - del self._items[idx] - del self._keys[idx] - - -# MODELS - - -class SchemaManager(object): - def __init__(self, model, database=None, **context_options): - self.model = model - self._database = database - context_options.setdefault('scope', SCOPE_VALUES) - self.context_options = context_options - - @property - def database(self): - db = self._database or self.model._meta.database - if db is None: - raise ImproperlyConfigured('database attribute does not appear to ' - 'be set on the model: %s' % self.model) - return db - - @database.setter - def database(self, value): - self._database = value - - def _create_context(self): - return self.database.get_sql_context(**self.context_options) - - def _create_table(self, safe=True, **options): - is_temp = options.pop('temporary', False) - ctx = self._create_context() - ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') - if safe: - ctx.literal('IF NOT EXISTS ') - ctx.sql(self.model).literal(' ') - - columns = [] - constraints = [] - meta = self.model._meta - if meta.composite_key: - pk_columns = [meta.fields[field_name].column - for field_name in meta.primary_key.field_names] - constraints.append(NodeList((SQL('PRIMARY KEY'), - EnclosedNodeList(pk_columns)))) - - for field in meta.sorted_fields: - columns.append(field.ddl(ctx)) - if isinstance(field, ForeignKeyField) and not field.deferred: - constraints.append(field.foreign_key_constraint()) - - if meta.constraints: - constraints.extend(meta.constraints) - - constraints.extend(self._create_table_option_sql(options)) - ctx.sql(EnclosedNodeList(columns + constraints)) - - if meta.table_settings is not None: - table_settings = ensure_tuple(meta.table_settings) - for setting in table_settings: - if not isinstance(setting, basestring): - raise ValueError('table_settings must be strings') - ctx.literal(' ').literal(setting) - - if meta.without_rowid: - ctx.literal(' WITHOUT ROWID') - return ctx - - def _create_table_option_sql(self, options): - accum = [] - options = merge_dict(self.model._meta.options or {}, options) - if not options: - return accum - - for key, value in sorted(options.items()): - if not isinstance(value, Node): - if is_model(value): - value = value._meta.table - else: - value = SQL(str(value)) - accum.append(NodeList((SQL(key), value), glue='=')) - return accum - - def create_table(self, safe=True, **options): - self.database.execute(self._create_table(safe=safe, **options)) - - def _create_table_as(self, table_name, query, safe=True, **meta): - ctx = (self._create_context() - .literal('CREATE TEMPORARY TABLE ' - if meta.get('temporary') else 'CREATE TABLE ')) - if safe: - ctx.literal('IF NOT EXISTS ') - return (ctx - .sql(Entity(table_name)) - .literal(' AS ') - .sql(query)) - - def create_table_as(self, table_name, query, safe=True, **meta): - ctx = self._create_table_as(table_name, query, safe=safe, **meta) - self.database.execute(ctx) - - def _drop_table(self, safe=True, **options): - ctx = (self._create_context() - .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') - .sql(self.model)) - if options.get('cascade'): - ctx = ctx.literal(' CASCADE') - elif options.get('restrict'): - ctx = ctx.literal(' RESTRICT') - return ctx - - def drop_table(self, safe=True, **options): - self.database.execute(self._drop_table(safe=safe, **options)) - - def _truncate_table(self, restart_identity=False, cascade=False): - db = self.database - if not db.truncate_table: - return (self._create_context() - .literal('DELETE FROM ').sql(self.model)) - - ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) - if restart_identity: - ctx = ctx.literal(' RESTART IDENTITY') - if cascade: - ctx = ctx.literal(' CASCADE') - return ctx - - def truncate_table(self, restart_identity=False, cascade=False): - self.database.execute(self._truncate_table(restart_identity, cascade)) - - def _create_indexes(self, safe=True): - return [self._create_index(index, safe) - for index in self.model._meta.fields_to_index()] - - def _create_index(self, index, safe=True): - if isinstance(index, Index): - if not self.database.safe_create_index: - index = index.safe(False) - elif index._safe != safe: - index = index.safe(safe) - return self._create_context().sql(index) - - def create_indexes(self, safe=True): - for query in self._create_indexes(safe=safe): - self.database.execute(query) - - def _drop_indexes(self, safe=True): - return [self._drop_index(index, safe) - for index in self.model._meta.fields_to_index() - if isinstance(index, Index)] - - def _drop_index(self, index, safe): - statement = 'DROP INDEX ' - if safe and self.database.safe_drop_index: - statement += 'IF EXISTS ' - if isinstance(index._table, Table) and index._table._schema: - index_name = Entity(index._table._schema, index._name) - else: - index_name = Entity(index._name) - return (self - ._create_context() - .literal(statement) - .sql(index_name)) - - def drop_indexes(self, safe=True): - for query in self._drop_indexes(safe=safe): - self.database.execute(query) - - def _check_sequences(self, field): - if not field.sequence or not self.database.sequences: - raise ValueError('Sequences are either not supported, or are not ' - 'defined for "%s".' % field.name) - - def _sequence_for_field(self, field): - if field.model._meta.schema: - return Entity(field.model._meta.schema, field.sequence) - else: - return Entity(field.sequence) - - def _create_sequence(self, field): - self._check_sequences(field) - if not self.database.sequence_exists(field.sequence): - return (self - ._create_context() - .literal('CREATE SEQUENCE ') - .sql(self._sequence_for_field(field))) - - def create_sequence(self, field): - seq_ctx = self._create_sequence(field) - if seq_ctx is not None: - self.database.execute(seq_ctx) - - def _drop_sequence(self, field): - self._check_sequences(field) - if self.database.sequence_exists(field.sequence): - return (self - ._create_context() - .literal('DROP SEQUENCE ') - .sql(self._sequence_for_field(field))) - - def drop_sequence(self, field): - seq_ctx = self._drop_sequence(field) - if seq_ctx is not None: - self.database.execute(seq_ctx) - - def _create_foreign_key(self, field): - name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, - field.column_name, - field.rel_model._meta.table_name) - return (self - ._create_context() - .literal('ALTER TABLE ') - .sql(field.model) - .literal(' ADD CONSTRAINT ') - .sql(Entity(_truncate_constraint_name(name))) - .literal(' ') - .sql(field.foreign_key_constraint())) - - def create_foreign_key(self, field): - self.database.execute(self._create_foreign_key(field)) - - def create_sequences(self): - if self.database.sequences: - for field in self.model._meta.sorted_fields: - if field.sequence: - self.create_sequence(field) - - def create_all(self, safe=True, **table_options): - self.create_sequences() - self.create_table(safe, **table_options) - self.create_indexes(safe=safe) - - def drop_sequences(self): - if self.database.sequences: - for field in self.model._meta.sorted_fields: - if field.sequence: - self.drop_sequence(field) - - def drop_all(self, safe=True, drop_sequences=True, **options): - self.drop_table(safe, **options) - if drop_sequences: - self.drop_sequences() - - -class Metadata(object): - def __init__(self, model, database=None, table_name=None, indexes=None, - primary_key=None, constraints=None, schema=None, - only_save_dirty=False, depends_on=None, options=None, - db_table=None, table_function=None, table_settings=None, - without_rowid=False, temporary=False, legacy_table_names=True, - **kwargs): - if db_table is not None: - __deprecated__('"db_table" has been deprecated in favor of ' - '"table_name" for Models.') - table_name = db_table - self.model = model - self.database = database - - self.fields = {} - self.columns = {} - self.combined = {} - - self._sorted_field_list = _SortedFieldList() - self.sorted_fields = [] - self.sorted_field_names = [] - - self.defaults = {} - self._default_by_name = {} - self._default_dict = {} - self._default_callables = {} - self._default_callable_list = [] - - self.name = model.__name__.lower() - self.table_function = table_function - self.legacy_table_names = legacy_table_names - if not table_name: - table_name = (self.table_function(model) - if self.table_function - else self.make_table_name()) - self.table_name = table_name - self._table = None - - self.indexes = list(indexes) if indexes else [] - self.constraints = constraints - self._schema = schema - self.primary_key = primary_key - self.composite_key = self.auto_increment = None - self.only_save_dirty = only_save_dirty - self.depends_on = depends_on - self.table_settings = table_settings - self.without_rowid = without_rowid - self.temporary = temporary - - self.refs = {} - self.backrefs = {} - self.model_refs = collections.defaultdict(list) - self.model_backrefs = collections.defaultdict(list) - self.manytomany = {} - - self.options = options or {} - for key, value in kwargs.items(): - setattr(self, key, value) - self._additional_keys = set(kwargs.keys()) - - # Allow objects to register hooks that are called if the model is bound - # to a different database. For example, BlobField uses a different - # Python data-type depending on the db driver / python version. When - # the database changes, we need to update any BlobField so they can use - # the appropriate data-type. - self._db_hooks = [] - - def make_table_name(self): - if self.legacy_table_names: - return re.sub('[^\w]+', '_', self.name) - return make_snake_case(self.model.__name__) - - def model_graph(self, refs=True, backrefs=True, depth_first=True): - if not refs and not backrefs: - raise ValueError('One of `refs` or `backrefs` must be True.') - - accum = [(None, self.model, None)] - seen = set() - queue = collections.deque((self,)) - method = queue.pop if depth_first else queue.popleft - - while queue: - curr = method() - if curr in seen: continue - seen.add(curr) - - if refs: - for fk, model in curr.refs.items(): - accum.append((fk, model, False)) - queue.append(model._meta) - if backrefs: - for fk, model in curr.backrefs.items(): - accum.append((fk, model, True)) - queue.append(model._meta) - - return accum - - def add_ref(self, field): - rel = field.rel_model - self.refs[field] = rel - self.model_refs[rel].append(field) - rel._meta.backrefs[field] = self.model - rel._meta.model_backrefs[self.model].append(field) - - def remove_ref(self, field): - rel = field.rel_model - del self.refs[field] - self.model_refs[rel].remove(field) - del rel._meta.backrefs[field] - rel._meta.model_backrefs[self.model].remove(field) - - def add_manytomany(self, field): - self.manytomany[field.name] = field - - def remove_manytomany(self, field): - del self.manytomany[field.name] - - @property - def table(self): - if self._table is None: - self._table = Table( - self.table_name, - [field.column_name for field in self.sorted_fields], - schema=self.schema, - _model=self.model, - _database=self.database) - return self._table - - @table.setter - def table(self, value): - raise AttributeError('Cannot set the "table".') - - @table.deleter - def table(self): - self._table = None - - @property - def schema(self): - return self._schema - - @schema.setter - def schema(self, value): - self._schema = value - del self.table - - @property - def entity(self): - if self._schema: - return Entity(self._schema, self.table_name) - else: - return Entity(self.table_name) - - def _update_sorted_fields(self): - self.sorted_fields = list(self._sorted_field_list) - self.sorted_field_names = [f.name for f in self.sorted_fields] - - def get_rel_for_model(self, model): - if isinstance(model, ModelAlias): - model = model.model - forwardrefs = self.model_refs.get(model, []) - backrefs = self.model_backrefs.get(model, []) - return (forwardrefs, backrefs) - - def add_field(self, field_name, field, set_attribute=True): - if field_name in self.fields: - self.remove_field(field_name) - elif field_name in self.manytomany: - self.remove_manytomany(self.manytomany[field_name]) - - if not isinstance(field, MetaField): - del self.table - field.bind(self.model, field_name, set_attribute) - self.fields[field.name] = field - self.columns[field.column_name] = field - self.combined[field.name] = field - self.combined[field.column_name] = field - - self._sorted_field_list.insert(field) - self._update_sorted_fields() - - if field.default is not None: - # This optimization helps speed up model instance construction. - self.defaults[field] = field.default - if callable_(field.default): - self._default_callables[field] = field.default - self._default_callable_list.append((field.name, - field.default)) - else: - self._default_dict[field] = field.default - self._default_by_name[field.name] = field.default - else: - field.bind(self.model, field_name, set_attribute) - - if isinstance(field, ForeignKeyField): - self.add_ref(field) - elif isinstance(field, ManyToManyField) and field.name: - self.add_manytomany(field) - - def remove_field(self, field_name): - if field_name not in self.fields: - return - - del self.table - original = self.fields.pop(field_name) - del self.columns[original.column_name] - del self.combined[field_name] - try: - del self.combined[original.column_name] - except KeyError: - pass - self._sorted_field_list.remove(original) - self._update_sorted_fields() - - if original.default is not None: - del self.defaults[original] - if self._default_callables.pop(original, None): - for i, (name, _) in enumerate(self._default_callable_list): - if name == field_name: - self._default_callable_list.pop(i) - break - else: - self._default_dict.pop(original, None) - self._default_by_name.pop(original.name, None) - - if isinstance(original, ForeignKeyField): - self.remove_ref(original) - - def set_primary_key(self, name, field): - self.composite_key = isinstance(field, CompositeKey) - self.add_field(name, field) - self.primary_key = field - self.auto_increment = ( - field.auto_increment or - bool(field.sequence)) - - def get_primary_keys(self): - if self.composite_key: - return tuple([self.fields[field_name] - for field_name in self.primary_key.field_names]) - else: - return (self.primary_key,) if self.primary_key is not False else () - - def get_default_dict(self): - dd = self._default_by_name.copy() - for field_name, default in self._default_callable_list: - dd[field_name] = default() - return dd - - def fields_to_index(self): - indexes = [] - for f in self.sorted_fields: - if f.primary_key: - continue - if f.index or f.unique: - indexes.append(ModelIndex(self.model, (f,), unique=f.unique, - using=f.index_type)) - - for index_obj in self.indexes: - if isinstance(index_obj, Node): - indexes.append(index_obj) - elif isinstance(index_obj, (list, tuple)): - index_parts, unique = index_obj - fields = [] - for part in index_parts: - if isinstance(part, basestring): - fields.append(self.combined[part]) - elif isinstance(part, Node): - fields.append(part) - else: - raise ValueError('Expected either a field name or a ' - 'subclass of Node. Got: %s' % part) - indexes.append(ModelIndex(self.model, fields, unique=unique)) - - return indexes - - def set_database(self, database): - self.database = database - self.model._schema._database = database - del self.table - - # Apply any hooks that have been registered. - for hook in self._db_hooks: - hook(database) - - def set_table_name(self, table_name): - self.table_name = table_name - del self.table - - -class SubclassAwareMetadata(Metadata): - models = [] - - def __init__(self, model, *args, **kwargs): - super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) - self.models.append(model) - - def map_models(self, fn): - for model in self.models: - fn(model) - - -class DoesNotExist(Exception): pass - - -class ModelBase(type): - inheritable = set(['constraints', 'database', 'indexes', 'primary_key', - 'options', 'schema', 'table_function', 'temporary', - 'only_save_dirty', 'legacy_table_names', - 'table_settings']) - - def __new__(cls, name, bases, attrs): - if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: - return super(ModelBase, cls).__new__(cls, name, bases, attrs) - - meta_options = {} - meta = attrs.pop('Meta', None) - if meta: - for k, v in meta.__dict__.items(): - if not k.startswith('_'): - meta_options[k] = v - - pk = getattr(meta, 'primary_key', None) - pk_name = parent_pk = None - - # Inherit any field descriptors by deep copying the underlying field - # into the attrs of the new model, additionally see if the bases define - # inheritable model options and swipe them. - for b in bases: - if not hasattr(b, '_meta'): - continue - - base_meta = b._meta - if parent_pk is None: - parent_pk = deepcopy(base_meta.primary_key) - all_inheritable = cls.inheritable | base_meta._additional_keys - for k in base_meta.__dict__: - if k in all_inheritable and k not in meta_options: - meta_options[k] = base_meta.__dict__[k] - meta_options.setdefault('schema', base_meta.schema) - - for (k, v) in b.__dict__.items(): - if k in attrs: continue - - if isinstance(v, FieldAccessor) and not v.field.primary_key: - attrs[k] = deepcopy(v.field) - - sopts = meta_options.pop('schema_options', None) or {} - Meta = meta_options.get('model_metadata_class', Metadata) - Schema = meta_options.get('schema_manager_class', SchemaManager) - - # Construct the new class. - cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) - cls.__data__ = cls.__rel__ = None - - cls._meta = Meta(cls, **meta_options) - cls._schema = Schema(cls, **sopts) - - fields = [] - for key, value in cls.__dict__.items(): - if isinstance(value, Field): - if value.primary_key and pk: - raise ValueError('over-determined primary key %s.' % name) - elif value.primary_key: - pk, pk_name = value, key - else: - fields.append((key, value)) - - if pk is None: - if parent_pk is not False: - pk, pk_name = ((parent_pk, parent_pk.name) - if parent_pk is not None else - (AutoField(), 'id')) - else: - pk = False - elif isinstance(pk, CompositeKey): - pk_name = '__composite_key__' - cls._meta.composite_key = True - - if pk is not False: - cls._meta.set_primary_key(pk_name, pk) - - for name, field in fields: - cls._meta.add_field(name, field) - - # Create a repr and error class before finalizing. - if hasattr(cls, '__str__') and '__repr__' not in attrs: - setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( - cls.__name__, self.__str__())) - - exc_name = '%sDoesNotExist' % cls.__name__ - exc_attrs = {'__module__': cls.__module__} - exception_class = type(exc_name, (DoesNotExist,), exc_attrs) - cls.DoesNotExist = exception_class - - # Call validation hook, allowing additional model validation. - cls.validate_model() - DeferredForeignKey.resolve(cls) - return cls - - def __repr__(self): - return '' % self.__name__ - - def __iter__(self): - return iter(self.select()) - - def __getitem__(self, key): - return self.get_by_id(key) - - def __setitem__(self, key, value): - self.set_by_id(key, value) - - def __delitem__(self, key): - self.delete_by_id(key) - - def __contains__(self, key): - try: - self.get_by_id(key) - except self.DoesNotExist: - return False - else: - return True - - def __len__(self): - return self.select().count() - def __bool__(self): return True - __nonzero__ = __bool__ # Python 2. - - -class _BoundModelsContext(_callable_context_manager): - def __init__(self, models, database, bind_refs, bind_backrefs): - self.models = models - self.database = database - self.bind_refs = bind_refs - self.bind_backrefs = bind_backrefs - - def __enter__(self): - self._orig_database = [] - for model in self.models: - self._orig_database.append(model._meta.database) - model.bind(self.database, self.bind_refs, self.bind_backrefs) - return self.models - - def __exit__(self, exc_type, exc_val, exc_tb): - for model, db in zip(self.models, self._orig_database): - model.bind(db, self.bind_refs, self.bind_backrefs) - - -class Model(with_metaclass(ModelBase, Node)): - def __init__(self, *args, **kwargs): - if kwargs.pop('__no_default__', None): - self.__data__ = {} - else: - self.__data__ = self._meta.get_default_dict() - self._dirty = set(self.__data__) - self.__rel__ = {} - - for k in kwargs: - setattr(self, k, kwargs[k]) - - def __str__(self): - return str(self._pk) if self._meta.primary_key is not False else 'n/a' - - @classmethod - def validate_model(cls): - pass - - @classmethod - def alias(cls, alias=None): - return ModelAlias(cls, alias) - - @classmethod - def select(cls, *fields): - is_default = not fields - if not fields: - fields = cls._meta.sorted_fields - return ModelSelect(cls, fields, is_default=is_default) - - @classmethod - def _normalize_data(cls, data, kwargs): - normalized = {} - if data: - if not isinstance(data, dict): - if kwargs: - raise ValueError('Data cannot be mixed with keyword ' - 'arguments: %s' % data) - return data - for key in data: - try: - field = (key if isinstance(key, Field) - else cls._meta.combined[key]) - except KeyError: - raise ValueError('Unrecognized field name: "%s" in %s.' % - (key, data)) - normalized[field] = data[key] - if kwargs: - for key in kwargs: - try: - normalized[cls._meta.combined[key]] = kwargs[key] - except KeyError: - normalized[getattr(cls, key)] = kwargs[key] - return normalized - - @classmethod - def update(cls, __data=None, **update): - return ModelUpdate(cls, cls._normalize_data(__data, update)) - - @classmethod - def insert(cls, __data=None, **insert): - return ModelInsert(cls, cls._normalize_data(__data, insert)) - - @classmethod - def insert_many(cls, rows, fields=None): - return ModelInsert(cls, insert=rows, columns=fields) - - @classmethod - def insert_from(cls, query, fields): - columns = [getattr(cls, field) if isinstance(field, basestring) - else field for field in fields] - return ModelInsert(cls, insert=query, columns=columns) - - @classmethod - def replace(cls, __data=None, **insert): - return cls.insert(__data, **insert).on_conflict('REPLACE') - - @classmethod - def replace_many(cls, rows, fields=None): - return (cls - .insert_many(rows=rows, fields=fields) - .on_conflict('REPLACE')) - - @classmethod - def raw(cls, sql, *params): - return ModelRaw(cls, sql, params) - - @classmethod - def delete(cls): - return ModelDelete(cls) - - @classmethod - def create(cls, **query): - inst = cls(**query) - inst.save(force_insert=True) - return inst - - @classmethod - def bulk_create(cls, model_list, batch_size=None): - if batch_size is not None: - batches = chunked(model_list, batch_size) - else: - batches = [model_list] - - field_names = list(cls._meta.sorted_field_names) - if cls._meta.auto_increment: - pk_name = cls._meta.primary_key.name - field_names.remove(pk_name) - ids_returned = cls._meta.database.returning_clause - else: - ids_returned = False - - fields = [cls._meta.fields[field_name] for field_name in field_names] - for batch in batches: - accum = ([getattr(model, f) for f in field_names] - for model in batch) - res = cls.insert_many(accum, fields=fields).execute() - if ids_returned and res is not None: - for (obj_id,), model in zip(res, batch): - setattr(model, pk_name, obj_id) - - @classmethod - def bulk_update(cls, model_list, fields, batch_size=None): - if isinstance(cls._meta.primary_key, CompositeKey): - raise ValueError('bulk_update() is not supported for models with ' - 'a composite primary key.') - - # First normalize list of fields so all are field instances. - fields = [cls._meta.fields[f] if isinstance(f, basestring) else f - for f in fields] - # Now collect list of attribute names to use for values. - attrs = [field.object_id_name if isinstance(field, ForeignKeyField) - else field.name for field in fields] - - if batch_size is not None: - batches = chunked(model_list, batch_size) - else: - batches = [model_list] - - n = 0 - for batch in batches: - id_list = [model._pk for model in batch] - update = {} - for field, attr in zip(fields, attrs): - accum = [] - for model in batch: - value = getattr(model, attr) - if not isinstance(value, Node): - value = Value(value, converter=field.db_value) - accum.append((model._pk, value)) - case = Case(cls._meta.primary_key, accum) - update[field] = case - - n += (cls.update(update) - .where(cls._meta.primary_key.in_(id_list)) - .execute()) - return n - - @classmethod - def noop(cls): - return NoopModelSelect(cls, ()) - - @classmethod - def get(cls, *query, **filters): - sq = cls.select() - if query: - # Handle simple lookup using just the primary key. - if len(query) == 1 and isinstance(query[0], int): - sq = sq.where(cls._meta.primary_key == query[0]) - else: - sq = sq.where(*query) - if filters: - sq = sq.filter(**filters) - return sq.get() - - @classmethod - def get_or_none(cls, *query, **filters): - try: - return cls.get(*query, **filters) - except DoesNotExist: - pass - - @classmethod - def get_by_id(cls, pk): - return cls.get(cls._meta.primary_key == pk) - - @classmethod - def set_by_id(cls, key, value): - if key is None: - return cls.insert(value).execute() - else: - return (cls.update(value) - .where(cls._meta.primary_key == key).execute()) - - @classmethod - def delete_by_id(cls, pk): - return cls.delete().where(cls._meta.primary_key == pk).execute() - - @classmethod - def get_or_create(cls, **kwargs): - defaults = kwargs.pop('defaults', {}) - query = cls.select() - for field, value in kwargs.items(): - query = query.where(getattr(cls, field) == value) - - try: - return query.get(), False - except cls.DoesNotExist: - try: - if defaults: - kwargs.update(defaults) - with cls._meta.database.atomic(): - return cls.create(**kwargs), True - except IntegrityError as exc: - try: - return query.get(), False - except cls.DoesNotExist: - raise exc - - @classmethod - def filter(cls, *dq_nodes, **filters): - return cls.select().filter(*dq_nodes, **filters) - - def get_id(self): - # Using getattr(self, pk-name) could accidentally trigger a query if - # the primary-key is a foreign-key. So we use the safe_name attribute, - # which defaults to the field-name, but will be the object_id_name for - # foreign-key fields. - if self._meta.primary_key is not False: - return getattr(self, self._meta.primary_key.safe_name) - - _pk = property(get_id) - - @_pk.setter - def _pk(self, value): - setattr(self, self._meta.primary_key.name, value) - - def _pk_expr(self): - return self._meta.primary_key == self._pk - - def _prune_fields(self, field_dict, only): - new_data = {} - for field in only: - if isinstance(field, basestring): - field = self._meta.combined[field] - if field.name in field_dict: - new_data[field.name] = field_dict[field.name] - return new_data - - def _populate_unsaved_relations(self, field_dict): - for foreign_key_field in self._meta.refs: - foreign_key = foreign_key_field.name - conditions = ( - foreign_key in field_dict and - field_dict[foreign_key] is None and - self.__rel__.get(foreign_key) is not None) - if conditions: - setattr(self, foreign_key, getattr(self, foreign_key)) - field_dict[foreign_key] = self.__data__[foreign_key] - - def save(self, force_insert=False, only=None): - field_dict = self.__data__.copy() - if self._meta.primary_key is not False: - pk_field = self._meta.primary_key - pk_value = self._pk - else: - pk_field = pk_value = None - if only: - field_dict = self._prune_fields(field_dict, only) - elif self._meta.only_save_dirty and not force_insert: - field_dict = self._prune_fields(field_dict, self.dirty_fields) - if not field_dict: - self._dirty.clear() - return False - - self._populate_unsaved_relations(field_dict) - rows = 1 - - if pk_value is not None and not force_insert: - if self._meta.composite_key: - for pk_part_name in pk_field.field_names: - field_dict.pop(pk_part_name, None) - else: - field_dict.pop(pk_field.name, None) - if not field_dict: - raise ValueError('no data to save!') - rows = self.update(**field_dict).where(self._pk_expr()).execute() - elif pk_field is not None: - pk = self.insert(**field_dict).execute() - if pk is not None and (self._meta.auto_increment or - pk_value is None): - self._pk = pk - else: - self.insert(**field_dict).execute() - - self._dirty.clear() - return rows - - def is_dirty(self): - return bool(self._dirty) - - @property - def dirty_fields(self): - return [f for f in self._meta.sorted_fields if f.name in self._dirty] - - def dependencies(self, search_nullable=False): - model_class = type(self) - stack = [(type(self), None)] - seen = set() - - while stack: - klass, query = stack.pop() - if klass in seen: - continue - seen.add(klass) - for fk, rel_model in klass._meta.backrefs.items(): - if rel_model is model_class or query is None: - node = (fk == self.__data__[fk.rel_field.name]) - else: - node = fk << query - subquery = (rel_model.select(rel_model._meta.primary_key) - .where(node)) - if not fk.null or search_nullable: - stack.append((rel_model, subquery)) - yield (node, fk) - - def delete_instance(self, recursive=False, delete_nullable=False): - if recursive: - dependencies = self.dependencies(delete_nullable) - for query, fk in reversed(list(dependencies)): - model = fk.model - if fk.null and not delete_nullable: - model.update(**{fk.name: None}).where(query).execute() - else: - model.delete().where(query).execute() - return type(self).delete().where(self._pk_expr()).execute() - - def __hash__(self): - return hash((self.__class__, self._pk)) - - def __eq__(self, other): - return ( - other.__class__ == self.__class__ and - self._pk is not None and - self._pk == other._pk) - - def __ne__(self, other): - return not self == other - - def __sql__(self, ctx): - return ctx.sql(Value(getattr(self, self._meta.primary_key.name), - converter=self._meta.primary_key.db_value)) - - @classmethod - def bind(cls, database, bind_refs=True, bind_backrefs=True): - is_different = cls._meta.database is not database - cls._meta.set_database(database) - if bind_refs or bind_backrefs: - G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) - for _, model, is_backref in G: - model._meta.set_database(database) - return is_different - - @classmethod - def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): - return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) - - @classmethod - def table_exists(cls): - M = cls._meta - return cls._schema.database.table_exists(M.table.__name__, M.schema) - - @classmethod - def create_table(cls, safe=True, **options): - if 'fail_silently' in options: - __deprecated__('"fail_silently" has been deprecated in favor of ' - '"safe" for the create_table() method.') - safe = options.pop('fail_silently') - - if safe and not cls._schema.database.safe_create_index \ - and cls.table_exists(): - return - if cls._meta.temporary: - options.setdefault('temporary', cls._meta.temporary) - cls._schema.create_all(safe, **options) - - @classmethod - def drop_table(cls, safe=True, drop_sequences=True, **options): - if safe and not cls._schema.database.safe_drop_index \ - and not cls.table_exists(): - return - if cls._meta.temporary: - options.setdefault('temporary', cls._meta.temporary) - cls._schema.drop_all(safe, drop_sequences, **options) - - @classmethod - def truncate_table(cls, **options): - cls._schema.truncate_table(**options) - - @classmethod - def index(cls, *fields, **kwargs): - return ModelIndex(cls, fields, **kwargs) - - @classmethod - def add_index(cls, *fields, **kwargs): - if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): - cls._meta.indexes.append(fields[0]) - else: - cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) - - -class ModelAlias(Node): - """Provide a separate reference to a model in a query.""" - def __init__(self, model, alias=None): - self.__dict__['model'] = model - self.__dict__['alias'] = alias - - def __getattr__(self, attr): - # Hack to work-around the fact that properties or other objects - # implementing the descriptor protocol (on the model being aliased), - # will not work correctly when we use getattr(). So we explicitly pass - # the model alias to the descriptor's getter. - try: - obj = self.model.__dict__[attr] - except KeyError: - pass - else: - if isinstance(obj, ModelDescriptor): - return obj.__get__(None, self) - - model_attr = getattr(self.model, attr) - if isinstance(model_attr, Field): - self.__dict__[attr] = FieldAlias.create(self, model_attr) - return self.__dict__[attr] - return model_attr - - def __setattr__(self, attr, value): - raise AttributeError('Cannot set attributes on model aliases.') - - def get_field_aliases(self): - return [getattr(self, n) for n in self.model._meta.sorted_field_names] - - def select(self, *selection): - if not selection: - selection = self.get_field_aliases() - return ModelSelect(self, selection) - - def __call__(self, **kwargs): - return self.model(**kwargs) - - def __sql__(self, ctx): - if ctx.scope == SCOPE_VALUES: - # Return the quoted table name. - return ctx.sql(self.model) - - if self.alias: - ctx.alias_manager[self] = self.alias - - if ctx.scope == SCOPE_SOURCE: - # Define the table and its alias. - return (ctx - .sql(self.model._meta.entity) - .literal(' AS ') - .sql(Entity(ctx.alias_manager[self]))) - else: - # Refer to the table using the alias. - return ctx.sql(Entity(ctx.alias_manager[self])) - - -class FieldAlias(Field): - def __init__(self, source, field): - self.source = source - self.model = source.model - self.field = field - - @classmethod - def create(cls, source, field): - class _FieldAlias(cls, type(field)): - pass - return _FieldAlias(source, field) - - def clone(self): - return FieldAlias(self.source, self.field) - - def adapt(self, value): return self.field.adapt(value) - def python_value(self, value): return self.field.python_value(value) - def db_value(self, value): return self.field.db_value(value) - def __getattr__(self, attr): - return self.source if attr == 'model' else getattr(self.field, attr) - - def __sql__(self, ctx): - return ctx.sql(Column(self.source, self.field.column_name)) - - -def sort_models(models): - models = set(models) - seen = set() - ordering = [] - def dfs(model): - if model in models and model not in seen: - seen.add(model) - for foreign_key, rel_model in model._meta.refs.items(): - # Do not depth-first search deferred foreign-keys as this can - # cause tables to be created in the incorrect order. - if not foreign_key.deferred: - dfs(rel_model) - if model._meta.depends_on: - for dependency in model._meta.depends_on: - dfs(dependency) - ordering.append(model) - - names = lambda m: (m._meta.name, m._meta.table_name) - for m in sorted(models, key=names): - dfs(m) - return ordering - - -class _ModelQueryHelper(object): - default_row_type = ROW.MODEL - - def __init__(self, *args, **kwargs): - super(_ModelQueryHelper, self).__init__(*args, **kwargs) - if not self._database: - self._database = self.model._meta.database - - @Node.copy - def objects(self, constructor=None): - self._row_type = ROW.CONSTRUCTOR - self._constructor = self.model if constructor is None else constructor - - def _get_cursor_wrapper(self, cursor): - row_type = self._row_type or self.default_row_type - if row_type == ROW.MODEL: - return self._get_model_cursor_wrapper(cursor) - elif row_type == ROW.DICT: - return ModelDictCursorWrapper(cursor, self.model, self._returning) - elif row_type == ROW.TUPLE: - return ModelTupleCursorWrapper(cursor, self.model, self._returning) - elif row_type == ROW.NAMED_TUPLE: - return ModelNamedTupleCursorWrapper(cursor, self.model, - self._returning) - elif row_type == ROW.CONSTRUCTOR: - return ModelObjectCursorWrapper(cursor, self.model, - self._returning, self._constructor) - else: - raise ValueError('Unrecognized row type: "%s".' % row_type) - - def _get_model_cursor_wrapper(self, cursor): - return ModelObjectCursorWrapper(cursor, self.model, [], self.model) - - -class ModelRaw(_ModelQueryHelper, RawQuery): - def __init__(self, model, sql, params, **kwargs): - self.model = model - self._returning = () - super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) - - def get(self): - try: - return self.execute()[0] - except IndexError: - sql, params = self.sql() - raise self.model.DoesNotExist('%s instance matching query does ' - 'not exist:\nSQL: %s\nParams: %s' % - (self.model, sql, params)) - - -class BaseModelSelect(_ModelQueryHelper): - def union_all(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) - __add__ = union_all - - def union(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) - __or__ = union - - def intersect(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) - __and__ = intersect - - def except_(self, rhs): - return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) - __sub__ = except_ - - def __iter__(self): - if not self._cursor_wrapper: - self.execute() - return iter(self._cursor_wrapper) - - def prefetch(self, *subqueries): - return prefetch(self, *subqueries) - - def get(self, database=None): - clone = self.paginate(1, 1) - clone._cursor_wrapper = None - try: - return clone.execute(database)[0] - except IndexError: - sql, params = clone.sql() - raise self.model.DoesNotExist('%s instance matching query does ' - 'not exist:\nSQL: %s\nParams: %s' % - (clone.model, sql, params)) - - @Node.copy - def group_by(self, *columns): - grouping = [] - for column in columns: - if is_model(column): - grouping.extend(column._meta.sorted_fields) - elif isinstance(column, Table): - if not column._columns: - raise ValueError('Cannot pass a table to group_by() that ' - 'does not have columns explicitly ' - 'declared.') - grouping.extend([getattr(column, col_name) - for col_name in column._columns]) - else: - grouping.append(column) - self._group_by = grouping - - -class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): - def __init__(self, model, *args, **kwargs): - self.model = model - super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) - - def _get_model_cursor_wrapper(self, cursor): - return self.lhs._get_model_cursor_wrapper(cursor) - - -def _normalize_model_select(fields_or_models): - fields = [] - for fm in fields_or_models: - if is_model(fm): - fields.extend(fm._meta.sorted_fields) - elif isinstance(fm, ModelAlias): - fields.extend(fm.get_field_aliases()) - elif isinstance(fm, Table) and fm._columns: - fields.extend([getattr(fm, col) for col in fm._columns]) - else: - fields.append(fm) - return fields - - -class ModelSelect(BaseModelSelect, Select): - def __init__(self, model, fields_or_models, is_default=False): - self.model = self._join_ctx = model - self._joins = {} - self._is_default = is_default - fields = _normalize_model_select(fields_or_models) - super(ModelSelect, self).__init__([model], fields) - - def clone(self): - clone = super(ModelSelect, self).clone() - if clone._joins: - clone._joins = dict(clone._joins) - return clone - - def select(self, *fields_or_models): - if fields_or_models or not self._is_default: - self._is_default = False - fields = _normalize_model_select(fields_or_models) - return super(ModelSelect, self).select(*fields) - return self - - def switch(self, ctx=None): - self._join_ctx = self.model if ctx is None else ctx - return self - - def _get_model(self, src): - if is_model(src): - return src, True - elif isinstance(src, Table) and src._model: - return src._model, False - elif isinstance(src, ModelAlias): - return src.model, False - elif isinstance(src, ModelSelect): - return src.model, False - return None, False - - def _normalize_join(self, src, dest, on, attr): - # Allow "on" expression to have an alias that determines the - # destination attribute for the joined data. - on_alias = isinstance(on, Alias) - if on_alias: - attr = attr or on._alias - on = on.alias() - - # Obtain references to the source and destination models being joined. - src_model, src_is_model = self._get_model(src) - dest_model, dest_is_model = self._get_model(dest) - - if src_model and dest_model: - self._join_ctx = dest - constructor = dest_model - - # In the case where the "on" clause is a Column or Field, we will - # convert that field into the appropriate predicate expression. - if not (src_is_model and dest_is_model) and isinstance(on, Column): - if on.source is src: - to_field = src_model._meta.columns[on.name] - elif on.source is dest: - to_field = dest_model._meta.columns[on.name] - else: - raise AttributeError('"on" clause Column %s does not ' - 'belong to %s or %s.' % - (on, src_model, dest_model)) - on = None - elif isinstance(on, Field): - to_field = on - on = None - else: - to_field = None - - fk_field, is_backref = self._generate_on_clause( - src_model, dest_model, to_field, on) - - if on is None: - src_attr = 'name' if src_is_model else 'column_name' - dest_attr = 'name' if dest_is_model else 'column_name' - if is_backref: - lhs = getattr(dest, getattr(fk_field, dest_attr)) - rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) - else: - lhs = getattr(src, getattr(fk_field, src_attr)) - rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) - on = (lhs == rhs) - - if not attr: - if fk_field is not None and not is_backref: - attr = fk_field.name - else: - attr = dest_model._meta.name - elif on_alias and fk_field is not None and \ - attr == fk_field.object_id_name and not is_backref: - raise ValueError('Cannot assign join alias to "%s", as this ' - 'attribute is the object_id_name for the ' - 'foreign-key field "%s"' % (attr, fk_field)) - - elif isinstance(dest, Source): - constructor = dict - attr = attr or dest._alias - if not attr and isinstance(dest, Table): - attr = attr or dest.__name__ - - return (on, attr, constructor) - - def _generate_on_clause(self, src, dest, to_field=None, on=None): - meta = src._meta - is_backref = fk_fields = False - - # Get all the foreign keys between source and dest, and determine if - # the join is via a back-reference. - if dest in meta.model_refs: - fk_fields = meta.model_refs[dest] - elif dest in meta.model_backrefs: - fk_fields = meta.model_backrefs[dest] - is_backref = True - - if not fk_fields: - if on is not None: - return None, False - raise ValueError('Unable to find foreign key between %s and %s. ' - 'Please specify an explicit join condition.' % - (src, dest)) - elif to_field is not None: - # If the foreign-key field was specified explicitly, remove all - # other foreign-key fields from the list. - target = (to_field.field if isinstance(to_field, FieldAlias) - else to_field) - fk_fields = [f for f in fk_fields if ( - (f is target) or - (is_backref and f.rel_field is to_field))] - - if len(fk_fields) == 1: - return fk_fields[0], is_backref - - if on is None: - # If multiple foreign-keys exist, try using the FK whose name - # matches that of the related model. If not, raise an error as this - # is ambiguous. - for fk in fk_fields: - if fk.name == dest._meta.name: - return fk, is_backref - - raise ValueError('More than one foreign key between %s and %s.' - ' Please specify which you are joining on.' % - (src, dest)) - - # If there are multiple foreign-keys to choose from and the join - # predicate is an expression, we'll try to figure out which - # foreign-key field we're joining on so that we can assign to the - # correct attribute when resolving the model graph. - to_field = None - if isinstance(on, Expression): - lhs, rhs = on.lhs, on.rhs - # Coerce to set() so that we force Python to compare using the - # object's hash rather than equality test, which returns a - # false-positive due to overriding __eq__. - fk_set = set(fk_fields) - - if isinstance(lhs, Field): - lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs - if lhs_f in fk_set: - to_field = lhs_f - elif isinstance(rhs, Field): - rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs - if rhs_f in fk_set: - to_field = rhs_f - - return to_field, False - - @Node.copy - def join(self, dest, join_type='INNER', on=None, src=None, attr=None): - src = self._join_ctx if src is None else src - - if join_type != JOIN.CROSS: - on, attr, constructor = self._normalize_join(src, dest, on, attr) - if attr: - self._joins.setdefault(src, []) - self._joins[src].append((dest, attr, constructor, join_type)) - elif on is not None: - raise ValueError('Cannot specify on clause with cross join.') - - if not self._from_list: - raise ValueError('No sources to join on.') - - item = self._from_list.pop() - self._from_list.append(Join(item, dest, join_type, on)) - - def join_from(self, src, dest, join_type='INNER', on=None, attr=None): - return self.join(dest, join_type, on, src, attr) - - def _get_model_cursor_wrapper(self, cursor): - if len(self._from_list) == 1 and not self._joins: - return ModelObjectCursorWrapper(cursor, self.model, - self._returning, self.model) - return ModelCursorWrapper(cursor, self.model, self._returning, - self._from_list, self._joins) - - def ensure_join(self, lm, rm, on=None, **join_kwargs): - join_ctx = self._join_ctx - for dest, _, constructor, _ in self._joins.get(lm, []): - if dest == rm: - return self - return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) - - def convert_dict_to_node(self, qdict): - accum = [] - joins = [] - fks = (ForeignKeyField, BackrefAccessor) - for key, value in sorted(qdict.items()): - curr = self.model - if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: - key, op = key.rsplit('__', 1) - op = DJANGO_MAP[op] - elif value is None: - op = DJANGO_MAP['is'] - else: - op = DJANGO_MAP['eq'] - - if '__' not in key: - # Handle simplest case. This avoids joining over-eagerly when a - # direct FK lookup is all that is required. - model_attr = getattr(curr, key) - else: - for piece in key.split('__'): - for dest, attr, _, _ in self._joins.get(curr, ()): - if attr == piece or (isinstance(dest, ModelAlias) and - dest.alias == piece): - curr = dest - break - else: - model_attr = getattr(curr, piece) - if value is not None and isinstance(model_attr, fks): - curr = model_attr.rel_model - joins.append(model_attr) - accum.append(op(model_attr, value)) - return accum, joins - - def filter(self, *args, **kwargs): - # normalize args and kwargs into a new expression - dq_node = ColumnBase() - if args: - dq_node &= reduce(operator.and_, [a.clone() for a in args]) - if kwargs: - dq_node &= DQ(**kwargs) - - # dq_node should now be an Expression, lhs = Node(), rhs = ... - q = collections.deque([dq_node]) - dq_joins = set() - while q: - curr = q.popleft() - if not isinstance(curr, Expression): - continue - for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): - if isinstance(piece, DQ): - query, joins = self.convert_dict_to_node(piece.query) - dq_joins.update(joins) - expression = reduce(operator.and_, query) - # Apply values from the DQ object. - if piece._negated: - expression = Negated(expression) - #expression._alias = piece._alias - setattr(curr, side, expression) - else: - q.append(piece) - - dq_node = dq_node.rhs - - query = self.clone() - for field in dq_joins: - if isinstance(field, ForeignKeyField): - lm, rm = field.model, field.rel_model - field_obj = field - elif isinstance(field, BackrefAccessor): - lm, rm = field.model, field.rel_model - field_obj = field.field - query = query.ensure_join(lm, rm, field_obj) - return query.where(dq_node) - - def create_table(self, name, safe=True, **meta): - return self.model._schema.create_table_as(name, self, safe, **meta) - - def __sql_selection__(self, ctx, is_subquery=False): - if self._is_default and is_subquery and len(self._returning) > 1 and \ - self.model._meta.primary_key is not False: - return ctx.sql(self.model._meta.primary_key) - - return ctx.sql(CommaNodeList(self._returning)) - - -class NoopModelSelect(ModelSelect): - def __sql__(self, ctx): - return self.model._meta.database.get_noop_select(ctx) - - def _get_cursor_wrapper(self, cursor): - return CursorWrapper(cursor) - - -class _ModelWriteQueryHelper(_ModelQueryHelper): - def __init__(self, model, *args, **kwargs): - self.model = model - super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) - - def returning(self, *returning): - accum = [] - for item in returning: - if is_model(item): - accum.extend(item._meta.sorted_fields) - else: - accum.append(item) - return super(_ModelWriteQueryHelper, self).returning(*accum) - - def _set_table_alias(self, ctx): - table = self.model._meta.table - ctx.alias_manager[table] = table.__name__ - - -class ModelUpdate(_ModelWriteQueryHelper, Update): - pass - - -class ModelInsert(_ModelWriteQueryHelper, Insert): - default_row_type = ROW.TUPLE - - def __init__(self, *args, **kwargs): - super(ModelInsert, self).__init__(*args, **kwargs) - if self._returning is None and self.model._meta.database is not None: - if self.model._meta.database.returning_clause: - self._returning = self.model._meta.get_primary_keys() - - def returning(self, *returning): - # By default ModelInsert will yield a `tuple` containing the - # primary-key of the newly inserted row. But if we are explicitly - # specifying a returning clause and have not set a row type, we will - # default to returning model instances instead. - if returning and self._row_type is None: - self._row_type = ROW.MODEL - return super(ModelInsert, self).returning(*returning) - - def get_default_data(self): - return self.model._meta.defaults - - def get_default_columns(self): - fields = self.model._meta.sorted_fields - return fields[1:] if self.model._meta.auto_increment else fields - - -class ModelDelete(_ModelWriteQueryHelper, Delete): - pass - - -class ManyToManyQuery(ModelSelect): - def __init__(self, instance, accessor, rel, *args, **kwargs): - self._instance = instance - self._accessor = accessor - self._src_attr = accessor.src_fk.rel_field.name - self._dest_attr = accessor.dest_fk.rel_field.name - super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) - - def _id_list(self, model_or_id_list): - if isinstance(model_or_id_list[0], Model): - return [getattr(obj, self._dest_attr) for obj in model_or_id_list] - return model_or_id_list - - def add(self, value, clear_existing=False): - if clear_existing: - self.clear() - - accessor = self._accessor - src_id = getattr(self._instance, self._src_attr) - if isinstance(value, SelectQuery): - query = value.columns( - Value(src_id), - accessor.dest_fk.rel_field) - accessor.through_model.insert_from( - fields=[accessor.src_fk, accessor.dest_fk], - query=query).execute() - else: - value = ensure_tuple(value) - if not value: return - - inserts = [{ - accessor.src_fk.name: src_id, - accessor.dest_fk.name: rel_id} - for rel_id in self._id_list(value)] - accessor.through_model.insert_many(inserts).execute() - - def remove(self, value): - src_id = getattr(self._instance, self._src_attr) - if isinstance(value, SelectQuery): - column = getattr(value.model, self._dest_attr) - subquery = value.columns(column) - return (self._accessor.through_model - .delete() - .where( - (self._accessor.dest_fk << subquery) & - (self._accessor.src_fk == src_id)) - .execute()) - else: - value = ensure_tuple(value) - if not value: - return - return (self._accessor.through_model - .delete() - .where( - (self._accessor.dest_fk << self._id_list(value)) & - (self._accessor.src_fk == src_id)) - .execute()) - - def clear(self): - src_id = getattr(self._instance, self._src_attr) - return (self._accessor.through_model - .delete() - .where(self._accessor.src_fk == src_id) - .execute()) - - -def safe_python_value(conv_func): - def validate(value): - try: - return conv_func(value) - except (TypeError, ValueError): - return value - return validate - - -class BaseModelCursorWrapper(DictCursorWrapper): - def __init__(self, cursor, model, columns): - super(BaseModelCursorWrapper, self).__init__(cursor) - self.model = model - self.select = columns or [] - - def _initialize_columns(self): - combined = self.model._meta.combined - table = self.model._meta.table - description = self.cursor.description - - self.ncols = len(self.cursor.description) - self.columns = [] - self.converters = converters = [None] * self.ncols - self.fields = fields = [None] * self.ncols - - for idx, description_item in enumerate(description): - column = description_item[0] - dot_index = column.find('.') - if dot_index != -1: - column = column[dot_index + 1:] - - column = column.strip('"') - self.columns.append(column) - try: - raw_node = self.select[idx] - except IndexError: - if column in combined: - raw_node = node = combined[column] - else: - continue - else: - node = raw_node.unwrap() - - # Heuristics used to attempt to get the field associated with a - # given SELECT column, so that we can accurately convert the value - # returned by the database-cursor into a Python object. - if isinstance(node, Field): - if raw_node._coerce: - converters[idx] = node.python_value - fields[idx] = node - if not raw_node.is_alias(): - self.columns[idx] = node.name - elif isinstance(node, Function) and node._coerce: - if node._python_value is not None: - converters[idx] = node._python_value - elif node.arguments and isinstance(node.arguments[0], Node): - # If the first argument is a field or references a column - # on a Model, try using that field's conversion function. - # This usually works, but we use "safe_python_value()" so - # that if a TypeError or ValueError occurs during - # conversion we can just fall-back to the raw cursor value. - first = node.arguments[0].unwrap() - if isinstance(first, Entity): - path = first._path[-1] # Try to look-up by name. - first = combined.get(path) - if isinstance(first, Field): - converters[idx] = safe_python_value(first.python_value) - elif column in combined: - if node._coerce: - converters[idx] = combined[column].python_value - if isinstance(node, Column) and node.source == table: - fields[idx] = combined[column] - - initialize = _initialize_columns - - def process_row(self, row): - raise NotImplementedError - - -class ModelDictCursorWrapper(BaseModelCursorWrapper): - def process_row(self, row): - result = {} - columns, converters = self.columns, self.converters - fields = self.fields - - for i in range(self.ncols): - attr = columns[i] - if attr in result: continue # Don't overwrite if we have dupes. - if converters[i] is not None: - result[attr] = converters[i](row[i]) - else: - result[attr] = row[i] - - return result - - -class ModelTupleCursorWrapper(ModelDictCursorWrapper): - constructor = tuple - - def process_row(self, row): - columns, converters = self.columns, self.converters - return self.constructor([ - (converters[i](row[i]) if converters[i] is not None else row[i]) - for i in range(self.ncols)]) - - -class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): - def initialize(self): - self._initialize_columns() - attributes = [] - for i in range(self.ncols): - attributes.append(self.columns[i]) - self.tuple_class = collections.namedtuple('Row', attributes) - self.constructor = lambda row: self.tuple_class(*row) - - -class ModelObjectCursorWrapper(ModelDictCursorWrapper): - def __init__(self, cursor, model, select, constructor): - self.constructor = constructor - self.is_model = is_model(constructor) - super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) - - def process_row(self, row): - data = super(ModelObjectCursorWrapper, self).process_row(row) - if self.is_model: - # Clear out any dirty fields before returning to the user. - obj = self.constructor(__no_default__=1, **data) - obj._dirty.clear() - return obj - else: - return self.constructor(**data) - - -class ModelCursorWrapper(BaseModelCursorWrapper): - def __init__(self, cursor, model, select, from_list, joins): - super(ModelCursorWrapper, self).__init__(cursor, model, select) - self.from_list = from_list - self.joins = joins - - def initialize(self): - self._initialize_columns() - selected_src = set([field.model for field in self.fields - if field is not None]) - select, columns = self.select, self.columns - - self.key_to_constructor = {self.model: self.model} - self.src_is_dest = {} - self.src_to_dest = [] - accum = collections.deque(self.from_list) - dests = set() - - while accum: - curr = accum.popleft() - if isinstance(curr, Join): - accum.append(curr.lhs) - accum.append(curr.rhs) - continue - - if curr not in self.joins: - continue - - is_dict = isinstance(curr, dict) - for key, attr, constructor, join_type in self.joins[curr]: - if key not in self.key_to_constructor: - self.key_to_constructor[key] = constructor - - # (src, attr, dest, is_dict, join_type). - self.src_to_dest.append((curr, attr, key, is_dict, - join_type)) - dests.add(key) - accum.append(key) - - # Ensure that we accommodate everything selected. - for src in selected_src: - if src not in self.key_to_constructor: - if is_model(src): - self.key_to_constructor[src] = src - elif isinstance(src, ModelAlias): - self.key_to_constructor[src] = src.model - - # Indicate which sources are also dests. - for src, _, dest, _, _ in self.src_to_dest: - self.src_is_dest[src] = src in dests and (dest in selected_src - or src in selected_src) - - self.column_keys = [] - for idx, node in enumerate(select): - key = self.model - field = self.fields[idx] - if field is not None: - if isinstance(field, FieldAlias): - key = field.source - else: - key = field.model - else: - if isinstance(node, Node): - node = node.unwrap() - if isinstance(node, Column): - key = node.source - - self.column_keys.append(key) - - def process_row(self, row): - objects = {} - object_list = [] - for key, constructor in self.key_to_constructor.items(): - objects[key] = constructor(__no_default__=True) - object_list.append(objects[key]) - - set_keys = set() - for idx, key in enumerate(self.column_keys): - instance = objects[key] - column = self.columns[idx] - value = row[idx] - if value is not None: - set_keys.add(key) - if self.converters[idx]: - value = self.converters[idx](value) - - if isinstance(instance, dict): - instance[column] = value - else: - setattr(instance, column, value) - - # Need to do some analysis on the joins before this. - for (src, attr, dest, is_dict, join_type) in self.src_to_dest: - instance = objects[src] - try: - joined_instance = objects[dest] - except KeyError: - continue - - # If no fields were set on the destination instance then do not - # assign an "empty" instance. - if instance is None or dest is None or \ - (dest not in set_keys and not self.src_is_dest.get(dest)): - continue - - # If no fields were set on either the source or the destination, - # then we have nothing to do here. - if instance not in set_keys and dest not in set_keys \ - and join_type.endswith('OUTER'): - continue - - if is_dict: - instance[attr] = joined_instance - else: - setattr(instance, attr, joined_instance) - - # When instantiating models from a cursor, we clear the dirty fields. - for instance in object_list: - if isinstance(instance, Model): - instance._dirty.clear() - - return objects[self.model] - - -class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( - 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): - def __new__(cls, query, fields=None, is_backref=None, rel_models=None, - field_to_name=None, model=None): - if fields: - if is_backref: - if rel_models is None: - rel_models = [field.model for field in fields] - foreign_key_attrs = [field.rel_field.name for field in fields] - else: - if rel_models is None: - rel_models = [field.rel_model for field in fields] - foreign_key_attrs = [field.name for field in fields] - field_to_name = list(zip(fields, foreign_key_attrs)) - model = query.model - return super(PrefetchQuery, cls).__new__( - cls, query, fields, is_backref, rel_models, field_to_name, model) - - def populate_instance(self, instance, id_map): - if self.is_backref: - for field in self.fields: - identifier = instance.__data__[field.name] - key = (field, identifier) - if key in id_map: - setattr(instance, field.name, id_map[key]) - else: - for field, attname in self.field_to_name: - identifier = instance.__data__[field.rel_field.name] - key = (field, identifier) - rel_instances = id_map.get(key, []) - for inst in rel_instances: - setattr(inst, attname, instance) - setattr(instance, field.backref, rel_instances) - - def store_instance(self, instance, id_map): - for field, attname in self.field_to_name: - identity = field.rel_field.python_value(instance.__data__[attname]) - key = (field, identity) - if self.is_backref: - id_map[key] = instance - else: - id_map.setdefault(key, []) - id_map[key].append(instance) - - -def prefetch_add_subquery(sq, subqueries): - fixed_queries = [PrefetchQuery(sq)] - for i, subquery in enumerate(subqueries): - if isinstance(subquery, tuple): - subquery, target_model = subquery - else: - target_model = None - if not isinstance(subquery, Query) and is_model(subquery) or \ - isinstance(subquery, ModelAlias): - subquery = subquery.select() - subquery_model = subquery.model - fks = backrefs = None - for j in reversed(range(i + 1)): - fixed = fixed_queries[j] - last_query = fixed.query - last_model = last_obj = fixed.model - if isinstance(last_model, ModelAlias): - last_model = last_model.model - rels = subquery_model._meta.model_refs.get(last_model, []) - if rels: - fks = [getattr(subquery_model, fk.name) for fk in rels] - pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] - else: - backrefs = subquery_model._meta.model_backrefs.get(last_model) - if (fks or backrefs) and ((target_model is last_obj) or - (target_model is None)): - break - - if not fks and not backrefs: - tgt_err = ' using %s' % target_model if target_model else '' - raise AttributeError('Error: unable to find foreign key for ' - 'query: %s%s' % (subquery, tgt_err)) - - dest = (target_model,) if target_model else None - - if fks: - expr = reduce(operator.or_, [ - (fk << last_query.select(pk)) - for (fk, pk) in zip(fks, pks)]) - subquery = subquery.where(expr) - fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) - elif backrefs: - expressions = [] - for backref in backrefs: - rel_field = getattr(subquery_model, backref.rel_field.name) - fk_field = getattr(last_obj, backref.name) - expressions.append(rel_field << last_query.select(fk_field)) - subquery = subquery.where(reduce(operator.or_, expressions)) - fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) - - return fixed_queries - - -def prefetch(sq, *subqueries): - if not subqueries: - return sq - - fixed_queries = prefetch_add_subquery(sq, subqueries) - deps = {} - rel_map = {} - for pq in reversed(fixed_queries): - query_model = pq.model - if pq.fields: - for rel_model in pq.rel_models: - rel_map.setdefault(rel_model, []) - rel_map[rel_model].append(pq) - - deps.setdefault(query_model, {}) - id_map = deps[query_model] - has_relations = bool(rel_map.get(query_model)) - - for instance in pq.query: - if pq.fields: - pq.store_instance(instance, id_map) - if has_relations: - for rel in rel_map[query_model]: - rel.populate_instance(instance, deps[rel.model]) - - return list(pq.query) diff --git a/libs/playhouse/README.md b/libs/playhouse/README.md deleted file mode 100644 index faebd6902..000000000 --- a/libs/playhouse/README.md +++ /dev/null @@ -1,48 +0,0 @@ -## Playhouse - -The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection. - -### Vendor extensions - -* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html) - * Full-text search (FTS3/4/5) - * BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4 - * Virtual tables and C extensions - * Closure tables - * JSON extension support - * LSM1 (key/value database) support - * BLOB API - * Online backup API -* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver. -* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases. -* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/). -* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext) - * JSON and JSONB - * HStore - * Arrays - * Server-side cursors - * Full-text search -* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext) - -### High-level libraries - -* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields) - * Compressed field - * PickleField -* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts) - * Model-to-dict serializer - * Dict-to-model deserializer -* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid) -* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init. -* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/). -* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries. - -### Database management and framework support - -* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database. -* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite. -* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling. -* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection -* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database -* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications. -* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more. diff --git a/libs/playhouse/__init__.py b/libs/playhouse/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/playhouse/_pysqlite/cache.h b/libs/playhouse/_pysqlite/cache.h deleted file mode 100644 index 06f957a77..000000000 --- a/libs/playhouse/_pysqlite/cache.h +++ /dev/null @@ -1,73 +0,0 @@ -/* cache.h - definitions for the LRU cache - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_CACHE_H -#define PYSQLITE_CACHE_H -#include "Python.h" - -/* The LRU cache is implemented as a combination of a doubly-linked with a - * dictionary. The list items are of type 'Node' and the dictionary has the - * nodes as values. */ - -typedef struct _pysqlite_Node -{ - PyObject_HEAD - PyObject* key; - PyObject* data; - long count; - struct _pysqlite_Node* prev; - struct _pysqlite_Node* next; -} pysqlite_Node; - -typedef struct -{ - PyObject_HEAD - int size; - - /* a dictionary mapping keys to Node entries */ - PyObject* mapping; - - /* the factory callable */ - PyObject* factory; - - pysqlite_Node* first; - pysqlite_Node* last; - - /* if set, decrement the factory function when the Cache is deallocated. - * this is almost always desirable, but not in the pysqlite context */ - int decref_factory; -} pysqlite_Cache; - -extern PyTypeObject pysqlite_NodeType; -extern PyTypeObject pysqlite_CacheType; - -int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs); -void pysqlite_node_dealloc(pysqlite_Node* self); - -int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs); -void pysqlite_cache_dealloc(pysqlite_Cache* self); -PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args); - -int pysqlite_cache_setup_types(void); - -#endif diff --git a/libs/playhouse/_pysqlite/connection.h b/libs/playhouse/_pysqlite/connection.h deleted file mode 100644 index d35c13f9a..000000000 --- a/libs/playhouse/_pysqlite/connection.h +++ /dev/null @@ -1,129 +0,0 @@ -/* connection.h - definitions for the connection type - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_CONNECTION_H -#define PYSQLITE_CONNECTION_H -#include "Python.h" -#include "pythread.h" -#include "structmember.h" - -#include "cache.h" -#include "module.h" - -#include "sqlite3.h" - -typedef struct -{ - PyObject_HEAD - sqlite3* db; - - /* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a - * bitwise combination thereof makes sense */ - int detect_types; - - /* the timeout value in seconds for database locks */ - double timeout; - - /* for internal use in the timeout handler: when did the timeout handler - * first get called with count=0? */ - double timeout_started; - - /* None for autocommit, otherwise a PyString with the isolation level */ - PyObject* isolation_level; - - /* NULL for autocommit, otherwise a string with the BEGIN statement; will be - * freed in connection destructor */ - char* begin_statement; - - /* 1 if a check should be performed for each API call if the connection is - * used from the same thread it was created in */ - int check_same_thread; - - int initialized; - - /* thread identification of the thread the connection was created in */ - long thread_ident; - - pysqlite_Cache* statement_cache; - - /* Lists of weak references to statements and cursors used within this connection */ - PyObject* statements; - PyObject* cursors; - - /* Counters for how many statements/cursors were created in the connection. May be - * reset to 0 at certain intervals */ - int created_statements; - int created_cursors; - - PyObject* row_factory; - - /* Determines how bytestrings from SQLite are converted to Python objects: - * - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings - * - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created. - * - PyString_Type: PyStrings are created as-is. - * - Any custom callable: Any object returned from the callable called with the bytestring - * as single parameter. - */ - PyObject* text_factory; - - /* remember references to functions/classes used in - * create_function/create/aggregate, use these as dictionary keys, so we - * can keep the total system refcount constant by clearing that dictionary - * in connection_dealloc */ - PyObject* function_pinboard; - - /* a dictionary of registered collation name => collation callable mappings */ - PyObject* collations; - - /* Exception objects */ - PyObject* Warning; - PyObject* Error; - PyObject* InterfaceError; - PyObject* DatabaseError; - PyObject* DataError; - PyObject* OperationalError; - PyObject* IntegrityError; - PyObject* InternalError; - PyObject* ProgrammingError; - PyObject* NotSupportedError; -} pysqlite_Connection; - -extern PyTypeObject pysqlite_ConnectionType; - -PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware); -void pysqlite_connection_dealloc(pysqlite_Connection* self); -PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); -PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args); -PyObject* _pysqlite_connection_begin(pysqlite_Connection* self); -PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args); -PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args); -PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw); -int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); - -int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor); -int pysqlite_check_thread(pysqlite_Connection* self); -int pysqlite_check_connection(pysqlite_Connection* con); - -int pysqlite_connection_setup_types(void); - -#endif diff --git a/libs/playhouse/_pysqlite/module.h b/libs/playhouse/_pysqlite/module.h deleted file mode 100644 index 08c566257..000000000 --- a/libs/playhouse/_pysqlite/module.h +++ /dev/null @@ -1,58 +0,0 @@ -/* module.h - definitions for the module - * - * Copyright (C) 2004-2015 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ - -#ifndef PYSQLITE_MODULE_H -#define PYSQLITE_MODULE_H -#include "Python.h" - -#define PYSQLITE_VERSION "2.8.2" - -extern PyObject* pysqlite_Error; -extern PyObject* pysqlite_Warning; -extern PyObject* pysqlite_InterfaceError; -extern PyObject* pysqlite_DatabaseError; -extern PyObject* pysqlite_InternalError; -extern PyObject* pysqlite_OperationalError; -extern PyObject* pysqlite_ProgrammingError; -extern PyObject* pysqlite_IntegrityError; -extern PyObject* pysqlite_DataError; -extern PyObject* pysqlite_NotSupportedError; - -extern PyObject* pysqlite_OptimizedUnicode; - -/* the functions time.time() and time.sleep() */ -extern PyObject* time_time; -extern PyObject* time_sleep; - -/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter - * functions, that convert the SQL value to the appropriate Python value. - * The key is uppercase. - */ -extern PyObject* converters; - -extern int _enable_callback_tracebacks; -extern int pysqlite_BaseTypeAdapted; - -#define PARSE_DECLTYPES 1 -#define PARSE_COLNAMES 2 -#endif diff --git a/libs/playhouse/_sqlite_ext.pyx b/libs/playhouse/_sqlite_ext.pyx deleted file mode 100644 index 7fa3949e0..000000000 --- a/libs/playhouse/_sqlite_ext.pyx +++ /dev/null @@ -1,1579 +0,0 @@ -import hashlib -import zlib - -cimport cython -from cpython cimport datetime -from cpython.bytes cimport PyBytes_AsStringAndSize -from cpython.bytes cimport PyBytes_Check -from cpython.bytes cimport PyBytes_FromStringAndSize -from cpython.bytes cimport PyBytes_AS_STRING -from cpython.object cimport PyObject -from cpython.ref cimport Py_INCREF, Py_DECREF -from cpython.unicode cimport PyUnicode_AsUTF8String -from cpython.unicode cimport PyUnicode_Check -from cpython.unicode cimport PyUnicode_DecodeUTF8 -from cpython.version cimport PY_MAJOR_VERSION -from libc.float cimport DBL_MAX -from libc.math cimport ceil, log, sqrt -from libc.math cimport pow as cpow -#from libc.stdint cimport ssize_t -from libc.stdint cimport uint8_t -from libc.stdint cimport uint32_t -from libc.stdlib cimport calloc, free, malloc, rand -from libc.string cimport memcpy, memset, strlen - -from peewee import InterfaceError -from peewee import Node -from peewee import OperationalError -from peewee import sqlite3 as pysqlite - -import traceback - - -cdef struct sqlite3_index_constraint: - int iColumn # Column constrained, -1 for rowid. - unsigned char op # Constraint operator. - unsigned char usable # True if this constraint is usable. - int iTermOffset # Used internally - xBestIndex should ignore. - - -cdef struct sqlite3_index_orderby: - int iColumn - unsigned char desc - - -cdef struct sqlite3_index_constraint_usage: - int argvIndex # if > 0, constraint is part of argv to xFilter. - unsigned char omit - - -cdef extern from "sqlite3.h" nogil: - ctypedef struct sqlite3: - int busyTimeout - ctypedef struct sqlite3_backup - ctypedef struct sqlite3_blob - ctypedef struct sqlite3_context - ctypedef struct sqlite3_value - ctypedef long long sqlite3_int64 - ctypedef unsigned long long sqlite_uint64 - - # Virtual tables. - ctypedef struct sqlite3_module # Forward reference. - ctypedef struct sqlite3_vtab: - const sqlite3_module *pModule - int nRef - char *zErrMsg - ctypedef struct sqlite3_vtab_cursor: - sqlite3_vtab *pVtab - - ctypedef struct sqlite3_index_info: - int nConstraint - sqlite3_index_constraint *aConstraint - int nOrderBy - sqlite3_index_orderby *aOrderBy - sqlite3_index_constraint_usage *aConstraintUsage - int idxNum - char *idxStr - int needToFreeIdxStr - int orderByConsumed - double estimatedCost - sqlite3_int64 estimatedRows - int idxFlags - - ctypedef struct sqlite3_module: - int iVersion - int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVTab, char**) - int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVTab, char**) - int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*) - int (*xDisconnect)(sqlite3_vtab *pVTab) - int (*xDestroy)(sqlite3_vtab *pVTab) - int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) - int (*xClose)(sqlite3_vtab_cursor*) - int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, - int argc, sqlite3_value **argv) - int (*xNext)(sqlite3_vtab_cursor*) - int (*xEof)(sqlite3_vtab_cursor*) - int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context *, int) - int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid) - int (*xUpdate)(sqlite3_vtab *pVTab, int, sqlite3_value **, - sqlite3_int64 **) - int (*xBegin)(sqlite3_vtab *pVTab) - int (*xSync)(sqlite3_vtab *pVTab) - int (*xCommit)(sqlite3_vtab *pVTab) - int (*xRollback)(sqlite3_vtab *pVTab) - int (*xFindFunction)(sqlite3_vtab *pVTab, int nArg, const char *zName, - void (**pxFunc)(sqlite3_context *, int, - sqlite3_value **), - void **ppArg) - int (*xRename)(sqlite3_vtab *pVTab, const char *zNew) - int (*xSavepoint)(sqlite3_vtab *pVTab, int) - int (*xRelease)(sqlite3_vtab *pVTab, int) - int (*xRollbackTo)(sqlite3_vtab *pVTab, int) - - cdef int sqlite3_declare_vtab(sqlite3 *db, const char *zSQL) - cdef int sqlite3_create_module(sqlite3 *db, const char *zName, - const sqlite3_module *p, void *pClientData) - - cdef const char sqlite3_version[] - - # Encoding. - cdef int SQLITE_UTF8 = 1 - - # Return values. - cdef int SQLITE_OK = 0 - cdef int SQLITE_ERROR = 1 - cdef int SQLITE_INTERNAL = 2 - cdef int SQLITE_PERM = 3 - cdef int SQLITE_ABORT = 4 - cdef int SQLITE_BUSY = 5 - cdef int SQLITE_LOCKED = 6 - cdef int SQLITE_NOMEM = 7 - cdef int SQLITE_READONLY = 8 - cdef int SQLITE_INTERRUPT = 9 - cdef int SQLITE_DONE = 101 - - # Function type. - cdef int SQLITE_DETERMINISTIC = 0x800 - - # Types of filtering operations. - cdef int SQLITE_INDEX_CONSTRAINT_EQ = 2 - cdef int SQLITE_INDEX_CONSTRAINT_GT = 4 - cdef int SQLITE_INDEX_CONSTRAINT_LE = 8 - cdef int SQLITE_INDEX_CONSTRAINT_LT = 16 - cdef int SQLITE_INDEX_CONSTRAINT_GE = 32 - cdef int SQLITE_INDEX_CONSTRAINT_MATCH = 64 - - # sqlite_value_type. - cdef int SQLITE_INTEGER = 1 - cdef int SQLITE_FLOAT = 2 - cdef int SQLITE3_TEXT = 3 - cdef int SQLITE_TEXT = 3 - cdef int SQLITE_BLOB = 4 - cdef int SQLITE_NULL = 5 - - ctypedef void (*sqlite3_destructor_type)(void*) - - # Converting from Sqlite -> Python. - cdef const void *sqlite3_value_blob(sqlite3_value*) - cdef int sqlite3_value_bytes(sqlite3_value*) - cdef double sqlite3_value_double(sqlite3_value*) - cdef int sqlite3_value_int(sqlite3_value*) - cdef sqlite3_int64 sqlite3_value_int64(sqlite3_value*) - cdef const unsigned char *sqlite3_value_text(sqlite3_value*) - cdef int sqlite3_value_type(sqlite3_value*) - cdef int sqlite3_value_numeric_type(sqlite3_value*) - - # Converting from Python -> Sqlite. - cdef void sqlite3_result_blob(sqlite3_context*, const void *, int, - void(*)(void*)) - cdef void sqlite3_result_double(sqlite3_context*, double) - cdef void sqlite3_result_error(sqlite3_context*, const char*, int) - cdef void sqlite3_result_error_toobig(sqlite3_context*) - cdef void sqlite3_result_error_nomem(sqlite3_context*) - cdef void sqlite3_result_error_code(sqlite3_context*, int) - cdef void sqlite3_result_int(sqlite3_context*, int) - cdef void sqlite3_result_int64(sqlite3_context*, sqlite3_int64) - cdef void sqlite3_result_null(sqlite3_context*) - cdef void sqlite3_result_text(sqlite3_context*, const char*, int, - void(*)(void*)) - cdef void sqlite3_result_value(sqlite3_context*, sqlite3_value*) - - # Memory management. - cdef void* sqlite3_malloc(int) - cdef void sqlite3_free(void *) - - cdef int sqlite3_changes(sqlite3 *db) - cdef int sqlite3_get_autocommit(sqlite3 *db) - cdef sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *db) - - cdef void *sqlite3_commit_hook(sqlite3 *, int(*)(void *), void *) - cdef void *sqlite3_rollback_hook(sqlite3 *, void(*)(void *), void *) - cdef void *sqlite3_update_hook( - sqlite3 *, - void(*)(void *, int, char *, char *, sqlite3_int64), - void *) - - cdef int SQLITE_STATUS_MEMORY_USED = 0 - cdef int SQLITE_STATUS_PAGECACHE_USED = 1 - cdef int SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 - cdef int SQLITE_STATUS_SCRATCH_USED = 3 - cdef int SQLITE_STATUS_SCRATCH_OVERFLOW = 4 - cdef int SQLITE_STATUS_MALLOC_SIZE = 5 - cdef int SQLITE_STATUS_PARSER_STACK = 6 - cdef int SQLITE_STATUS_PAGECACHE_SIZE = 7 - cdef int SQLITE_STATUS_SCRATCH_SIZE = 8 - cdef int SQLITE_STATUS_MALLOC_COUNT = 9 - cdef int sqlite3_status(int op, int *pCurrent, int *pHighwater, int resetFlag) - - cdef int SQLITE_DBSTATUS_LOOKASIDE_USED = 0 - cdef int SQLITE_DBSTATUS_CACHE_USED = 1 - cdef int SQLITE_DBSTATUS_SCHEMA_USED = 2 - cdef int SQLITE_DBSTATUS_STMT_USED = 3 - cdef int SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 - cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 - cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 - cdef int SQLITE_DBSTATUS_CACHE_HIT = 7 - cdef int SQLITE_DBSTATUS_CACHE_MISS = 8 - cdef int SQLITE_DBSTATUS_CACHE_WRITE = 9 - cdef int SQLITE_DBSTATUS_DEFERRED_FKS = 10 - #cdef int SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 - cdef int sqlite3_db_status(sqlite3 *, int op, int *pCur, int *pHigh, int reset) - - cdef int SQLITE_DELETE = 9 - cdef int SQLITE_INSERT = 18 - cdef int SQLITE_UPDATE = 23 - - cdef int SQLITE_CONFIG_SINGLETHREAD = 1 # None - cdef int SQLITE_CONFIG_MULTITHREAD = 2 # None - cdef int SQLITE_CONFIG_SERIALIZED = 3 # None - cdef int SQLITE_CONFIG_SCRATCH = 6 # void *, int sz, int N - cdef int SQLITE_CONFIG_PAGECACHE = 7 # void *, int sz, int N - cdef int SQLITE_CONFIG_HEAP = 8 # void *, int nByte, int min - cdef int SQLITE_CONFIG_MEMSTATUS = 9 # boolean - cdef int SQLITE_CONFIG_LOOKASIDE = 13 # int, int - cdef int SQLITE_CONFIG_URI = 17 # int - cdef int SQLITE_CONFIG_MMAP_SIZE = 22 # sqlite3_int64, sqlite3_int64 - cdef int SQLITE_CONFIG_STMTJRNL_SPILL = 26 # int nByte - cdef int SQLITE_DBCONFIG_MAINDBNAME = 1000 # const char* - cdef int SQLITE_DBCONFIG_LOOKASIDE = 1001 # void* int int - cdef int SQLITE_DBCONFIG_ENABLE_FKEY = 1002 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005 # int int* - cdef int SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006 # int int* - cdef int SQLITE_DBCONFIG_ENABLE_QPSG = 1007 # int int* - - cdef int sqlite3_config(int, ...) - cdef int sqlite3_db_config(sqlite3*, int op, ...) - - # Misc. - cdef int sqlite3_busy_handler(sqlite3 *db, int(*)(void *, int), void *) - cdef int sqlite3_sleep(int ms) - cdef sqlite3_backup *sqlite3_backup_init( - sqlite3 *pDest, - const char *zDestName, - sqlite3 *pSource, - const char *zSourceName) - - # Backup. - cdef int sqlite3_backup_step(sqlite3_backup *p, int nPage) - cdef int sqlite3_backup_finish(sqlite3_backup *p) - cdef int sqlite3_backup_remaining(sqlite3_backup *p) - cdef int sqlite3_backup_pagecount(sqlite3_backup *p) - - # Error handling. - cdef int sqlite3_errcode(sqlite3 *db) - cdef int sqlite3_errstr(int) - cdef const char *sqlite3_errmsg(sqlite3 *db) - - cdef int sqlite3_blob_open( - sqlite3*, - const char *zDb, - const char *zTable, - const char *zColumn, - sqlite3_int64 iRow, - int flags, - sqlite3_blob **ppBlob) - cdef int sqlite3_blob_reopen(sqlite3_blob *, sqlite3_int64) - cdef int sqlite3_blob_close(sqlite3_blob *) - cdef int sqlite3_blob_bytes(sqlite3_blob *) - cdef int sqlite3_blob_read(sqlite3_blob *, void *Z, int N, int iOffset) - cdef int sqlite3_blob_write(sqlite3_blob *, const void *z, int n, - int iOffset) - - -cdef extern from "_pysqlite/connection.h": - ctypedef struct pysqlite_Connection: - sqlite3* db - double timeout - int initialized - - -cdef sqlite_to_python(int argc, sqlite3_value **params): - cdef: - int i - int vtype - list pyargs = [] - - for i in range(argc): - vtype = sqlite3_value_type(params[i]) - if vtype == SQLITE_INTEGER: - pyval = sqlite3_value_int(params[i]) - elif vtype == SQLITE_FLOAT: - pyval = sqlite3_value_double(params[i]) - elif vtype == SQLITE_TEXT: - pyval = PyUnicode_DecodeUTF8( - sqlite3_value_text(params[i]), - sqlite3_value_bytes(params[i]), NULL) - elif vtype == SQLITE_BLOB: - pyval = PyBytes_FromStringAndSize( - sqlite3_value_blob(params[i]), - sqlite3_value_bytes(params[i])) - elif vtype == SQLITE_NULL: - pyval = None - else: - pyval = None - - pyargs.append(pyval) - - return pyargs - - -cdef python_to_sqlite(sqlite3_context *context, value): - if value is None: - sqlite3_result_null(context) - elif isinstance(value, (int, long)): - sqlite3_result_int64(context, value) - elif isinstance(value, float): - sqlite3_result_double(context, value) - elif isinstance(value, unicode): - bval = PyUnicode_AsUTF8String(value) - sqlite3_result_text( - context, - bval, - len(bval), - -1) - elif isinstance(value, bytes): - if PY_MAJOR_VERSION > 2: - sqlite3_result_blob( - context, - (value), - len(value), - -1) - else: - sqlite3_result_text( - context, - value, - len(value), - -1) - else: - sqlite3_result_error( - context, - encode('Unsupported type %s' % type(value)), - -1) - return SQLITE_ERROR - - return SQLITE_OK - - -cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. - -USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' - -# The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field -# to store a reference to the Python implementation. -ctypedef struct peewee_vtab: - sqlite3_vtab base - void *table_func_cls - - -# Like peewee_vtab, the peewee_cursor embeds the base sqlite3_vtab_cursor and -# adds fields to store references to the current index, the Python -# implementation, the current rows' data, and a flag for whether the cursor has -# been exhausted. -ctypedef struct peewee_cursor: - sqlite3_vtab_cursor base - long long idx - void *table_func - void *row_data - bint stopped - - -# We define an xConnect function, but leave xCreate NULL so that the -# table-function can be called eponymously. -cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, - sqlite3_vtab **ppVtab, char **pzErr) with gil: - cdef: - int rc - object table_func_cls = pAux - peewee_vtab *pNew = 0 - - rc = sqlite3_declare_vtab( - db, - encode('CREATE TABLE x(%s);' % - table_func_cls.get_table_columns_declaration())) - if rc == SQLITE_OK: - pNew = sqlite3_malloc(sizeof(pNew[0])) - memset(pNew, 0, sizeof(pNew[0])) - ppVtab[0] = &(pNew.base) - - pNew.table_func_cls = table_func_cls - Py_INCREF(table_func_cls) - - return rc - - -cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: - cdef: - peewee_vtab *pVtab = pBase - object table_func_cls = (pVtab.table_func_cls) - - Py_DECREF(table_func_cls) - sqlite3_free(pVtab) - return SQLITE_OK - - -# The xOpen method is used to initialize a cursor. In this method we -# instantiate the TableFunction class and zero out a new cursor for iteration. -cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: - cdef: - peewee_vtab *pVtab = pBase - peewee_cursor *pCur = 0 - object table_func_cls = pVtab.table_func_cls - - pCur = sqlite3_malloc(sizeof(pCur[0])) - memset(pCur, 0, sizeof(pCur[0])) - ppCursor[0] = &(pCur.base) - pCur.idx = 0 - try: - table_func = table_func_cls() - except: - if table_func_cls.print_tracebacks: - traceback.print_exc() - sqlite3_free(pCur) - return SQLITE_ERROR - - Py_INCREF(table_func) - pCur.table_func = table_func - pCur.stopped = False - return SQLITE_OK - - -cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - Py_DECREF(table_func) - sqlite3_free(pCur) - return SQLITE_OK - - -# Iterate once, advancing the cursor's index and assigning the row data to the -# `row_data` field on the peewee_cursor struct. -cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - tuple result - - if pCur.row_data: - Py_DECREF(pCur.row_data) - - pCur.row_data = NULL - try: - result = tuple(table_func.iterate(pCur.idx)) - except StopIteration: - pCur.stopped = True - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - else: - Py_INCREF(result) - pCur.row_data = result - pCur.idx += 1 - pCur.stopped = False - - return SQLITE_OK - - -# Return the requested column from the current row. -cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, - int iCol) with gil: - cdef: - bytes bval - peewee_cursor *pCur = pBase - sqlite3_int64 x = 0 - tuple row_data - - if iCol == -1: - sqlite3_result_int64(ctx, pCur.idx) - return SQLITE_OK - - if not pCur.row_data: - sqlite3_result_error(ctx, encode('no row data'), -1) - return SQLITE_ERROR - - row_data = pCur.row_data - return python_to_sqlite(ctx, row_data[iCol]) - - -cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): - cdef: - peewee_cursor *pCur = pBase - pRowid[0] = pCur.idx - return SQLITE_OK - - -# Return a boolean indicating whether the cursor has been consumed. -cdef int pwEof(sqlite3_vtab_cursor *pBase): - cdef: - peewee_cursor *pCur = pBase - return 1 if pCur.stopped else 0 - - -# The filter method is called on the first iteration. This method is where we -# get access to the parameters that the function was called with, and call the -# TableFunction's `initialize()` function. -cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum, - const char *idxStr, int argc, sqlite3_value **argv) with gil: - cdef: - peewee_cursor *pCur = pBase - object table_func = pCur.table_func - dict query = {} - int idx - int value_type - tuple row_data - void *row_data_raw - - if not idxStr or argc == 0 and len(table_func.params): - return SQLITE_ERROR - elif len(idxStr): - params = decode(idxStr).split(',') - else: - params = [] - - py_values = sqlite_to_python(argc, argv) - - for idx, param in enumerate(params): - value = argv[idx] - if not value: - query[param] = None - else: - query[param] = py_values[idx] - - try: - table_func.initialize(**query) - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - - pCur.stopped = False - try: - row_data = tuple(table_func.iterate(0)) - except StopIteration: - pCur.stopped = True - except: - if table_func.print_tracebacks: - traceback.print_exc() - return SQLITE_ERROR - else: - Py_INCREF(row_data) - pCur.row_data = row_data - pCur.idx += 1 - return SQLITE_OK - - -# SQLite will (in some cases, repeatedly) call the xBestIndex method to try and -# find the best query plan. -cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ - with gil: - cdef: - int i - int idxNum = 0, nArg = 0 - peewee_vtab *pVtab = pBase - object table_func_cls = pVtab.table_func_cls - sqlite3_index_constraint *pConstraint = 0 - list columns = [] - char *idxStr - int nParams = len(table_func_cls.params) - - for i in range(pIdxInfo.nConstraint): - pConstraint = pIdxInfo.aConstraint + i - if not pConstraint.usable: - continue - if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: - continue - - columns.append(table_func_cls.params[pConstraint.iColumn - - table_func_cls._ncols]) - nArg += 1 - pIdxInfo.aConstraintUsage[i].argvIndex = nArg - pIdxInfo.aConstraintUsage[i].omit = 1 - - if nArg > 0 or nParams == 0: - if nArg == nParams: - # All parameters are present, this is ideal. - pIdxInfo.estimatedCost = 1 - pIdxInfo.estimatedRows = 10 - else: - # Penalize score based on number of missing params. - pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) - pIdxInfo.estimatedRows = 10 ** (nParams - nArg) - - # Store a reference to the columns in the index info structure. - joinedCols = encode(','.join(columns)) - idxStr = sqlite3_malloc((len(joinedCols) + 1) * sizeof(char)) - memcpy(idxStr, joinedCols, len(joinedCols)) - idxStr[len(joinedCols)] = '\x00' - pIdxInfo.idxStr = idxStr - pIdxInfo.needToFreeIdxStr = 0 - elif USE_SQLITE_CONSTRAINT: - return SQLITE_CONSTRAINT - else: - pIdxInfo.estimatedCost = DBL_MAX - pIdxInfo.estimatedRows = 100000 - return SQLITE_OK - - -cdef class _TableFunctionImpl(object): - cdef: - sqlite3_module module - object table_function - - def __cinit__(self, table_function): - self.table_function = table_function - - cdef create_module(self, pysqlite_Connection* sqlite_conn): - cdef: - bytes name = encode(self.table_function.name) - sqlite3 *db = sqlite_conn.db - int rc - - # Populate the SQLite module struct members. - self.module.iVersion = 0 - self.module.xCreate = NULL - self.module.xConnect = pwConnect - self.module.xBestIndex = pwBestIndex - self.module.xDisconnect = pwDisconnect - self.module.xDestroy = NULL - self.module.xOpen = pwOpen - self.module.xClose = pwClose - self.module.xFilter = pwFilter - self.module.xNext = pwNext - self.module.xEof = pwEof - self.module.xColumn = pwColumn - self.module.xRowid = pwRowid - self.module.xUpdate = NULL - self.module.xBegin = NULL - self.module.xSync = NULL - self.module.xCommit = NULL - self.module.xRollback = NULL - self.module.xFindFunction = NULL - self.module.xRename = NULL - - # Create the SQLite virtual table. - rc = sqlite3_create_module( - db, - name, - &self.module, - (self.table_function)) - - Py_INCREF(self) - - return rc == SQLITE_OK - - -class TableFunction(object): - columns = None - params = None - name = None - print_tracebacks = True - _ncols = None - - @classmethod - def register(cls, conn): - cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) - impl.create_module(conn) - cls._ncols = len(cls.columns) - - def initialize(self, **filters): - raise NotImplementedError - - def iterate(self, idx): - raise NotImplementedError - - @classmethod - def get_table_columns_declaration(cls): - cdef list accum = [] - - for column in cls.columns: - if isinstance(column, tuple): - if len(column) != 2: - raise ValueError('Column must be either a string or a ' - '2-tuple of name, type') - accum.append('%s %s' % column) - else: - accum.append(column) - - for param in cls.params: - accum.append('%s HIDDEN' % param) - - return ', '.join(accum) - - -cdef tuple SQLITE_DATETIME_FORMATS = ( - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d', - '%H:%M:%S', - '%H:%M:%S.%f', - '%H:%M') - -cdef dict SQLITE_DATE_TRUNC_MAPPING = { - 'year': '%Y', - 'month': '%Y-%m', - 'day': '%Y-%m-%d', - 'hour': '%Y-%m-%d %H', - 'minute': '%Y-%m-%d %H:%M', - 'second': '%Y-%m-%d %H:%M:%S'} - - -cdef tuple validate_and_format_datetime(lookup, date_str): - if not date_str or not lookup: - return - - lookup = lookup.lower() - if lookup not in SQLITE_DATE_TRUNC_MAPPING: - return - - cdef datetime.datetime date_obj - cdef bint success = False - - for date_format in SQLITE_DATETIME_FORMATS: - try: - date_obj = datetime.datetime.strptime(date_str, date_format) - except ValueError: - pass - else: - return (date_obj, lookup) - - -cdef inline bytes encode(key): - cdef bytes bkey - if PyUnicode_Check(key): - bkey = PyUnicode_AsUTF8String(key) - elif PyBytes_Check(key): - bkey = key - elif key is None: - return None - else: - bkey = PyUnicode_AsUTF8String(str(key)) - return bkey - - -cdef inline unicode decode(key): - cdef unicode ukey - if PyBytes_Check(key): - ukey = key.decode('utf-8') - elif PyUnicode_Check(key): - ukey = key - elif key is None: - return None - else: - ukey = unicode(key) - return ukey - - -cdef double *get_weights(int ncol, tuple raw_weights): - cdef: - int argc = len(raw_weights) - int icol - double *weights = malloc(sizeof(double) * ncol) - - for icol in range(ncol): - if argc == 0: - weights[icol] = 1.0 - elif icol < argc: - weights[icol] = raw_weights[icol] - else: - weights[icol] = 0.0 - return weights - - -def peewee_rank(py_match_info, *raw_weights): - cdef: - unsigned int *match_info - unsigned int *phrase_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol, icol, iphrase, hits, global_hits - int P_O = 0, C_O = 1, X_O = 2 - double score = 0.0, weight - double *weights - - match_info = match_info_buf - nphrase = match_info[P_O] - ncol = match_info[C_O] - weights = get_weights(ncol, raw_weights) - - # matchinfo X value corresponds to, for each phrase in the search query, a - # list of 3 values for each column in the search table. - # So if we have a two-phrase search query and three columns of data, the - # following would be the layout: - # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] - # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - for iphrase in range(nphrase): - phrase_info = &match_info[X_O + iphrase * ncol * 3] - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - # The idea is that we count the number of times the phrase appears - # in this column of the current row, compared to how many times it - # appears in this column across all rows. The ratio of these values - # provides a rough way to score based on "high value" terms. - hits = phrase_info[3 * icol] - global_hits = phrase_info[3 * icol + 1] - if hits > 0: - score += weight * (hits / global_hits) - - free(weights) - return -1 * score - - -def peewee_lucene(py_match_info, *raw_weights): - # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double total_docs, term_frequency - double doc_length, docs_with_term, avg_length - double idf, weight, rhs, denom - double *weights - int P_O = 0, C_O = 1, N_O = 2, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - nphrase = match_info[P_O] - ncol = match_info[C_O] - total_docs = match_info[N_O] - - L_O = 3 + ncol - X_O = L_O + ncol - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - doc_length = match_info[L_O + icol] - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi) - docs_with_term = match_info[x + 2] or 1. # n(qi) - idf = log(total_docs / (docs_with_term + 1.)) - tf = sqrt(term_frequency) - fieldNorms = 1.0 / sqrt(doc_length) - score += (idf * tf * fieldNorms) - - free(weights) - return -1 * score - - -def peewee_bm25(py_match_info, *raw_weights): - # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) - # where the second parameter is the index of the column and - # the 3rd and 4th specify k and b. - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double B = 0.75, K = 1.2 - double total_docs, term_frequency - double doc_length, docs_with_term, avg_length - double idf, weight, ratio, num, b_part, denom, pc_score - double *weights - int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - # PCNALX = matchinfo format. - # P = 1 = phrase count within query. - # C = 1 = searchable columns in table. - # N = 1 = total rows in table. - # A = c = for each column, avg number of tokens - # L = c = for each column, length of current row (in tokens) - # X = 3 * c * p = for each phrase and table column, - # * phrase count within column for current row. - # * phrase count within column for all rows. - # * total rows for which column contains phrase. - nphrase = match_info[P_O] # n - ncol = match_info[C_O] - total_docs = match_info[N_O] # N - - L_O = A_O + ncol - X_O = L_O + ncol - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi, D) - docs_with_term = match_info[x + 2] # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - if idf <= 0.0: - idf = 1e-6 - - doc_length = match_info[L_O + icol] # |D| - avg_length = match_info[A_O + icol] # avgdl - if avg_length == 0: - avg_length = 1 - ratio = doc_length / avg_length - - num = term_frequency * (K + 1) - b_part = 1 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * (num / denom) - score += (pc_score * weight) - - free(weights) - return -1 * score - - -def peewee_bm25f(py_match_info, *raw_weights): - # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) - # where the second parameter is the index of the column and - # the 3rd and 4th specify k and b. - cdef: - unsigned int *match_info - bytes _match_info_buf = bytes(py_match_info) - char *match_info_buf = _match_info_buf - int nphrase, ncol - double B = 0.75, K = 1.2, epsilon - double total_docs, term_frequency, docs_with_term - double doc_length = 0.0, avg_length = 0.0 - double idf, weight, ratio, num, b_part, denom, pc_score - double *weights - int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O - int iphrase, icol, x - double score = 0.0 - - match_info = match_info_buf - nphrase = match_info[P_O] # n - ncol = match_info[C_O] - total_docs = match_info[N_O] # N - - L_O = A_O + ncol - X_O = L_O + ncol - - for icol in range(ncol): - avg_length += match_info[A_O + icol] - doc_length += match_info[L_O + icol] - - epsilon = 1.0 / (total_docs * avg_length) - if avg_length == 0: - avg_length = 1 - ratio = doc_length / avg_length - weights = get_weights(ncol, raw_weights) - - for iphrase in range(nphrase): - for icol in range(ncol): - weight = weights[icol] - if weight == 0: - continue - - x = X_O + (3 * (icol + iphrase * ncol)) - term_frequency = match_info[x] # f(qi, D) - docs_with_term = match_info[x + 2] # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - idf = epsilon if idf <= 0 else idf - - num = term_frequency * (K + 1) - b_part = 1 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * ((num / denom) + 1.) - score += (pc_score * weight) - - free(weights) - return -1 * score - - -cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen, - uint32_t seed): - cdef: - uint32_t m = 0x5bd1e995 - int r = 24 - const unsigned char *data = key - uint32_t h = seed ^ nlen - uint32_t k - - while nlen >= 4: - k = ((data)[0]) - - k *= m - k = k ^ (k >> r) - k *= m - - h *= m - h = h ^ k - - data += 4 - nlen -= 4 - - if nlen == 3: - h = h ^ (data[2] << 16) - if nlen >= 2: - h = h ^ (data[1] << 8) - if nlen >= 1: - h = h ^ (data[0]) - h *= m - - h = h ^ (h >> 13) - h *= m - h = h ^ (h >> 15) - return h - - -def peewee_murmurhash(key, seed=None): - if key is None: - return - - cdef: - bytes bkey = encode(key) - int nseed = seed or 0 - - if key: - return murmurhash2(bkey, len(bkey), nseed) - return 0 - - -def make_hash(hash_impl): - def inner(*items): - state = hash_impl() - for item in items: - state.update(encode(item)) - return state.hexdigest() - return inner - - -peewee_md5 = make_hash(hashlib.md5) -peewee_sha1 = make_hash(hashlib.sha1) -peewee_sha256 = make_hash(hashlib.sha256) - - -def _register_functions(database, pairs): - for func, name in pairs: - database.register_function(func, name) - - -def register_hash_functions(database): - _register_functions(database, ( - (peewee_murmurhash, 'murmurhash'), - (peewee_md5, 'md5'), - (peewee_sha1, 'sha1'), - (peewee_sha256, 'sha256'), - (zlib.adler32, 'adler32'), - (zlib.crc32, 'crc32'))) - - -def register_rank_functions(database): - _register_functions(database, ( - (peewee_bm25, 'fts_bm25'), - (peewee_bm25f, 'fts_bm25f'), - (peewee_lucene, 'fts_lucene'), - (peewee_rank, 'fts_rank'))) - - -ctypedef struct bf_t: - void *bits - size_t size - -cdef int seeds[10] -seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b] - - -cdef bf_t *bf_create(size_t size): - cdef bf_t *bf = calloc(1, sizeof(bf_t)) - bf.size = size - bf.bits = malloc(size) - return bf - -@cython.cdivision(True) -cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): - cdef: - uint32_t h = murmurhash2(key, klen, seed) - return h % (bf.size * 8) - -@cython.cdivision(True) -cdef bf_add(bf_t *bf, unsigned char *key): - cdef: - uint8_t *bits = (bf.bits) - uint32_t h - int pos, seed - size_t keylen = strlen(key) - - for seed in seeds: - h = bf_bitindex(bf, key, keylen, seed) - pos = h / 8 - bits[pos] = bits[pos] | (1 << (h % 8)) - -@cython.cdivision(True) -cdef int bf_contains(bf_t *bf, unsigned char *key): - cdef: - uint8_t *bits = (bf.bits) - uint32_t h - int pos, seed - size_t keylen = strlen(key) - - for seed in seeds: - h = bf_bitindex(bf, key, keylen, seed) - pos = h / 8 - if not (bits[pos] & (1 << (h % 8))): - return 0 - return 1 - -cdef bf_free(bf_t *bf): - free(bf.bits) - free(bf) - - -cdef class BloomFilter(object): - cdef: - bf_t *bf - - def __init__(self, size=1024 * 32): - self.bf = bf_create(size) - - def __dealloc__(self): - if self.bf: - bf_free(self.bf) - - def add(self, *keys): - cdef bytes bkey - - for key in keys: - bkey = encode(key) - bf_add(self.bf, bkey) - - def __contains__(self, key): - cdef bytes bkey = encode(key) - return bf_contains(self.bf, bkey) - - def to_buffer(self): - # We have to do this so that embedded NULL bytes are preserved. - cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), - self.bf.size) - # Similarly we wrap in a buffer object so pysqlite preserves the - # embedded NULL bytes. - return buf - - @classmethod - def calculate_size(cls, double n, double p): - cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) - return m - - -cdef class BloomFilterAggregate(object): - cdef: - BloomFilter bf - - def __init__(self): - self.bf = None - - def step(self, value, size=None): - if not self.bf: - size = size or 1024 - self.bf = BloomFilter(size) - - self.bf.add(value) - - def finalize(self): - if not self.bf: - return None - - return pysqlite.Binary(self.bf.to_buffer()) - - -def peewee_bloomfilter_contains(key, data): - cdef: - bf_t bf - bytes bkey - bytes bdata = bytes(data) - unsigned char *cdata = bdata - - bf.size = len(data) - bf.bits = cdata - bkey = encode(key) - - return bf_contains(&bf, bkey) - - -def peewee_bloomfilter_calculate_size(n_items, error_p): - return BloomFilter.calculate_size(n_items, error_p) - - -def register_bloomfilter(database): - database.register_aggregate(BloomFilterAggregate, 'bloomfilter') - database.register_function(peewee_bloomfilter_contains, - 'bloomfilter_contains') - database.register_function(peewee_bloomfilter_calculate_size, - 'bloomfilter_calculate_size') - - -cdef inline int _check_connection(pysqlite_Connection *conn) except -1: - """ - Check that the underlying SQLite database connection is usable. Raises an - InterfaceError if the connection is either uninitialized or closed. - """ - if not conn.db: - raise InterfaceError('Cannot operate on closed database.') - return 1 - - -class ZeroBlob(Node): - def __init__(self, length): - if not isinstance(length, int) or length < 0: - raise ValueError('Length must be a positive integer.') - self.length = length - - def __sql__(self, ctx): - return ctx.literal('zeroblob(%s)' % self.length) - - -cdef class Blob(object) # Forward declaration. - - -cdef inline int _check_blob_closed(Blob blob) except -1: - if not blob.pBlob: - raise InterfaceError('Cannot operate on closed blob.') - return 1 - - -cdef class Blob(object): - cdef: - int offset - pysqlite_Connection *conn - sqlite3_blob *pBlob - - def __init__(self, database, table, column, rowid, - read_only=False): - cdef: - bytes btable = encode(table) - bytes bcolumn = encode(column) - int flags = 0 if read_only else 1 - int rc - sqlite3_blob *blob - - self.conn = (database._state.conn) - _check_connection(self.conn) - - rc = sqlite3_blob_open( - self.conn.db, - 'main', - btable, - bcolumn, - rowid, - flags, - &blob) - if rc != SQLITE_OK: - raise OperationalError('Unable to open blob.') - if not blob: - raise MemoryError('Unable to allocate blob.') - - self.pBlob = blob - self.offset = 0 - - cdef _close(self): - if self.pBlob: - sqlite3_blob_close(self.pBlob) - self.pBlob = 0 - - def __dealloc__(self): - self._close() - - def __len__(self): - _check_blob_closed(self) - return sqlite3_blob_bytes(self.pBlob) - - def read(self, n=None): - cdef: - bytes pybuf - int length = -1 - int size - char *buf - - if n is not None: - length = n - - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - if self.offset == size or length == 0: - return b'' - - if length < 0: - length = size - self.offset - - if self.offset + length > size: - length = size - self.offset - - pybuf = PyBytes_FromStringAndSize(NULL, length) - buf = PyBytes_AS_STRING(pybuf) - if sqlite3_blob_read(self.pBlob, buf, length, self.offset): - self._close() - raise OperationalError('Error reading from blob.') - - self.offset += length - return bytes(pybuf) - - def seek(self, offset, frame_of_reference=0): - cdef int size - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - if frame_of_reference == 0: - if offset < 0 or offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset = offset - elif frame_of_reference == 1: - if self.offset + offset < 0 or self.offset + offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset += offset - elif frame_of_reference == 2: - if size + offset < 0 or size + offset > size: - raise ValueError('seek() offset outside of valid range.') - self.offset = size + offset - else: - raise ValueError('seek() frame of reference must be 0, 1 or 2.') - - def tell(self): - _check_blob_closed(self) - return self.offset - - def write(self, bytes data): - cdef: - char *buf - int size - Py_ssize_t buflen - - _check_blob_closed(self) - size = sqlite3_blob_bytes(self.pBlob) - PyBytes_AsStringAndSize(data, &buf, &buflen) - if ((buflen + self.offset)) < self.offset: - raise ValueError('Data is too large (integer wrap)') - if ((buflen + self.offset)) > size: - raise ValueError('Data would go beyond end of blob') - if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): - raise OperationalError('Error writing to blob.') - self.offset += buflen - - def close(self): - self._close() - - def reopen(self, rowid): - _check_blob_closed(self) - self.offset = 0 - if sqlite3_blob_reopen(self.pBlob, rowid): - self._close() - raise OperationalError('Unable to re-open blob.') - - -def sqlite_get_status(flag): - cdef: - int current, highwater, rc - - rc = sqlite3_status(flag, ¤t, &highwater, 0) - if rc == SQLITE_OK: - return (current, highwater) - raise Exception('Error requesting status: %s' % rc) - - -def sqlite_get_db_status(conn, flag): - cdef: - int current, highwater, rc - pysqlite_Connection *c_conn = conn - - rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) - if rc == SQLITE_OK: - return (current, highwater) - raise Exception('Error requesting db status: %s' % rc) - - -cdef class ConnectionHelper(object): - cdef: - object _commit_hook, _rollback_hook, _update_hook - pysqlite_Connection *conn - - def __init__(self, connection): - self.conn = connection - self._commit_hook = self._rollback_hook = self._update_hook = None - - def __dealloc__(self): - # When deallocating a Database object, we need to ensure that we clear - # any commit, rollback or update hooks that may have been applied. - if not self.conn.initialized or not self.conn.db: - return - - if self._commit_hook is not None: - sqlite3_commit_hook(self.conn.db, NULL, NULL) - if self._rollback_hook is not None: - sqlite3_rollback_hook(self.conn.db, NULL, NULL) - if self._update_hook is not None: - sqlite3_update_hook(self.conn.db, NULL, NULL) - - def set_commit_hook(self, fn): - self._commit_hook = fn - if fn is None: - sqlite3_commit_hook(self.conn.db, NULL, NULL) - else: - sqlite3_commit_hook(self.conn.db, _commit_callback, fn) - - def set_rollback_hook(self, fn): - self._rollback_hook = fn - if fn is None: - sqlite3_rollback_hook(self.conn.db, NULL, NULL) - else: - sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) - - def set_update_hook(self, fn): - self._update_hook = fn - if fn is None: - sqlite3_update_hook(self.conn.db, NULL, NULL) - else: - sqlite3_update_hook(self.conn.db, _update_callback, fn) - - def set_busy_handler(self, timeout=5): - """ - Replace the default busy handler with one that introduces some "jitter" - into the amount of time delayed between checks. - """ - cdef sqlite3_int64 n = timeout * 1000 - sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) - return True - - def changes(self): - return sqlite3_changes(self.conn.db) - - def last_insert_rowid(self): - return sqlite3_last_insert_rowid(self.conn.db) - - def autocommit(self): - return sqlite3_get_autocommit(self.conn.db) != 0 - - -cdef int _commit_callback(void *userData) with gil: - # C-callback that delegates to the Python commit handler. If the Python - # function raises a ValueError, then the commit is aborted and the - # transaction rolled back. Otherwise, regardless of the function return - # value, the transaction will commit. - cdef object fn = userData - try: - fn() - except ValueError: - return 1 - else: - return SQLITE_OK - - -cdef void _rollback_callback(void *userData) with gil: - # C-callback that delegates to the Python rollback handler. - cdef object fn = userData - fn() - - -cdef void _update_callback(void *userData, int queryType, const char *database, - const char *table, sqlite3_int64 rowid) with gil: - # C-callback that delegates to a Python function that is executed whenever - # the database is updated (insert/update/delete queries). The Python - # callback receives a string indicating the query type, the name of the - # database, the name of the table being updated, and the rowid of the row - # being updatd. - cdef object fn = userData - if queryType == SQLITE_INSERT: - query = 'INSERT' - elif queryType == SQLITE_UPDATE: - query = 'UPDATE' - elif queryType == SQLITE_DELETE: - query = 'DELETE' - else: - query = '' - fn(query, decode(database), decode(table), rowid) - - -def backup(src_conn, dest_conn, pages=None, name=None, progress=None): - cdef: - bytes bname = encode(name or 'main') - int page_step = pages or -1 - int rc - pysqlite_Connection *src = src_conn - pysqlite_Connection *dest = dest_conn - sqlite3 *src_db = src.db - sqlite3 *dest_db = dest.db - sqlite3_backup *backup - - # We always backup to the "main" database in the dest db. - backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) - if backup == NULL: - raise OperationalError('Unable to initialize backup.') - - while True: - with nogil: - rc = sqlite3_backup_step(backup, page_step) - if progress is not None: - # Progress-handler is called with (remaining, page count, is done?) - remaining = sqlite3_backup_remaining(backup) - page_count = sqlite3_backup_pagecount(backup) - try: - progress(remaining, page_count, rc == SQLITE_DONE) - except: - sqlite3_backup_finish(backup) - raise - if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: - with nogil: - sqlite3_sleep(250) - elif rc == SQLITE_DONE: - break - - with nogil: - sqlite3_backup_finish(backup) - if sqlite3_errcode(dest_db): - raise OperationalError('Error backuping up database: %s' % - sqlite3_errmsg(dest_db)) - return True - - -def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): - dest_conn = pysqlite.connect(filename) - backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) - dest_conn.close() - return True - - -cdef int _aggressive_busy_handler(void *ptr, int n) nogil: - # In concurrent environments, it often seems that if multiple queries are - # kicked off at around the same time, they proceed in lock-step to check - # for the availability of the lock. By introducing some "jitter" we can - # ensure that this doesn't happen. Furthermore, this function makes more - # attempts in the same time period than the default handler. - cdef: - sqlite3_int64 busyTimeout = ptr - int current, total - - if n < 20: - current = 25 - (rand() % 10) # ~20ms - total = n * 20 - elif n < 40: - current = 50 - (rand() % 20) # ~40ms - total = 400 + ((n - 20) * 40) - else: - current = 120 - (rand() % 40) # ~100ms - total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. - - if total + current > busyTimeout: - current = busyTimeout - total - if current > 0: - sqlite3_sleep(current) - return 1 - return 0 diff --git a/libs/playhouse/_sqlite_udf.pyx b/libs/playhouse/_sqlite_udf.pyx deleted file mode 100644 index 9ff6e7430..000000000 --- a/libs/playhouse/_sqlite_udf.pyx +++ /dev/null @@ -1,137 +0,0 @@ -import sys -from difflib import SequenceMatcher -from random import randint - - -IS_PY3K = sys.version_info[0] == 3 - -# String UDF. -def damerau_levenshtein_dist(s1, s2): - cdef: - int i, j, del_cost, add_cost, sub_cost - int s1_len = len(s1), s2_len = len(s2) - list one_ago, two_ago, current_row - list zeroes = [0] * (s2_len + 1) - - if IS_PY3K: - current_row = list(range(1, s2_len + 2)) - else: - current_row = range(1, s2_len + 2) - - current_row[-1] = 0 - one_ago = None - - for i in range(s1_len): - two_ago = one_ago - one_ago = current_row - current_row = list(zeroes) - current_row[-1] = i + 1 - for j in range(s2_len): - del_cost = one_ago[j] + 1 - add_cost = current_row[j - 1] + 1 - sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) - current_row[j] = min(del_cost, add_cost, sub_cost) - - # Handle transpositions. - if (i > 0 and j > 0 and s1[i] == s2[j - 1] - and s1[i-1] == s2[j] and s1[i] != s2[j]): - current_row[j] = min(current_row[j], two_ago[j - 2] + 1) - - return current_row[s2_len - 1] - -# String UDF. -def levenshtein_dist(a, b): - cdef: - int add, delete, change - int i, j - int n = len(a), m = len(b) - list current, previous - list zeroes - - if n > m: - a, b = b, a - n, m = m, n - - zeroes = [0] * (m + 1) - - if IS_PY3K: - current = list(range(n + 1)) - else: - current = range(n + 1) - - for i in range(1, m + 1): - previous = current - current = list(zeroes) - current[0] = i - - for j in range(1, n + 1): - add = previous[j] + 1 - delete = current[j - 1] + 1 - change = previous[j - 1] - if a[j - 1] != b[i - 1]: - change +=1 - current[j] = min(add, delete, change) - - return current[n] - -# String UDF. -def str_dist(a, b): - cdef: - int t = 0 - - for i in SequenceMatcher(None, a, b).get_opcodes(): - if i[0] == 'equal': - continue - t = t + max(i[4] - i[3], i[2] - i[1]) - return t - -# Math Aggregate. -cdef class median(object): - cdef: - int ct - list items - - def __init__(self): - self.ct = 0 - self.items = [] - - cdef selectKth(self, int k, int s=0, int e=-1): - cdef: - int idx - if e < 0: - e = len(self.items) - idx = randint(s, e-1) - idx = self.partition_k(idx, s, e) - if idx > k: - return self.selectKth(k, s, idx) - elif idx < k: - return self.selectKth(k, idx + 1, e) - else: - return self.items[idx] - - cdef int partition_k(self, int pi, int s, int e): - cdef: - int i, x - - val = self.items[pi] - # Swap pivot w/last item. - self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] - x = s - for i in range(s, e): - if self.items[i] < val: - self.items[i], self.items[x] = self.items[x], self.items[i] - x += 1 - self.items[x], self.items[e-1] = self.items[e-1], self.items[x] - return x - - def step(self, item): - self.items.append(item) - self.ct += 1 - - def finalize(self): - if self.ct == 0: - return None - elif self.ct < 3: - return self.items[0] - else: - return self.selectKth(self.ct / 2) diff --git a/libs/playhouse/apsw_ext.py b/libs/playhouse/apsw_ext.py deleted file mode 100644 index 0aa35939b..000000000 --- a/libs/playhouse/apsw_ext.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Peewee integration with APSW, "another python sqlite wrapper". - -Project page: https://rogerbinns.github.io/apsw/ - -APSW is a really neat library that provides a thin wrapper on top of SQLite's -C interface. - -Here are just a few reasons to use APSW, taken from the documentation: - -* APSW gives all functionality of SQLite, including virtual tables, virtual - file system, blob i/o, backups and file control. -* Connections can be shared across threads without any additional locking. -* Transactions are managed explicitly by your code. -* APSW can handle nested transactions. -* Unicode is handled correctly. -* APSW is faster. -""" -import apsw -from peewee import * -from peewee import __exception_wrapper__ -from peewee import BooleanField as _BooleanField -from peewee import DateField as _DateField -from peewee import DateTimeField as _DateTimeField -from peewee import DecimalField as _DecimalField -from peewee import TimeField as _TimeField -from peewee import logger - -from playhouse.sqlite_ext import SqliteExtDatabase - - -class APSWDatabase(SqliteExtDatabase): - server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) - - def __init__(self, database, **kwargs): - self._modules = {} - super(APSWDatabase, self).__init__(database, **kwargs) - - def register_module(self, mod_name, mod_inst): - self._modules[mod_name] = mod_inst - if not self.is_closed(): - self.connection().createmodule(mod_name, mod_inst) - - def unregister_module(self, mod_name): - del(self._modules[mod_name]) - - def _connect(self): - conn = apsw.Connection(self.database, **self.connect_params) - if self._timeout is not None: - conn.setbusytimeout(self._timeout * 1000) - try: - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def _add_conn_hooks(self, conn): - super(APSWDatabase, self)._add_conn_hooks(conn) - self._load_modules(conn) # APSW-only. - - def _load_modules(self, conn): - for mod_name, mod_inst in self._modules.items(): - conn.createmodule(mod_name, mod_inst) - return conn - - def _load_aggregates(self, conn): - for name, (klass, num_params) in self._aggregates.items(): - def make_aggregate(): - return (klass(), klass.step, klass.finalize) - conn.createaggregatefunction(name, make_aggregate) - - def _load_collations(self, conn): - for name, fn in self._collations.items(): - conn.createcollation(name, fn) - - def _load_functions(self, conn): - for name, (fn, num_params) in self._functions.items(): - conn.createscalarfunction(name, fn, num_params) - - def _load_extensions(self, conn): - conn.enableloadextension(True) - for extension in self._extensions: - conn.loadextension(extension) - - def load_extension(self, extension): - self._extensions.add(extension) - if not self.is_closed(): - conn = self.connection() - conn.enableloadextension(True) - conn.loadextension(extension) - - def last_insert_id(self, cursor, query_type=None): - return cursor.getconnection().last_insert_rowid() - - def rows_affected(self, cursor): - return cursor.getconnection().changes() - - def begin(self, lock_type='deferred'): - self.cursor().execute('begin %s;' % lock_type) - - def commit(self): - self.cursor().execute('commit;') - - def rollback(self): - self.cursor().execute('rollback;') - - def execute_sql(self, sql, params=None, commit=True): - logger.debug((sql, params)) - with __exception_wrapper__: - cursor = self.cursor() - cursor.execute(sql, params or ()) - return cursor - - -def nh(s, v): - if v is not None: - return str(v) - -class BooleanField(_BooleanField): - def db_value(self, v): - v = super(BooleanField, self).db_value(v) - if v is not None: - return v and 1 or 0 - -class DateField(_DateField): - db_value = nh - -class TimeField(_TimeField): - db_value = nh - -class DateTimeField(_DateTimeField): - db_value = nh - -class DecimalField(_DecimalField): - db_value = nh diff --git a/libs/playhouse/dataset.py b/libs/playhouse/dataset.py deleted file mode 100644 index f5bbf8b28..000000000 --- a/libs/playhouse/dataset.py +++ /dev/null @@ -1,452 +0,0 @@ -import csv -import datetime -from decimal import Decimal -import json -import operator -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse -import sys - -from peewee import * -from playhouse.db_url import connect -from playhouse.migrate import migrate -from playhouse.migrate import SchemaMigrator -from playhouse.reflection import Introspector - -if sys.version_info[0] == 3: - basestring = str - from functools import reduce - def open_file(f, mode): - return open(f, mode, encoding='utf8') -else: - open_file = open - - -class DataSet(object): - def __init__(self, url, bare_fields=False): - if isinstance(url, Database): - self._url = None - self._database = url - self._database_path = self._database.database - else: - self._url = url - parse_result = urlparse(url) - self._database_path = parse_result.path[1:] - - # Connect to the database. - self._database = connect(url) - - self._database.connect() - - # Introspect the database and generate models. - self._introspector = Introspector.from_database(self._database) - self._models = self._introspector.generate_models( - skip_invalid=True, - literal_column_names=True, - bare_fields=bare_fields) - self._migrator = SchemaMigrator.from_database(self._database) - - class BaseModel(Model): - class Meta: - database = self._database - self._base_model = BaseModel - self._export_formats = self.get_export_formats() - self._import_formats = self.get_import_formats() - - def __repr__(self): - return '' % self._database_path - - def get_export_formats(self): - return { - 'csv': CSVExporter, - 'json': JSONExporter, - 'tsv': TSVExporter} - - def get_import_formats(self): - return { - 'csv': CSVImporter, - 'json': JSONImporter, - 'tsv': TSVImporter} - - def __getitem__(self, table): - if table not in self._models and table in self.tables: - self.update_cache(table) - return Table(self, table, self._models.get(table)) - - @property - def tables(self): - return self._database.get_tables() - - def __contains__(self, table): - return table in self.tables - - def connect(self): - self._database.connect() - - def close(self): - self._database.close() - - def update_cache(self, table=None): - if table: - dependencies = [table] - if table in self._models: - model_class = self._models[table] - dependencies.extend([ - related._meta.table_name for _, related, _ in - model_class._meta.model_graph()]) - else: - dependencies.extend(self.get_table_dependencies(table)) - else: - dependencies = None # Update all tables. - self._models = {} - updated = self._introspector.generate_models( - skip_invalid=True, - table_names=dependencies, - literal_column_names=True) - self._models.update(updated) - - def get_table_dependencies(self, table): - stack = [table] - accum = [] - seen = set() - while stack: - table = stack.pop() - for fk_meta in self._database.get_foreign_keys(table): - dest = fk_meta.dest_table - if dest not in seen: - stack.append(dest) - accum.append(dest) - return accum - - def __enter__(self): - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not self._database.is_closed(): - self.close() - - def query(self, sql, params=None, commit=True): - return self._database.execute_sql(sql, params, commit) - - def transaction(self): - if self._database.transaction_depth() == 0: - return self._database.transaction() - else: - return self._database.savepoint() - - def _check_arguments(self, filename, file_obj, format, format_dict): - if filename and file_obj: - raise ValueError('file is over-specified. Please use either ' - 'filename or file_obj, but not both.') - if not filename and not file_obj: - raise ValueError('A filename or file-like object must be ' - 'specified.') - if format not in format_dict: - valid_formats = ', '.join(sorted(format_dict.keys())) - raise ValueError('Unsupported format "%s". Use one of %s.' % ( - format, valid_formats)) - - def freeze(self, query, format='csv', filename=None, file_obj=None, - **kwargs): - self._check_arguments(filename, file_obj, format, self._export_formats) - if filename: - file_obj = open_file(filename, 'w') - - exporter = self._export_formats[format](query) - exporter.export(file_obj, **kwargs) - - if filename: - file_obj.close() - - def thaw(self, table, format='csv', filename=None, file_obj=None, - strict=False, **kwargs): - self._check_arguments(filename, file_obj, format, self._export_formats) - if filename: - file_obj = open_file(filename, 'r') - - importer = self._import_formats[format](self[table], strict) - count = importer.load(file_obj, **kwargs) - - if filename: - file_obj.close() - - return count - - -class Table(object): - def __init__(self, dataset, name, model_class): - self.dataset = dataset - self.name = name - if model_class is None: - model_class = self._create_model() - model_class.create_table() - self.dataset._models[name] = model_class - - @property - def model_class(self): - return self.dataset._models[self.name] - - def __repr__(self): - return '' % self.name - - def __len__(self): - return self.find().count() - - def __iter__(self): - return iter(self.find().iterator()) - - def _create_model(self): - class Meta: - table_name = self.name - return type( - str(self.name), - (self.dataset._base_model,), - {'Meta': Meta}) - - def create_index(self, columns, unique=False): - self.dataset._database.create_index( - self.model_class, - columns, - unique=unique) - - def _guess_field_type(self, value): - if isinstance(value, basestring): - return TextField - if isinstance(value, (datetime.date, datetime.datetime)): - return DateTimeField - elif value is True or value is False: - return BooleanField - elif isinstance(value, int): - return IntegerField - elif isinstance(value, float): - return FloatField - elif isinstance(value, Decimal): - return DecimalField - return TextField - - @property - def columns(self): - return [f.name for f in self.model_class._meta.sorted_fields] - - def _migrate_new_columns(self, data): - new_keys = set(data) - set(self.model_class._meta.fields) - if new_keys: - operations = [] - for key in new_keys: - field_class = self._guess_field_type(data[key]) - field = field_class(null=True) - operations.append( - self.dataset._migrator.add_column(self.name, key, field)) - field.bind(self.model_class, key) - - migrate(*operations) - - self.dataset.update_cache(self.name) - - def __getitem__(self, item): - try: - return self.model_class[item] - except self.model_class.DoesNotExist: - pass - - def __setitem__(self, item, value): - if not isinstance(value, dict): - raise ValueError('Table.__setitem__() value must be a dict') - - pk = self.model_class._meta.primary_key - value[pk.name] = item - - try: - with self.dataset.transaction() as txn: - self.insert(**value) - except IntegrityError: - self.dataset.update_cache(self.name) - self.update(columns=[pk.name], **value) - - def __delitem__(self, item): - del self.model_class[item] - - def insert(self, **data): - self._migrate_new_columns(data) - return self.model_class.insert(**data).execute() - - def _apply_where(self, query, filters, conjunction=None): - conjunction = conjunction or operator.and_ - if filters: - expressions = [ - (self.model_class._meta.fields[column] == value) - for column, value in filters.items()] - query = query.where(reduce(conjunction, expressions)) - return query - - def update(self, columns=None, conjunction=None, **data): - self._migrate_new_columns(data) - filters = {} - if columns: - for column in columns: - filters[column] = data.pop(column) - - return self._apply_where( - self.model_class.update(**data), - filters, - conjunction).execute() - - def _query(self, **query): - return self._apply_where(self.model_class.select(), query) - - def find(self, **query): - return self._query(**query).dicts() - - def find_one(self, **query): - try: - return self.find(**query).get() - except self.model_class.DoesNotExist: - return None - - def all(self): - return self.find() - - def delete(self, **query): - return self._apply_where(self.model_class.delete(), query).execute() - - def freeze(self, *args, **kwargs): - return self.dataset.freeze(self.all(), *args, **kwargs) - - def thaw(self, *args, **kwargs): - return self.dataset.thaw(self.name, *args, **kwargs) - - -class Exporter(object): - def __init__(self, query): - self.query = query - - def export(self, file_obj): - raise NotImplementedError - - -class JSONExporter(Exporter): - def __init__(self, query, iso8601_datetimes=False): - super(JSONExporter, self).__init__(query) - self.iso8601_datetimes = iso8601_datetimes - - def _make_default(self): - datetime_types = (datetime.datetime, datetime.date, datetime.time) - - if self.iso8601_datetimes: - def default(o): - if isinstance(o, datetime_types): - return o.isoformat() - elif isinstance(o, Decimal): - return str(o) - raise TypeError('Unable to serialize %r as JSON' % o) - else: - def default(o): - if isinstance(o, datetime_types + (Decimal,)): - return str(o) - raise TypeError('Unable to serialize %r as JSON' % o) - return default - - def export(self, file_obj, **kwargs): - json.dump( - list(self.query), - file_obj, - default=self._make_default(), - **kwargs) - - -class CSVExporter(Exporter): - def export(self, file_obj, header=True, **kwargs): - writer = csv.writer(file_obj, **kwargs) - tuples = self.query.tuples().execute() - tuples.initialize() - if header and getattr(tuples, 'columns', None): - writer.writerow([column for column in tuples.columns]) - for row in tuples: - writer.writerow(row) - - -class TSVExporter(CSVExporter): - def export(self, file_obj, header=True, **kwargs): - kwargs.setdefault('delimiter', '\t') - return super(TSVExporter, self).export(file_obj, header, **kwargs) - - -class Importer(object): - def __init__(self, table, strict=False): - self.table = table - self.strict = strict - - model = self.table.model_class - self.columns = model._meta.columns - self.columns.update(model._meta.fields) - - def load(self, file_obj): - raise NotImplementedError - - -class JSONImporter(Importer): - def load(self, file_obj, **kwargs): - data = json.load(file_obj, **kwargs) - count = 0 - - for row in data: - if self.strict: - obj = {} - for key in row: - field = self.columns.get(key) - if field is not None: - obj[field.name] = field.python_value(row[key]) - else: - obj = row - - if obj: - self.table.insert(**obj) - count += 1 - - return count - - -class CSVImporter(Importer): - def load(self, file_obj, header=True, **kwargs): - count = 0 - reader = csv.reader(file_obj, **kwargs) - if header: - try: - header_keys = next(reader) - except StopIteration: - return count - - if self.strict: - header_fields = [] - for idx, key in enumerate(header_keys): - if key in self.columns: - header_fields.append((idx, self.columns[key])) - else: - header_fields = list(enumerate(header_keys)) - else: - header_fields = list(enumerate(self.model._meta.sorted_fields)) - - if not header_fields: - return count - - for row in reader: - obj = {} - for idx, field in header_fields: - if self.strict: - obj[field.name] = field.python_value(row[idx]) - else: - obj[field] = row[idx] - - self.table.insert(**obj) - count += 1 - - return count - - -class TSVImporter(CSVImporter): - def load(self, file_obj, header=True, **kwargs): - kwargs.setdefault('delimiter', '\t') - return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/libs/playhouse/db_url.py b/libs/playhouse/db_url.py deleted file mode 100644 index fcc4ab87a..000000000 --- a/libs/playhouse/db_url.py +++ /dev/null @@ -1,124 +0,0 @@ -try: - from urlparse import parse_qsl, unquote, urlparse -except ImportError: - from urllib.parse import parse_qsl, unquote, urlparse - -from peewee import * -from playhouse.pool import PooledMySQLDatabase -from playhouse.pool import PooledPostgresqlDatabase -from playhouse.pool import PooledSqliteDatabase -from playhouse.pool import PooledSqliteExtDatabase -from playhouse.sqlite_ext import SqliteExtDatabase - - -schemes = { - 'mysql': MySQLDatabase, - 'mysql+pool': PooledMySQLDatabase, - 'postgres': PostgresqlDatabase, - 'postgresql': PostgresqlDatabase, - 'postgres+pool': PooledPostgresqlDatabase, - 'postgresql+pool': PooledPostgresqlDatabase, - 'sqlite': SqliteDatabase, - 'sqliteext': SqliteExtDatabase, - 'sqlite+pool': PooledSqliteDatabase, - 'sqliteext+pool': PooledSqliteExtDatabase, -} - -def register_database(db_class, *names): - global schemes - for name in names: - schemes[name] = db_class - -def parseresult_to_dict(parsed, unquote_password=False): - - # urlparse in python 2.6 is broken so query will be empty and instead - # appended to path complete with '?' - path_parts = parsed.path[1:].split('?') - try: - query = path_parts[1] - except IndexError: - query = parsed.query - - connect_kwargs = {'database': path_parts[0]} - if parsed.username: - connect_kwargs['user'] = parsed.username - if parsed.password: - connect_kwargs['password'] = parsed.password - if unquote_password: - connect_kwargs['password'] = unquote(connect_kwargs['password']) - if parsed.hostname: - connect_kwargs['host'] = parsed.hostname - if parsed.port: - connect_kwargs['port'] = parsed.port - - # Adjust parameters for MySQL. - if parsed.scheme == 'mysql' and 'password' in connect_kwargs: - connect_kwargs['passwd'] = connect_kwargs.pop('password') - elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: - connect_kwargs['database'] = ':memory:' - - # Get additional connection args from the query string - qs_args = parse_qsl(query, keep_blank_values=True) - for key, value in qs_args: - if value.lower() == 'false': - value = False - elif value.lower() == 'true': - value = True - elif value.isdigit(): - value = int(value) - elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): - try: - value = float(value) - except ValueError: - pass - elif value.lower() in ('null', 'none'): - value = None - - connect_kwargs[key] = value - - return connect_kwargs - -def parse(url, unquote_password=False): - parsed = urlparse(url) - return parseresult_to_dict(parsed, unquote_password) - -def connect(url, unquote_password=False, **connect_params): - parsed = urlparse(url) - connect_kwargs = parseresult_to_dict(parsed, unquote_password) - connect_kwargs.update(connect_params) - database_class = schemes.get(parsed.scheme) - - if database_class is None: - if database_class in schemes: - raise RuntimeError('Attempted to use "%s" but a required library ' - 'could not be imported.' % parsed.scheme) - else: - raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % - parsed.scheme) - - return database_class(**connect_kwargs) - -# Conditionally register additional databases. -try: - from playhouse.pool import PooledPostgresqlExtDatabase -except ImportError: - pass -else: - register_database( - PooledPostgresqlExtDatabase, - 'postgresext+pool', - 'postgresqlext+pool') - -try: - from playhouse.apsw_ext import APSWDatabase -except ImportError: - pass -else: - register_database(APSWDatabase, 'apsw') - -try: - from playhouse.postgres_ext import PostgresqlExtDatabase -except ImportError: - pass -else: - register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/libs/playhouse/fields.py b/libs/playhouse/fields.py deleted file mode 100644 index fce1a3d6d..000000000 --- a/libs/playhouse/fields.py +++ /dev/null @@ -1,64 +0,0 @@ -try: - import bz2 -except ImportError: - bz2 = None -try: - import zlib -except ImportError: - zlib = None -try: - import cPickle as pickle -except ImportError: - import pickle -import sys - -from peewee import BlobField -from peewee import buffer_type - - -PY2 = sys.version_info[0] == 2 - - -class CompressedField(BlobField): - ZLIB = 'zlib' - BZ2 = 'bz2' - algorithm_to_import = { - ZLIB: zlib, - BZ2: bz2, - } - - def __init__(self, compression_level=6, algorithm=ZLIB, *args, - **kwargs): - self.compression_level = compression_level - if algorithm not in self.algorithm_to_import: - raise ValueError('Unrecognized algorithm %s' % algorithm) - compress_module = self.algorithm_to_import[algorithm] - if compress_module is None: - raise ValueError('Missing library required for %s.' % algorithm) - - self.algorithm = algorithm - self.compress = compress_module.compress - self.decompress = compress_module.decompress - super(CompressedField, self).__init__(*args, **kwargs) - - def python_value(self, value): - if value is not None: - return self.decompress(value) - - def db_value(self, value): - if value is not None: - return self._constructor( - self.compress(value, self.compression_level)) - - -class PickleField(BlobField): - def python_value(self, value): - if value is not None: - if isinstance(value, buffer_type): - value = bytes(value) - return pickle.loads(value) - - def db_value(self, value): - if value is not None: - pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) - return self._constructor(pickled) diff --git a/libs/playhouse/flask_utils.py b/libs/playhouse/flask_utils.py deleted file mode 100644 index 76a2a62f4..000000000 --- a/libs/playhouse/flask_utils.py +++ /dev/null @@ -1,185 +0,0 @@ -import math -import sys - -from flask import abort -from flask import render_template -from flask import request -from peewee import Database -from peewee import DoesNotExist -from peewee import Model -from peewee import Proxy -from peewee import SelectQuery -from playhouse.db_url import connect as db_url_connect - - -class PaginatedQuery(object): - def __init__(self, query_or_model, paginate_by, page_var='page', page=None, - check_bounds=False): - self.paginate_by = paginate_by - self.page_var = page_var - self.page = page or None - self.check_bounds = check_bounds - - if isinstance(query_or_model, SelectQuery): - self.query = query_or_model - self.model = self.query.model - else: - self.model = query_or_model - self.query = self.model.select() - - def get_page(self): - if self.page is not None: - return self.page - - curr_page = request.args.get(self.page_var) - if curr_page and curr_page.isdigit(): - return max(1, int(curr_page)) - return 1 - - def get_page_count(self): - if not hasattr(self, '_page_count'): - self._page_count = int(math.ceil( - float(self.query.count()) / self.paginate_by)) - return self._page_count - - def get_object_list(self): - if self.check_bounds and self.get_page() > self.get_page_count(): - abort(404) - return self.query.paginate(self.get_page(), self.paginate_by) - - -def get_object_or_404(query_or_model, *query): - if not isinstance(query_or_model, SelectQuery): - query_or_model = query_or_model.select() - try: - return query_or_model.where(*query).get() - except DoesNotExist: - abort(404) - -def object_list(template_name, query, context_variable='object_list', - paginate_by=20, page_var='page', page=None, check_bounds=True, - **kwargs): - paginated_query = PaginatedQuery( - query, - paginate_by=paginate_by, - page_var=page_var, - page=page, - check_bounds=check_bounds) - kwargs[context_variable] = paginated_query.get_object_list() - return render_template( - template_name, - pagination=paginated_query, - page=paginated_query.get_page(), - **kwargs) - -def get_current_url(): - if not request.query_string: - return request.path - return '%s?%s' % (request.path, request.query_string) - -def get_next_url(default='/'): - if request.args.get('next'): - return request.args['next'] - elif request.form.get('next'): - return request.form['next'] - return default - -class FlaskDB(object): - def __init__(self, app=None, database=None, model_class=Model): - self.database = None # Reference to actual Peewee database instance. - self.base_model_class = model_class - self._app = app - self._db = database # dict, url, Database, or None (default). - if app is not None: - self.init_app(app) - - def init_app(self, app): - self._app = app - - if self._db is None: - if 'DATABASE' in app.config: - initial_db = app.config['DATABASE'] - elif 'DATABASE_URL' in app.config: - initial_db = app.config['DATABASE_URL'] - else: - raise ValueError('Missing required configuration data for ' - 'database: DATABASE or DATABASE_URL.') - else: - initial_db = self._db - - self._load_database(app, initial_db) - self._register_handlers(app) - - def _load_database(self, app, config_value): - if isinstance(config_value, Database): - database = config_value - elif isinstance(config_value, dict): - database = self._load_from_config_dict(dict(config_value)) - else: - # Assume a database connection URL. - database = db_url_connect(config_value) - - if isinstance(self.database, Proxy): - self.database.initialize(database) - else: - self.database = database - - def _load_from_config_dict(self, config_dict): - try: - name = config_dict.pop('name') - engine = config_dict.pop('engine') - except KeyError: - raise RuntimeError('DATABASE configuration must specify a ' - '`name` and `engine`.') - - if '.' in engine: - path, class_name = engine.rsplit('.', 1) - else: - path, class_name = 'peewee', engine - - try: - __import__(path) - module = sys.modules[path] - database_class = getattr(module, class_name) - assert issubclass(database_class, Database) - except ImportError: - raise RuntimeError('Unable to import %s' % engine) - except AttributeError: - raise RuntimeError('Database engine not found %s' % engine) - except AssertionError: - raise RuntimeError('Database engine not a subclass of ' - 'peewee.Database: %s' % engine) - - return database_class(name, **config_dict) - - def _register_handlers(self, app): - app.before_request(self.connect_db) - app.teardown_request(self.close_db) - - def get_model_class(self): - if self.database is None: - raise RuntimeError('Database must be initialized.') - - class BaseModel(self.base_model_class): - class Meta: - database = self.database - - return BaseModel - - @property - def Model(self): - if self._app is None: - database = getattr(self, 'database', None) - if database is None: - self.database = Proxy() - - if not hasattr(self, '_model_class'): - self._model_class = self.get_model_class() - return self._model_class - - def connect_db(self): - self.database.connect() - - def close_db(self, exc): - if not self.database.is_closed(): - self.database.close() diff --git a/libs/playhouse/hybrid.py b/libs/playhouse/hybrid.py deleted file mode 100644 index 50531cc35..000000000 --- a/libs/playhouse/hybrid.py +++ /dev/null @@ -1,53 +0,0 @@ -from peewee import ModelDescriptor - - -# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: -# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html -class hybrid_method(ModelDescriptor): - def __init__(self, func, expr=None): - self.func = func - self.expr = expr or func - - def __get__(self, instance, instance_type): - if instance is None: - return self.expr.__get__(instance_type, instance_type.__class__) - return self.func.__get__(instance, instance_type) - - def expression(self, expr): - self.expr = expr - return self - - -class hybrid_property(ModelDescriptor): - def __init__(self, fget, fset=None, fdel=None, expr=None): - self.fget = fget - self.fset = fset - self.fdel = fdel - self.expr = expr or fget - - def __get__(self, instance, instance_type): - if instance is None: - return self.expr(instance_type) - return self.fget(instance) - - def __set__(self, instance, value): - if self.fset is None: - raise AttributeError('Cannot set attribute.') - self.fset(instance, value) - - def __delete__(self, instance): - if self.fdel is None: - raise AttributeError('Cannot delete attribute.') - self.fdel(instance) - - def setter(self, fset): - self.fset = fset - return self - - def deleter(self, fdel): - self.fdel = fdel - return self - - def expression(self, expr): - self.expr = expr - return self diff --git a/libs/playhouse/kv.py b/libs/playhouse/kv.py deleted file mode 100644 index 742b49cad..000000000 --- a/libs/playhouse/kv.py +++ /dev/null @@ -1,172 +0,0 @@ -import operator - -from peewee import * -from peewee import Expression -from playhouse.fields import PickleField -try: - from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase -except ImportError: - from playhouse.sqlite_ext import SqliteExtDatabase - - -Sentinel = type('Sentinel', (object,), {}) - - -class KeyValue(object): - """ - Persistent dictionary. - - :param Field key_field: field to use for key. Defaults to CharField. - :param Field value_field: field to use for value. Defaults to PickleField. - :param bool ordered: data should be returned in key-sorted order. - :param Database database: database where key/value data is stored. - :param str table_name: table name for data. - """ - def __init__(self, key_field=None, value_field=None, ordered=False, - database=None, table_name='keyvalue'): - if key_field is None: - key_field = CharField(max_length=255, primary_key=True) - if not key_field.primary_key: - raise ValueError('key_field must have primary_key=True.') - - if value_field is None: - value_field = PickleField() - - self._key_field = key_field - self._value_field = value_field - self._ordered = ordered - self._database = database or SqliteExtDatabase(':memory:') - self._table_name = table_name - if isinstance(self._database, PostgresqlDatabase): - self.upsert = self._postgres_upsert - self.update = self._postgres_update - else: - self.upsert = self._upsert - self.update = self._update - - self.model = self.create_model() - self.key = self.model.key - self.value = self.model.value - - # Ensure table exists. - self.model.create_table() - - def create_model(self): - class KeyValue(Model): - key = self._key_field - value = self._value_field - class Meta: - database = self._database - table_name = self._table_name - return KeyValue - - def query(self, *select): - query = self.model.select(*select).tuples() - if self._ordered: - query = query.order_by(self.key) - return query - - def convert_expression(self, expr): - if not isinstance(expr, Expression): - return (self.key == expr), True - return expr, False - - def __contains__(self, key): - expr, _ = self.convert_expression(key) - return self.model.select().where(expr).exists() - - def __len__(self): - return len(self.model) - - def __getitem__(self, expr): - converted, is_single = self.convert_expression(expr) - query = self.query(self.value).where(converted) - item_getter = operator.itemgetter(0) - result = [item_getter(row) for row in query] - if len(result) == 0 and is_single: - raise KeyError(expr) - elif is_single: - return result[0] - return result - - def _upsert(self, key, value): - (self.model - .insert(key=key, value=value) - .on_conflict('replace') - .execute()) - - def _postgres_upsert(self, key, value): - (self.model - .insert(key=key, value=value) - .on_conflict(conflict_target=[self.key], - preserve=[self.value]) - .execute()) - - def __setitem__(self, expr, value): - if isinstance(expr, Expression): - self.model.update(value=value).where(expr).execute() - else: - self.upsert(expr, value) - - def __delitem__(self, expr): - converted, _ = self.convert_expression(expr) - self.model.delete().where(converted).execute() - - def __iter__(self): - return iter(self.query().execute()) - - def keys(self): - return map(operator.itemgetter(0), self.query(self.key)) - - def values(self): - return map(operator.itemgetter(0), self.query(self.value)) - - def items(self): - return iter(self.query().execute()) - - def _update(self, __data=None, **mapping): - if __data is not None: - mapping.update(__data) - return (self.model - .insert_many(list(mapping.items()), - fields=[self.key, self.value]) - .on_conflict('replace') - .execute()) - - def _postgres_update(self, __data=None, **mapping): - if __data is not None: - mapping.update(__data) - return (self.model - .insert_many(list(mapping.items()), - fields=[self.key, self.value]) - .on_conflict(conflict_target=[self.key], - preserve=[self.value]) - .execute()) - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def setdefault(self, key, default=None): - try: - return self[key] - except KeyError: - self[key] = default - return default - - def pop(self, key, default=Sentinel): - with self._database.atomic(): - try: - result = self[key] - except KeyError: - if default is Sentinel: - raise - return default - del self[key] - - return result - - def clear(self): - self.model.delete().execute() diff --git a/libs/playhouse/migrate.py b/libs/playhouse/migrate.py deleted file mode 100644 index 4d90b70ec..000000000 --- a/libs/playhouse/migrate.py +++ /dev/null @@ -1,823 +0,0 @@ -""" -Lightweight schema migrations. - -NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some -features. - -Example Usage -------------- - -Instantiate a migrator: - - # Postgres example: - my_db = PostgresqlDatabase(...) - migrator = PostgresqlMigrator(my_db) - - # SQLite example: - my_db = SqliteDatabase('my_database.db') - migrator = SqliteMigrator(my_db) - -Then you will use the `migrate` function to run various `Operation`s which -are generated by the migrator: - - migrate( - migrator.add_column('some_table', 'column_name', CharField(default='')) - ) - -Migrations are not run inside a transaction, so if you wish the migration to -run in a transaction you will need to wrap the call to `migrate` in a -transaction block, e.g.: - - with my_db.transaction(): - migrate(...) - -Supported Operations --------------------- - -Add new field(s) to an existing model: - - # Create your field instances. For non-null fields you must specify a - # default value. - pubdate_field = DateTimeField(null=True) - comment_field = TextField(default='') - - # Run the migration, specifying the database table, field name and field. - migrate( - migrator.add_column('comment_tbl', 'pub_date', pubdate_field), - migrator.add_column('comment_tbl', 'comment', comment_field), - ) - -Renaming a field: - - # Specify the table, original name of the column, and its new name. - migrate( - migrator.rename_column('story', 'pub_date', 'publish_date'), - migrator.rename_column('story', 'mod_date', 'modified_date'), - ) - -Dropping a field: - - migrate( - migrator.drop_column('story', 'some_old_field'), - ) - -Making a field nullable or not nullable: - - # Note that when making a field not null that field must not have any - # NULL values present. - migrate( - # Make `pub_date` allow NULL values. - migrator.drop_not_null('story', 'pub_date'), - - # Prevent `modified_date` from containing NULL values. - migrator.add_not_null('story', 'modified_date'), - ) - -Renaming a table: - - migrate( - migrator.rename_table('story', 'stories_tbl'), - ) - -Adding an index: - - # Specify the table, column names, and whether the index should be - # UNIQUE or not. - migrate( - # Create an index on the `pub_date` column. - migrator.add_index('story', ('pub_date',), False), - - # Create a multi-column index on the `pub_date` and `status` fields. - migrator.add_index('story', ('pub_date', 'status'), False), - - # Create a unique index on the category and title fields. - migrator.add_index('story', ('category_id', 'title'), True), - ) - -Dropping an index: - - # Specify the index name. - migrate(migrator.drop_index('story', 'story_pub_date_status')) - -Adding or dropping table constraints: - -.. code-block:: python - - # Add a CHECK() constraint to enforce the price cannot be negative. - migrate(migrator.add_constraint( - 'products', - 'price_check', - Check('price >= 0'))) - - # Remove the price check constraint. - migrate(migrator.drop_constraint('products', 'price_check')) - - # Add a UNIQUE constraint on the first and last names. - migrate(migrator.add_unique('person', 'first_name', 'last_name')) -""" -from collections import namedtuple -import functools -import hashlib -import re - -from peewee import * -from peewee import CommaNodeList -from peewee import EnclosedNodeList -from peewee import Entity -from peewee import Expression -from peewee import Node -from peewee import NodeList -from peewee import OP -from peewee import callable_ -from peewee import sort_models -from peewee import _truncate_constraint_name - - -class Operation(object): - """Encapsulate a single schema altering operation.""" - def __init__(self, migrator, method, *args, **kwargs): - self.migrator = migrator - self.method = method - self.args = args - self.kwargs = kwargs - - def execute(self, node): - self.migrator.database.execute(node) - - def _handle_result(self, result): - if isinstance(result, (Node, Context)): - self.execute(result) - elif isinstance(result, Operation): - result.run() - elif isinstance(result, (list, tuple)): - for item in result: - self._handle_result(item) - - def run(self): - kwargs = self.kwargs.copy() - kwargs['with_context'] = True - method = getattr(self.migrator, self.method) - self._handle_result(method(*self.args, **kwargs)) - - -def operation(fn): - @functools.wraps(fn) - def inner(self, *args, **kwargs): - with_context = kwargs.pop('with_context', False) - if with_context: - return fn(self, *args, **kwargs) - return Operation(self, fn.__name__, *args, **kwargs) - return inner - - -def make_index_name(table_name, columns): - index_name = '_'.join((table_name,) + tuple(columns)) - if len(index_name) > 64: - index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() - index_name = '%s_%s' % (index_name[:56], index_hash[:7]) - return index_name - - -class SchemaMigrator(object): - explicit_create_foreign_key = False - explicit_delete_foreign_key = False - - def __init__(self, database): - self.database = database - - def make_context(self): - return self.database.get_sql_context() - - @classmethod - def from_database(cls, database): - if isinstance(database, PostgresqlDatabase): - return PostgresqlMigrator(database) - elif isinstance(database, MySQLDatabase): - return MySQLMigrator(database) - elif isinstance(database, SqliteDatabase): - return SqliteMigrator(database) - raise ValueError('Unsupported database: %s' % database) - - @operation - def apply_default(self, table, column_name, field): - default = field.default - if callable_(default): - default = default() - - return (self.make_context() - .literal('UPDATE ') - .sql(Entity(table)) - .literal(' SET ') - .sql(Expression( - Entity(column_name), - OP.EQ, - field.db_value(default), - flat=True))) - - def _alter_table(self, ctx, table): - return ctx.literal('ALTER TABLE ').sql(Entity(table)) - - def _alter_column(self, ctx, table, column): - return (self - ._alter_table(ctx, table) - .literal(' ALTER COLUMN ') - .sql(Entity(column))) - - @operation - def alter_add_column(self, table, column_name, field): - # Make field null at first. - ctx = self.make_context() - field_null, field.null = field.null, True - field.name = field.column_name = column_name - (self - ._alter_table(ctx, table) - .literal(' ADD COLUMN ') - .sql(field.ddl(ctx))) - - field.null = field_null - if isinstance(field, ForeignKeyField): - self.add_inline_fk_sql(ctx, field) - return ctx - - @operation - def add_constraint(self, table, name, constraint): - return (self - ._alter_table(self.make_context(), table) - .literal(' ADD CONSTRAINT ') - .sql(Entity(name)) - .literal(' ') - .sql(constraint)) - - @operation - def add_unique(self, table, *column_names): - constraint_name = 'uniq_%s' % '_'.join(column_names) - constraint = NodeList(( - SQL('UNIQUE'), - EnclosedNodeList([Entity(column) for column in column_names]))) - return self.add_constraint(table, constraint_name, constraint) - - @operation - def drop_constraint(self, table, name): - return (self - ._alter_table(self.make_context(), table) - .literal(' DROP CONSTRAINT ') - .sql(Entity(name))) - - def add_inline_fk_sql(self, ctx, field): - ctx = (ctx - .literal(' REFERENCES ') - .sql(Entity(field.rel_model._meta.table_name)) - .literal(' ') - .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) - if field.on_delete is not None: - ctx = ctx.literal(' ON DELETE %s' % field.on_delete) - if field.on_update is not None: - ctx = ctx.literal(' ON UPDATE %s' % field.on_update) - return ctx - - @operation - def add_foreign_key_constraint(self, table, column_name, rel, rel_column, - on_delete=None, on_update=None): - constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) - ctx = (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' ADD CONSTRAINT ') - .sql(Entity(_truncate_constraint_name(constraint))) - .literal(' FOREIGN KEY ') - .sql(EnclosedNodeList((Entity(column_name),))) - .literal(' REFERENCES ') - .sql(Entity(rel)) - .literal(' (') - .sql(Entity(rel_column)) - .literal(')')) - if on_delete is not None: - ctx = ctx.literal(' ON DELETE %s' % on_delete) - if on_update is not None: - ctx = ctx.literal(' ON UPDATE %s' % on_update) - return ctx - - @operation - def add_column(self, table, column_name, field): - # Adding a column is complicated by the fact that if there are rows - # present and the field is non-null, then we need to first add the - # column as a nullable field, then set the value, then add a not null - # constraint. - if not field.null and field.default is None: - raise ValueError('%s is not null but has no default' % column_name) - - is_foreign_key = isinstance(field, ForeignKeyField) - if is_foreign_key and not field.rel_field: - raise ValueError('Foreign keys must specify a `field`.') - - operations = [self.alter_add_column(table, column_name, field)] - - # In the event the field is *not* nullable, update with the default - # value and set not null. - if not field.null: - operations.extend([ - self.apply_default(table, column_name, field), - self.add_not_null(table, column_name)]) - - if is_foreign_key and self.explicit_create_foreign_key: - operations.append( - self.add_foreign_key_constraint( - table, - column_name, - field.rel_model._meta.table_name, - field.rel_field.column_name, - field.on_delete, - field.on_update)) - - if field.index or field.unique: - using = getattr(field, 'index_type', None) - operations.append(self.add_index(table, (column_name,), - field.unique, using)) - - return operations - - @operation - def drop_foreign_key_constraint(self, table, column_name): - raise NotImplementedError - - @operation - def drop_column(self, table, column_name, cascade=True): - ctx = self.make_context() - (self._alter_table(ctx, table) - .literal(' DROP COLUMN ') - .sql(Entity(column_name))) - - if cascade: - ctx.literal(' CASCADE') - - fk_columns = [ - foreign_key.column - for foreign_key in self.database.get_foreign_keys(table)] - if column_name in fk_columns and self.explicit_delete_foreign_key: - return [self.drop_foreign_key_constraint(table, column_name), ctx] - - return ctx - - @operation - def rename_column(self, table, old_name, new_name): - return (self - ._alter_table(self.make_context(), table) - .literal(' RENAME COLUMN ') - .sql(Entity(old_name)) - .literal(' TO ') - .sql(Entity(new_name))) - - @operation - def add_not_null(self, table, column): - return (self - ._alter_column(self.make_context(), table, column) - .literal(' SET NOT NULL')) - - @operation - def drop_not_null(self, table, column): - return (self - ._alter_column(self.make_context(), table, column) - .literal(' DROP NOT NULL')) - - @operation - def rename_table(self, old_name, new_name): - return (self - ._alter_table(self.make_context(), old_name) - .literal(' RENAME TO ') - .sql(Entity(new_name))) - - @operation - def add_index(self, table, columns, unique=False, using=None): - ctx = self.make_context() - index_name = make_index_name(table, columns) - table_obj = Table(table) - cols = [getattr(table_obj.c, column) for column in columns] - index = Index(index_name, table_obj, cols, unique=unique, using=using) - return ctx.sql(index) - - @operation - def drop_index(self, table, index_name): - return (self - .make_context() - .literal('DROP INDEX ') - .sql(Entity(index_name))) - - -class PostgresqlMigrator(SchemaMigrator): - def _primary_key_columns(self, tbl): - query = """ - SELECT pg_attribute.attname - FROM pg_index, pg_class, pg_attribute - WHERE - pg_class.oid = '%s'::regclass AND - indrelid = pg_class.oid AND - pg_attribute.attrelid = pg_class.oid AND - pg_attribute.attnum = any(pg_index.indkey) AND - indisprimary; - """ - cursor = self.database.execute_sql(query % tbl) - return [row[0] for row in cursor.fetchall()] - - @operation - def set_search_path(self, schema_name): - return (self - .make_context() - .literal('SET search_path TO %s' % schema_name)) - - @operation - def rename_table(self, old_name, new_name): - pk_names = self._primary_key_columns(old_name) - ParentClass = super(PostgresqlMigrator, self) - - operations = [ - ParentClass.rename_table(old_name, new_name, with_context=True)] - - if len(pk_names) == 1: - # Check for existence of primary key sequence. - seq_name = '%s_%s_seq' % (old_name, pk_names[0]) - query = """ - SELECT 1 - FROM information_schema.sequences - WHERE LOWER(sequence_name) = LOWER(%s) - """ - cursor = self.database.execute_sql(query, (seq_name,)) - if bool(cursor.fetchone()): - new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) - operations.append(ParentClass.rename_table( - seq_name, new_seq_name)) - - return operations - - -class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', - 'default', 'extra'))): - @property - def is_pk(self): - return self.pk == 'PRI' - - @property - def is_unique(self): - return self.pk == 'UNI' - - @property - def is_null(self): - return self.null == 'YES' - - def sql(self, column_name=None, is_null=None): - if is_null is None: - is_null = self.is_null - if column_name is None: - column_name = self.name - parts = [ - Entity(column_name), - SQL(self.definition)] - if self.is_unique: - parts.append(SQL('UNIQUE')) - if is_null: - parts.append(SQL('NULL')) - else: - parts.append(SQL('NOT NULL')) - if self.is_pk: - parts.append(SQL('PRIMARY KEY')) - if self.extra: - parts.append(SQL(self.extra)) - return NodeList(parts) - - -class MySQLMigrator(SchemaMigrator): - explicit_create_foreign_key = True - explicit_delete_foreign_key = True - - @operation - def rename_table(self, old_name, new_name): - return (self - .make_context() - .literal('RENAME TABLE ') - .sql(Entity(old_name)) - .literal(' TO ') - .sql(Entity(new_name))) - - def _get_column_definition(self, table, column_name): - cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) - rows = cursor.fetchall() - for row in rows: - column = MySQLColumn(*row) - if column.name == column_name: - return column - return False - - def get_foreign_key_constraint(self, table, column_name): - cursor = self.database.execute_sql( - ('SELECT constraint_name ' - 'FROM information_schema.key_column_usage WHERE ' - 'table_schema = DATABASE() AND ' - 'table_name = %s AND ' - 'column_name = %s AND ' - 'referenced_table_name IS NOT NULL AND ' - 'referenced_column_name IS NOT NULL;'), - (table, column_name)) - result = cursor.fetchone() - if not result: - raise AttributeError( - 'Unable to find foreign key constraint for ' - '"%s" on table "%s".' % (table, column_name)) - return result[0] - - @operation - def drop_foreign_key_constraint(self, table, column_name): - fk_constraint = self.get_foreign_key_constraint(table, column_name) - return (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' DROP FOREIGN KEY ') - .sql(Entity(fk_constraint))) - - def add_inline_fk_sql(self, ctx, field): - pass - - @operation - def add_not_null(self, table, column): - column_def = self._get_column_definition(table, column) - add_not_null = (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' MODIFY ') - .sql(column_def.sql(is_null=False))) - - fk_objects = dict( - (fk.column, fk) - for fk in self.database.get_foreign_keys(table)) - if column not in fk_objects: - return add_not_null - - fk_metadata = fk_objects[column] - return (self.drop_foreign_key_constraint(table, column), - add_not_null, - self.add_foreign_key_constraint( - table, - column, - fk_metadata.dest_table, - fk_metadata.dest_column)) - - @operation - def drop_not_null(self, table, column): - column = self._get_column_definition(table, column) - if column.is_pk: - raise ValueError('Primary keys can not be null') - return (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' MODIFY ') - .sql(column.sql(is_null=True))) - - @operation - def rename_column(self, table, old_name, new_name): - fk_objects = dict( - (fk.column, fk) - for fk in self.database.get_foreign_keys(table)) - is_foreign_key = old_name in fk_objects - - column = self._get_column_definition(table, old_name) - rename_ctx = (self - .make_context() - .literal('ALTER TABLE ') - .sql(Entity(table)) - .literal(' CHANGE ') - .sql(Entity(old_name)) - .literal(' ') - .sql(column.sql(column_name=new_name))) - if is_foreign_key: - fk_metadata = fk_objects[old_name] - return [ - self.drop_foreign_key_constraint(table, old_name), - rename_ctx, - self.add_foreign_key_constraint( - table, - new_name, - fk_metadata.dest_table, - fk_metadata.dest_column), - ] - else: - return rename_ctx - - @operation - def drop_index(self, table, index_name): - return (self - .make_context() - .literal('DROP INDEX ') - .sql(Entity(index_name)) - .literal(' ON ') - .sql(Entity(table))) - - -class SqliteMigrator(SchemaMigrator): - """ - SQLite supports a subset of ALTER TABLE queries, view the docs for the - full details http://sqlite.org/lang_altertable.html - """ - column_re = re.compile('(.+?)\((.+)\)') - column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') - column_name_re = re.compile('["`\']?([\w]+)') - fk_re = re.compile('FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) - - def _get_column_names(self, table): - res = self.database.execute_sql('select * from "%s" limit 1' % table) - return [item[0] for item in res.description] - - def _get_create_table(self, table): - res = self.database.execute_sql( - ('select name, sql from sqlite_master ' - 'where type=? and LOWER(name)=?'), - ['table', table.lower()]) - return res.fetchone() - - @operation - def _update_column(self, table, column_to_update, fn): - columns = set(column.name.lower() - for column in self.database.get_columns(table)) - if column_to_update.lower() not in columns: - raise ValueError('Column "%s" does not exist on "%s"' % - (column_to_update, table)) - - # Get the SQL used to create the given table. - table, create_table = self._get_create_table(table) - - # Get the indexes and SQL to re-create indexes. - indexes = self.database.get_indexes(table) - - # Find any foreign keys we may need to remove. - self.database.get_foreign_keys(table) - - # Make sure the create_table does not contain any newlines or tabs, - # allowing the regex to work correctly. - create_table = re.sub(r'\s+', ' ', create_table) - - # Parse out the `CREATE TABLE` and column list portions of the query. - raw_create, raw_columns = self.column_re.search(create_table).groups() - - # Clean up the individual column definitions. - split_columns = self.column_split_re.findall(raw_columns) - column_defs = [col.strip() for col in split_columns] - - new_column_defs = [] - new_column_names = [] - original_column_names = [] - constraint_terms = ('foreign ', 'primary ', 'constraint ') - - for column_def in column_defs: - column_name, = self.column_name_re.match(column_def).groups() - - if column_name == column_to_update: - new_column_def = fn(column_name, column_def) - if new_column_def: - new_column_defs.append(new_column_def) - original_column_names.append(column_name) - column_name, = self.column_name_re.match( - new_column_def).groups() - new_column_names.append(column_name) - else: - new_column_defs.append(column_def) - - # Avoid treating constraints as columns. - if not column_def.lower().startswith(constraint_terms): - new_column_names.append(column_name) - original_column_names.append(column_name) - - # Create a mapping of original columns to new columns. - original_to_new = dict(zip(original_column_names, new_column_names)) - new_column = original_to_new.get(column_to_update) - - fk_filter_fn = lambda column_def: column_def - if not new_column: - # Remove any foreign keys associated with this column. - fk_filter_fn = lambda column_def: None - elif new_column != column_to_update: - # Update any foreign keys for this column. - fk_filter_fn = lambda column_def: self.fk_re.sub( - 'FOREIGN KEY ("%s") ' % new_column, - column_def) - - cleaned_columns = [] - for column_def in new_column_defs: - match = self.fk_re.match(column_def) - if match is not None and match.groups()[0] == column_to_update: - column_def = fk_filter_fn(column_def) - if column_def: - cleaned_columns.append(column_def) - - # Update the name of the new CREATE TABLE query. - temp_table = table + '__tmp__' - rgx = re.compile('("?)%s("?)' % table, re.I) - create = rgx.sub( - '\\1%s\\2' % temp_table, - raw_create) - - # Create the new table. - columns = ', '.join(cleaned_columns) - queries = [ - NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), - SQL('%s (%s)' % (create.strip(), columns))] - - # Populate new table. - populate_table = NodeList(( - SQL('INSERT INTO'), - Entity(temp_table), - EnclosedNodeList([Entity(col) for col in new_column_names]), - SQL('SELECT'), - CommaNodeList([Entity(col) for col in original_column_names]), - SQL('FROM'), - Entity(table))) - drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) - - # Drop existing table and rename temp table. - queries += [ - populate_table, - drop_original, - self.rename_table(temp_table, table)] - - # Re-create user-defined indexes. User-defined indexes will have a - # non-empty SQL attribute. - for index in filter(lambda idx: idx.sql, indexes): - if column_to_update not in index.columns: - queries.append(SQL(index.sql)) - elif new_column: - sql = self._fix_index(index.sql, column_to_update, new_column) - if sql is not None: - queries.append(SQL(sql)) - - return queries - - def _fix_index(self, sql, column_to_update, new_column): - # Split on the name of the column to update. If it splits into two - # pieces, then there's no ambiguity and we can simply replace the - # old with the new. - parts = sql.split(column_to_update) - if len(parts) == 2: - return sql.replace(column_to_update, new_column) - - # Find the list of columns in the index expression. - lhs, rhs = sql.rsplit('(', 1) - - # Apply the same "split in two" logic to the column list portion of - # the query. - if len(rhs.split(column_to_update)) == 2: - return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) - - # Strip off the trailing parentheses and go through each column. - parts = rhs.rsplit(')', 1)[0].split(',') - columns = [part.strip('"`[]\' ') for part in parts] - - # `columns` looks something like: ['status', 'timestamp" DESC'] - # https://www.sqlite.org/lang_keywords.html - # Strip out any junk after the column name. - clean = [] - for column in columns: - if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): - column = new_column + column[len(column_to_update):] - clean.append(column) - - return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) - - @operation - def drop_column(self, table, column_name, cascade=True): - return self._update_column(table, column_name, lambda a, b: None) - - @operation - def rename_column(self, table, old_name, new_name): - def _rename(column_name, column_def): - return column_def.replace(column_name, new_name) - return self._update_column(table, old_name, _rename) - - @operation - def add_not_null(self, table, column): - def _add_not_null(column_name, column_def): - return column_def + ' NOT NULL' - return self._update_column(table, column, _add_not_null) - - @operation - def drop_not_null(self, table, column): - def _drop_not_null(column_name, column_def): - return column_def.replace('NOT NULL', '') - return self._update_column(table, column, _drop_not_null) - - @operation - def add_constraint(self, table, name, constraint): - raise NotImplementedError - - @operation - def drop_constraint(self, table, name): - raise NotImplementedError - - @operation - def add_foreign_key_constraint(self, table, column_name, field, - on_delete=None, on_update=None): - raise NotImplementedError - - -def migrate(*operations, **kwargs): - for operation in operations: - operation.run() diff --git a/libs/playhouse/mysql_ext.py b/libs/playhouse/mysql_ext.py deleted file mode 100644 index 9ee265573..000000000 --- a/libs/playhouse/mysql_ext.py +++ /dev/null @@ -1,49 +0,0 @@ -import json - -try: - import mysql.connector as mysql_connector -except ImportError: - mysql_connector = None - -from peewee import ImproperlyConfigured -from peewee import MySQLDatabase -from peewee import NodeList -from peewee import SQL -from peewee import TextField -from peewee import fn - - -class MySQLConnectorDatabase(MySQLDatabase): - def _connect(self): - if mysql_connector is None: - raise ImproperlyConfigured('MySQL connector not installed!') - return mysql_connector.connect(db=self.database, **self.connect_params) - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - return self._state.conn.cursor(buffered=True) - - -class JSONField(TextField): - field_type = 'JSON' - - def db_value(self, value): - if value is not None: - return json.dumps(value) - - def python_value(self, value): - if value is not None: - return json.loads(value) - - -def Match(columns, expr, modifier=None): - if isinstance(columns, (list, tuple)): - match = fn.MATCH(*columns) # Tuple of one or more columns / fields. - else: - match = fn.MATCH(columns) # Single column / field. - args = expr if modifier is None else NodeList((expr, SQL(modifier))) - return NodeList((match, fn.AGAINST(args))) diff --git a/libs/playhouse/pool.py b/libs/playhouse/pool.py deleted file mode 100644 index 2ee3b486f..000000000 --- a/libs/playhouse/pool.py +++ /dev/null @@ -1,318 +0,0 @@ -""" -Lightweight connection pooling for peewee. - -In a multi-threaded application, up to `max_connections` will be opened. Each -thread (or, if using gevent, greenlet) will have it's own connection. - -In a single-threaded application, only one connection will be created. It will -be continually recycled until either it exceeds the stale timeout or is closed -explicitly (using `.manual_close()`). - -By default, all your application needs to do is ensure that connections are -closed when you are finished with them, and they will be returned to the pool. -For web applications, this typically means that at the beginning of a request, -you will open a connection, and when you return a response, you will close the -connection. - -Simple Postgres pool example code: - - # Use the special postgresql extensions. - from playhouse.pool import PooledPostgresqlExtDatabase - - db = PooledPostgresqlExtDatabase( - 'my_app', - max_connections=32, - stale_timeout=300, # 5 minutes. - user='postgres') - - class BaseModel(Model): - class Meta: - database = db - -That's it! -""" -import heapq -import logging -import random -import time -from collections import namedtuple -from itertools import chain - -try: - from psycopg2.extensions import TRANSACTION_STATUS_IDLE - from psycopg2.extensions import TRANSACTION_STATUS_INERROR - from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN -except ImportError: - TRANSACTION_STATUS_IDLE = \ - TRANSACTION_STATUS_INERROR = \ - TRANSACTION_STATUS_UNKNOWN = None - -from peewee import MySQLDatabase -from peewee import PostgresqlDatabase -from peewee import SqliteDatabase - -logger = logging.getLogger('peewee.pool') - - -def make_int(val): - if val is not None and not isinstance(val, (int, float)): - return int(val) - return val - - -class MaxConnectionsExceeded(ValueError): pass - - -PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', - 'checked_out')) - - -class PooledDatabase(object): - def __init__(self, database, max_connections=20, stale_timeout=None, - timeout=None, **kwargs): - self._max_connections = make_int(max_connections) - self._stale_timeout = make_int(stale_timeout) - self._wait_timeout = make_int(timeout) - if self._wait_timeout == 0: - self._wait_timeout = float('inf') - - # Available / idle connections stored in a heap, sorted oldest first. - self._connections = [] - - # Mapping of connection id to PoolConnection. Ordinarily we would want - # to use something like a WeakKeyDictionary, but Python typically won't - # allow us to create weak references to connection objects. - self._in_use = {} - - # Use the memory address of the connection as the key in the event the - # connection object is not hashable. Connections will not get - # garbage-collected, however, because a reference to them will persist - # in "_in_use" as long as the conn has not been closed. - self.conn_key = id - - super(PooledDatabase, self).__init__(database, **kwargs) - - def init(self, database, max_connections=None, stale_timeout=None, - timeout=None, **connect_kwargs): - super(PooledDatabase, self).init(database, **connect_kwargs) - if max_connections is not None: - self._max_connections = make_int(max_connections) - if stale_timeout is not None: - self._stale_timeout = make_int(stale_timeout) - if timeout is not None: - self._wait_timeout = make_int(timeout) - if self._wait_timeout == 0: - self._wait_timeout = float('inf') - - def connect(self, reuse_if_open=False): - if not self._wait_timeout: - return super(PooledDatabase, self).connect(reuse_if_open) - - expires = time.time() + self._wait_timeout - while expires > time.time(): - try: - ret = super(PooledDatabase, self).connect(reuse_if_open) - except MaxConnectionsExceeded: - time.sleep(0.1) - else: - return ret - raise MaxConnectionsExceeded('Max connections exceeded, timed out ' - 'attempting to connect.') - - def _connect(self): - while True: - try: - # Remove the oldest connection from the heap. - ts, conn = heapq.heappop(self._connections) - key = self.conn_key(conn) - except IndexError: - ts = conn = None - logger.debug('No connection available in pool.') - break - else: - if self._is_closed(conn): - # This connecton was closed, but since it was not stale - # it got added back to the queue of available conns. We - # then closed it and marked it as explicitly closed, so - # it's safe to throw it away now. - # (Because Database.close() calls Database._close()). - logger.debug('Connection %s was closed.', key) - ts = conn = None - elif self._stale_timeout and self._is_stale(ts): - # If we are attempting to check out a stale connection, - # then close it. We don't need to mark it in the "closed" - # set, because it is not in the list of available conns - # anymore. - logger.debug('Connection %s was stale, closing.', key) - self._close(conn, True) - ts = conn = None - else: - break - - if conn is None: - if self._max_connections and ( - len(self._in_use) >= self._max_connections): - raise MaxConnectionsExceeded('Exceeded maximum connections.') - conn = super(PooledDatabase, self)._connect() - ts = time.time() - random.random() / 1000 - key = self.conn_key(conn) - logger.debug('Created new connection %s.', key) - - self._in_use[key] = PoolConnection(ts, conn, time.time()) - return conn - - def _is_stale(self, timestamp): - # Called on check-out and check-in to ensure the connection has - # not outlived the stale timeout. - return (time.time() - timestamp) > self._stale_timeout - - def _is_closed(self, conn): - return False - - def _can_reuse(self, conn): - # Called on check-in to make sure the connection can be re-used. - return True - - def _close(self, conn, close_conn=False): - key = self.conn_key(conn) - if close_conn: - super(PooledDatabase, self)._close(conn) - elif key in self._in_use: - pool_conn = self._in_use.pop(key) - if self._stale_timeout and self._is_stale(pool_conn.timestamp): - logger.debug('Closing stale connection %s.', key) - super(PooledDatabase, self)._close(conn) - elif self._can_reuse(conn): - logger.debug('Returning %s to pool.', key) - heapq.heappush(self._connections, (pool_conn.timestamp, conn)) - else: - logger.debug('Closed %s.', key) - - def manual_close(self): - """ - Close the underlying connection without returning it to the pool. - """ - if self.is_closed(): - return False - - # Obtain reference to the connection in-use by the calling thread. - conn = self.connection() - - # A connection will only be re-added to the available list if it is - # marked as "in use" at the time it is closed. We will explicitly - # remove it from the "in use" list, call "close()" for the - # side-effects, and then explicitly close the connection. - self._in_use.pop(self.conn_key(conn), None) - self.close() - self._close(conn, close_conn=True) - - def close_idle(self): - # Close any open connections that are not currently in-use. - with self._lock: - for _, conn in self._connections: - self._close(conn, close_conn=True) - self._connections = [] - - def close_stale(self, age=600): - # Close any connections that are in-use but were checked out quite some - # time ago and can be considered stale. - with self._lock: - in_use = {} - cutoff = time.time() - age - n = 0 - for key, pool_conn in self._in_use.items(): - if pool_conn.checked_out < cutoff: - self._close(pool_conn.connection, close_conn=True) - n += 1 - else: - in_use[key] = pool_conn - self._in_use = in_use - return n - - def close_all(self): - # Close all connections -- available and in-use. Warning: may break any - # active connections used by other threads. - self.close() - with self._lock: - for _, conn in self._connections: - self._close(conn, close_conn=True) - for pool_conn in self._in_use.values(): - self._close(pool_conn.connection, close_conn=True) - self._connections = [] - self._in_use = {} - - -class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): - def _is_closed(self, conn): - try: - conn.ping(False) - except: - return True - else: - return False - - -class _PooledPostgresqlDatabase(PooledDatabase): - def _is_closed(self, conn): - if conn.closed: - return True - - txn_status = conn.get_transaction_status() - if txn_status == TRANSACTION_STATUS_UNKNOWN: - return True - elif txn_status != TRANSACTION_STATUS_IDLE: - conn.rollback() - return False - - def _can_reuse(self, conn): - txn_status = conn.get_transaction_status() - # Do not return connection in an error state, as subsequent queries - # will all fail. If the status is unknown then we lost the connection - # to the server and the connection should not be re-used. - if txn_status == TRANSACTION_STATUS_UNKNOWN: - return False - elif txn_status == TRANSACTION_STATUS_INERROR: - conn.reset() - elif txn_status != TRANSACTION_STATUS_IDLE: - conn.rollback() - return True - -class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): - pass - -try: - from playhouse.postgres_ext import PostgresqlExtDatabase - - class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): - pass -except ImportError: - PooledPostgresqlExtDatabase = None - - -class _PooledSqliteDatabase(PooledDatabase): - def _is_closed(self, conn): - try: - conn.total_changes - except: - return True - else: - return False - -class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): - pass - -try: - from playhouse.sqlite_ext import SqliteExtDatabase - - class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): - pass -except ImportError: - PooledSqliteExtDatabase = None - -try: - from playhouse.sqlite_ext import CSqliteExtDatabase - - class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): - pass -except ImportError: - PooledCSqliteExtDatabase = None diff --git a/libs/playhouse/postgres_ext.py b/libs/playhouse/postgres_ext.py deleted file mode 100644 index 64f44073f..000000000 --- a/libs/playhouse/postgres_ext.py +++ /dev/null @@ -1,474 +0,0 @@ -""" -Collection of postgres-specific extensions, currently including: - -* Support for hstore, a key/value type storage -""" -import json -import logging -import uuid - -from peewee import * -from peewee import ColumnBase -from peewee import Expression -from peewee import Node -from peewee import NodeList -from peewee import SENTINEL -from peewee import __exception_wrapper__ - -try: - from psycopg2cffi import compat - compat.register() -except ImportError: - pass - -from psycopg2.extras import register_hstore -try: - from psycopg2.extras import Json -except: - Json = None - - -logger = logging.getLogger('peewee') - - -HCONTAINS_DICT = '@>' -HCONTAINS_KEYS = '?&' -HCONTAINS_KEY = '?' -HCONTAINS_ANY_KEY = '?|' -HKEY = '->' -HUPDATE = '||' -ACONTAINS = '@>' -ACONTAINS_ANY = '&&' -TS_MATCH = '@@' -JSONB_CONTAINS = '@>' -JSONB_CONTAINED_BY = '<@' -JSONB_CONTAINS_KEY = '?' -JSONB_CONTAINS_ANY_KEY = '?|' -JSONB_CONTAINS_ALL_KEYS = '?&' -JSONB_EXISTS = '?' -JSONB_REMOVE = '-' - - -class _LookupNode(ColumnBase): - def __init__(self, node, parts): - self.node = node - self.parts = parts - super(_LookupNode, self).__init__() - - def clone(self): - return type(self)(self.node, list(self.parts)) - - -class _JsonLookupBase(_LookupNode): - def __init__(self, node, parts, as_json=False): - super(_JsonLookupBase, self).__init__(node, parts) - self._as_json = as_json - - def clone(self): - return type(self)(self.node, list(self.parts), self._as_json) - - @Node.copy - def as_json(self, as_json=True): - self._as_json = as_json - - def concat(self, rhs): - return Expression(self.as_json(True), OP.CONCAT, Json(rhs)) - - def contains(self, other): - clone = self.as_json(True) - if isinstance(other, (list, dict)): - return Expression(clone, JSONB_CONTAINS, Json(other)) - return Expression(clone, JSONB_EXISTS, other) - - def contains_any(self, *keys): - return Expression( - self.as_json(True), - JSONB_CONTAINS_ANY_KEY, - Value(list(keys), unpack=False)) - - def contains_all(self, *keys): - return Expression( - self.as_json(True), - JSONB_CONTAINS_ALL_KEYS, - Value(list(keys), unpack=False)) - - def has_key(self, key): - return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) - - -class JsonLookup(_JsonLookupBase): - def __getitem__(self, value): - return JsonLookup(self.node, self.parts + [value], self._as_json) - - def __sql__(self, ctx): - ctx.sql(self.node) - for part in self.parts[:-1]: - ctx.literal('->').sql(part) - if self.parts: - (ctx - .literal('->' if self._as_json else '->>') - .sql(self.parts[-1])) - - return ctx - - -class JsonPath(_JsonLookupBase): - def __sql__(self, ctx): - return (ctx - .sql(self.node) - .literal('#>' if self._as_json else '#>>') - .sql(Value('{%s}' % ','.join(map(str, self.parts))))) - - -class ObjectSlice(_LookupNode): - @classmethod - def create(cls, node, value): - if isinstance(value, slice): - parts = [value.start or 0, value.stop or 0] - elif isinstance(value, int): - parts = [value] - else: - parts = map(int, value.split(':')) - return cls(node, parts) - - def __sql__(self, ctx): - return (ctx - .sql(self.node) - .literal('[%s]' % ':'.join(str(p + 1) for p in self.parts))) - - def __getitem__(self, value): - return ObjectSlice.create(self, value) - - -class IndexedFieldMixin(object): - default_index_type = 'GIN' - - def __init__(self, *args, **kwargs): - kwargs.setdefault('index', True) # By default, use an index. - super(IndexedFieldMixin, self).__init__(*args, **kwargs) - - -class ArrayField(IndexedFieldMixin, Field): - passthrough = True - - def __init__(self, field_class=IntegerField, field_kwargs=None, - dimensions=1, convert_values=False, *args, **kwargs): - self.__field = field_class(**(field_kwargs or {})) - self.dimensions = dimensions - self.convert_values = convert_values - self.field_type = self.__field.field_type - super(ArrayField, self).__init__(*args, **kwargs) - - def bind(self, model, name, set_attribute=True): - ret = super(ArrayField, self).bind(model, name, set_attribute) - self.__field.bind(model, '__array_%s' % name, False) - return ret - - def ddl_datatype(self, ctx): - data_type = self.__field.ddl_datatype(ctx) - return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') - - def db_value(self, value): - if value is None or isinstance(value, Node): - return value - elif self.convert_values: - return self._process(self.__field.db_value, value, self.dimensions) - else: - return value if isinstance(value, list) else list(value) - - def python_value(self, value): - if self.convert_values and value is not None: - conv = self.__field.python_value - if isinstance(value, list): - return self._process(conv, value, self.dimensions) - else: - return conv(value) - else: - return value - - def _process(self, conv, value, dimensions): - dimensions -= 1 - if dimensions == 0: - return [conv(v) for v in value] - else: - return [self._process(conv, v, dimensions) for v in value] - - def __getitem__(self, value): - return ObjectSlice.create(self, value) - - def _e(op): - def inner(self, rhs): - return Expression(self, op, ArrayValue(self, rhs)) - return inner - __eq__ = _e(OP.EQ) - __ne__ = _e(OP.NE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __hash__ = Field.__hash__ - - def contains(self, *items): - return Expression(self, ACONTAINS, ArrayValue(self, items)) - - def contains_any(self, *items): - return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) - - -class ArrayValue(Node): - def __init__(self, field, value): - self.field = field - self.value = value - - def __sql__(self, ctx): - return (ctx - .sql(Value(self.value, unpack=False)) - .literal('::') - .sql(self.field.ddl_datatype(ctx))) - - -class DateTimeTZField(DateTimeField): - field_type = 'TIMESTAMPTZ' - - -class HStoreField(IndexedFieldMixin, Field): - field_type = 'HSTORE' - __hash__ = Field.__hash__ - - def __getitem__(self, key): - return Expression(self, HKEY, Value(key)) - - def keys(self): - return fn.akeys(self) - - def values(self): - return fn.avals(self) - - def items(self): - return fn.hstore_to_matrix(self) - - def slice(self, *args): - return fn.slice(self, Value(list(args), unpack=False)) - - def exists(self, key): - return fn.exist(self, key) - - def defined(self, key): - return fn.defined(self, key) - - def update(self, **data): - return Expression(self, HUPDATE, data) - - def delete(self, *keys): - return fn.delete(self, Value(list(keys), unpack=False)) - - def contains(self, value): - if isinstance(value, dict): - rhs = Value(value, unpack=False) - return Expression(self, HCONTAINS_DICT, rhs) - elif isinstance(value, (list, tuple)): - rhs = Value(value, unpack=False) - return Expression(self, HCONTAINS_KEYS, rhs) - return Expression(self, HCONTAINS_KEY, value) - - def contains_any(self, *keys): - return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), - unpack=False)) - - -class JSONField(Field): - field_type = 'JSON' - _json_datatype = 'json' - - def __init__(self, dumps=None, *args, **kwargs): - if Json is None: - raise Exception('Your version of psycopg2 does not support JSON.') - self.dumps = dumps or json.dumps - super(JSONField, self).__init__(*args, **kwargs) - - def db_value(self, value): - if value is None: - return value - if not isinstance(value, Json): - return Cast(self.dumps(value), self._json_datatype) - return value - - def __getitem__(self, value): - return JsonLookup(self, [value]) - - def path(self, *keys): - return JsonPath(self, keys) - - def concat(self, value): - return super(JSONField, self).concat(Json(value)) - - -def cast_jsonb(node): - return NodeList((node, SQL('::jsonb')), glue='') - - -class BinaryJSONField(IndexedFieldMixin, JSONField): - field_type = 'JSONB' - _json_datatype = 'jsonb' - __hash__ = Field.__hash__ - - def contains(self, other): - if isinstance(other, (list, dict)): - return Expression(self, JSONB_CONTAINS, Json(other)) - return Expression(cast_jsonb(self), JSONB_EXISTS, other) - - def contained_by(self, other): - return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) - - def contains_any(self, *items): - return Expression( - cast_jsonb(self), - JSONB_CONTAINS_ANY_KEY, - Value(list(items), unpack=False)) - - def contains_all(self, *items): - return Expression( - cast_jsonb(self), - JSONB_CONTAINS_ALL_KEYS, - Value(list(items), unpack=False)) - - def has_key(self, key): - return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) - - def remove(self, *items): - return Expression( - cast_jsonb(self), - JSONB_REMOVE, - Value(list(items), unpack=False)) - - -class TSVectorField(IndexedFieldMixin, TextField): - field_type = 'TSVECTOR' - __hash__ = Field.__hash__ - - def match(self, query, language=None, plain=False): - params = (language, query) if language is not None else (query,) - func = fn.plainto_tsquery if plain else fn.to_tsquery - return Expression(self, TS_MATCH, func(*params)) - - -def Match(field, query, language=None): - params = (language, query) if language is not None else (query,) - field_params = (language, field) if language is not None else (field,) - return Expression( - fn.to_tsvector(*field_params), - TS_MATCH, - fn.to_tsquery(*params)) - - -class IntervalField(Field): - field_type = 'INTERVAL' - - -class FetchManyCursor(object): - __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') - - def __init__(self, cursor, array_size=None): - self.cursor = cursor - self.array_size = array_size or cursor.itersize - self.exhausted = False - self.iterable = self.row_gen() - - @property - def description(self): - return self.cursor.description - - def close(self): - self.cursor.close() - - def row_gen(self): - while True: - rows = self.cursor.fetchmany(self.array_size) - if not rows: - return - for row in rows: - yield row - - def fetchone(self): - if self.exhausted: - return - try: - return next(self.iterable) - except StopIteration: - self.exhausted = True - - -class ServerSideQuery(Node): - def __init__(self, query, array_size=None): - self.query = query - self.array_size = array_size - self._cursor_wrapper = None - - def __sql__(self, ctx): - return self.query.__sql__(ctx) - - def __iter__(self): - if self._cursor_wrapper is None: - self._execute(self.query._database) - return iter(self._cursor_wrapper.iterator()) - - def _execute(self, database): - if self._cursor_wrapper is None: - cursor = database.execute(self.query, named_cursor=True, - array_size=self.array_size) - self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) - return self._cursor_wrapper - - -def ServerSide(query, database=None, array_size=None): - if database is None: - database = query._database - with database.transaction(): - server_side_query = ServerSideQuery(query, array_size=array_size) - for row in server_side_query: - yield row - - -class _empty_object(object): - __slots__ = () - def __nonzero__(self): - return False - __bool__ = __nonzero__ - -__named_cursor__ = _empty_object() - - -class PostgresqlExtDatabase(PostgresqlDatabase): - def __init__(self, *args, **kwargs): - self._register_hstore = kwargs.pop('register_hstore', False) - self._server_side_cursors = kwargs.pop('server_side_cursors', False) - super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) - - def _connect(self): - conn = super(PostgresqlExtDatabase, self)._connect() - if self._register_hstore: - register_hstore(conn, globally=True) - return conn - - def cursor(self, commit=None): - if self.is_closed(): - if self.autoconnect: - self.connect() - else: - raise InterfaceError('Error, database connection not opened.') - if commit is __named_cursor__: - return self._state.conn.cursor(name=str(uuid.uuid1())) - return self._state.conn.cursor() - - def execute(self, query, commit=SENTINEL, named_cursor=False, - array_size=None, **context_options): - ctx = self.get_sql_context(**context_options) - sql, params = ctx.sql(query).query() - named_cursor = named_cursor or (self._server_side_cursors and - sql[:6].lower() == 'select') - if named_cursor: - commit = __named_cursor__ - cursor = self.execute_sql(sql, params, commit=commit) - if named_cursor: - cursor = FetchManyCursor(cursor, array_size) - return cursor diff --git a/libs/playhouse/reflection.py b/libs/playhouse/reflection.py deleted file mode 100644 index 3a8f525eb..000000000 --- a/libs/playhouse/reflection.py +++ /dev/null @@ -1,799 +0,0 @@ -try: - from collections import OrderedDict -except ImportError: - OrderedDict = dict -from collections import namedtuple -from inspect import isclass -import re - -from peewee import * -from peewee import _StringField -from peewee import _query_val_transform -from peewee import CommaNodeList -from peewee import SCOPE_VALUES -from peewee import make_snake_case -from peewee import text_type -try: - from pymysql.constants import FIELD_TYPE -except ImportError: - try: - from MySQLdb.constants import FIELD_TYPE - except ImportError: - FIELD_TYPE = None -try: - from playhouse import postgres_ext -except ImportError: - postgres_ext = None - -RESERVED_WORDS = set([ - 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', - 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', - 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', - 'return', 'try', 'while', 'with', 'yield', -]) - - -class UnknownField(object): - pass - - -class Column(object): - """ - Store metadata about a database column. - """ - primary_key_types = (IntegerField, AutoField) - - def __init__(self, name, field_class, raw_column_type, nullable, - primary_key=False, column_name=None, index=False, - unique=False, default=None, extra_parameters=None): - self.name = name - self.field_class = field_class - self.raw_column_type = raw_column_type - self.nullable = nullable - self.primary_key = primary_key - self.column_name = column_name - self.index = index - self.unique = unique - self.default = default - self.extra_parameters = extra_parameters - - # Foreign key metadata. - self.rel_model = None - self.related_name = None - self.to_field = None - - def __repr__(self): - attrs = [ - 'field_class', - 'raw_column_type', - 'nullable', - 'primary_key', - 'column_name'] - keyword_args = ', '.join( - '%s=%s' % (attr, getattr(self, attr)) - for attr in attrs) - return 'Column(%s, %s)' % (self.name, keyword_args) - - def get_field_parameters(self): - params = {} - if self.extra_parameters is not None: - params.update(self.extra_parameters) - - # Set up default attributes. - if self.nullable: - params['null'] = True - if self.field_class is ForeignKeyField or self.name != self.column_name: - params['column_name'] = "'%s'" % self.column_name - if self.primary_key and not issubclass(self.field_class, AutoField): - params['primary_key'] = True - if self.default is not None: - params['constraints'] = '[SQL("DEFAULT %s")]' % self.default - - # Handle ForeignKeyField-specific attributes. - if self.is_foreign_key(): - params['model'] = self.rel_model - if self.to_field: - params['field'] = "'%s'" % self.to_field - if self.related_name: - params['backref'] = "'%s'" % self.related_name - - # Handle indexes on column. - if not self.is_primary_key(): - if self.unique: - params['unique'] = 'True' - elif self.index and not self.is_foreign_key(): - params['index'] = 'True' - - return params - - def is_primary_key(self): - return self.field_class is AutoField or self.primary_key - - def is_foreign_key(self): - return self.field_class is ForeignKeyField - - def is_self_referential_fk(self): - return (self.field_class is ForeignKeyField and - self.rel_model == "'self'") - - def set_foreign_key(self, foreign_key, model_names, dest=None, - related_name=None): - self.foreign_key = foreign_key - self.field_class = ForeignKeyField - if foreign_key.dest_table == foreign_key.table: - self.rel_model = "'self'" - else: - self.rel_model = model_names[foreign_key.dest_table] - self.to_field = dest and dest.name or None - self.related_name = related_name or None - - def get_field(self): - # Generate the field definition for this column. - field_params = {} - for key, value in self.get_field_parameters().items(): - if isclass(value) and issubclass(value, Field): - value = value.__name__ - field_params[key] = value - - param_str = ', '.join('%s=%s' % (k, v) - for k, v in sorted(field_params.items())) - field = '%s = %s(%s)' % ( - self.name, - self.field_class.__name__, - param_str) - - if self.field_class is UnknownField: - field = '%s # %s' % (field, self.raw_column_type) - - return field - - -class Metadata(object): - column_map = {} - extension_import = '' - - def __init__(self, database): - self.database = database - self.requires_extension = False - - def execute(self, sql, *params): - return self.database.execute_sql(sql, params) - - def get_columns(self, table, schema=None): - metadata = OrderedDict( - (metadata.name, metadata) - for metadata in self.database.get_columns(table, schema)) - - # Look up the actual column type for each column. - column_types, extra_params = self.get_column_types(table, schema) - - # Look up the primary keys. - pk_names = self.get_primary_keys(table, schema) - if len(pk_names) == 1: - pk = pk_names[0] - if column_types[pk] is IntegerField: - column_types[pk] = AutoField - elif column_types[pk] is BigIntegerField: - column_types[pk] = BigAutoField - - columns = OrderedDict() - for name, column_data in metadata.items(): - field_class = column_types[name] - default = self._clean_default(field_class, column_data.default) - - columns[name] = Column( - name, - field_class=field_class, - raw_column_type=column_data.data_type, - nullable=column_data.null, - primary_key=column_data.primary_key, - column_name=name, - default=default, - extra_parameters=extra_params.get(name)) - - return columns - - def get_column_types(self, table, schema=None): - raise NotImplementedError - - def _clean_default(self, field_class, default): - if default is None or field_class in (AutoField, BigAutoField) or \ - default.lower() == 'null': - return - if issubclass(field_class, _StringField) and \ - isinstance(default, text_type) and not default.startswith("'"): - default = "'%s'" % default - return default or "''" - - def get_foreign_keys(self, table, schema=None): - return self.database.get_foreign_keys(table, schema) - - def get_primary_keys(self, table, schema=None): - return self.database.get_primary_keys(table, schema) - - def get_indexes(self, table, schema=None): - return self.database.get_indexes(table, schema) - - -class PostgresqlMetadata(Metadata): - column_map = { - 16: BooleanField, - 17: BlobField, - 20: BigIntegerField, - 21: IntegerField, - 23: IntegerField, - 25: TextField, - 700: FloatField, - 701: DoubleField, - 1042: CharField, # blank-padded CHAR - 1043: CharField, - 1082: DateField, - 1114: DateTimeField, - 1184: DateTimeField, - 1083: TimeField, - 1266: TimeField, - 1700: DecimalField, - 2950: TextField, # UUID - } - array_types = { - 1000: BooleanField, - 1001: BlobField, - 1005: SmallIntegerField, - 1007: IntegerField, - 1009: TextField, - 1014: CharField, - 1015: CharField, - 1016: BigIntegerField, - 1115: DateTimeField, - 1182: DateField, - 1183: TimeField, - } - extension_import = 'from playhouse.postgres_ext import *' - - def __init__(self, database): - super(PostgresqlMetadata, self).__init__(database) - - if postgres_ext is not None: - # Attempt to add types like HStore and JSON. - cursor = self.execute('select oid, typname, format_type(oid, NULL)' - ' from pg_type;') - results = cursor.fetchall() - - for oid, typname, formatted_type in results: - if typname == 'json': - self.column_map[oid] = postgres_ext.JSONField - elif typname == 'jsonb': - self.column_map[oid] = postgres_ext.BinaryJSONField - elif typname == 'hstore': - self.column_map[oid] = postgres_ext.HStoreField - elif typname == 'tsvector': - self.column_map[oid] = postgres_ext.TSVectorField - - for oid in self.array_types: - self.column_map[oid] = postgres_ext.ArrayField - - def get_column_types(self, table, schema): - column_types = {} - extra_params = {} - extension_types = set(( - postgres_ext.ArrayField, - postgres_ext.BinaryJSONField, - postgres_ext.JSONField, - postgres_ext.TSVectorField, - postgres_ext.HStoreField)) if postgres_ext is not None else set() - - # Look up the actual column type for each column. - identifier = '"%s"."%s"' % (schema, table) - cursor = self.execute('SELECT * FROM %s LIMIT 1' % identifier) - - # Store column metadata in dictionary keyed by column name. - for column_description in cursor.description: - name = column_description.name - oid = column_description.type_code - column_types[name] = self.column_map.get(oid, UnknownField) - if column_types[name] in extension_types: - self.requires_extension = True - if oid in self.array_types: - extra_params[name] = {'field_class': self.array_types[oid]} - - return column_types, extra_params - - def get_columns(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_columns(table, schema) - - def get_foreign_keys(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) - - def get_primary_keys(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_primary_keys(table, schema) - - def get_indexes(self, table, schema=None): - schema = schema or 'public' - return super(PostgresqlMetadata, self).get_indexes(table, schema) - - -class MySQLMetadata(Metadata): - if FIELD_TYPE is None: - column_map = {} - else: - column_map = { - FIELD_TYPE.BLOB: TextField, - FIELD_TYPE.CHAR: CharField, - FIELD_TYPE.DATE: DateField, - FIELD_TYPE.DATETIME: DateTimeField, - FIELD_TYPE.DECIMAL: DecimalField, - FIELD_TYPE.DOUBLE: FloatField, - FIELD_TYPE.FLOAT: FloatField, - FIELD_TYPE.INT24: IntegerField, - FIELD_TYPE.LONG_BLOB: TextField, - FIELD_TYPE.LONG: IntegerField, - FIELD_TYPE.LONGLONG: BigIntegerField, - FIELD_TYPE.MEDIUM_BLOB: TextField, - FIELD_TYPE.NEWDECIMAL: DecimalField, - FIELD_TYPE.SHORT: IntegerField, - FIELD_TYPE.STRING: CharField, - FIELD_TYPE.TIMESTAMP: DateTimeField, - FIELD_TYPE.TIME: TimeField, - FIELD_TYPE.TINY_BLOB: TextField, - FIELD_TYPE.TINY: IntegerField, - FIELD_TYPE.VAR_STRING: CharField, - } - - def __init__(self, database, **kwargs): - if 'password' in kwargs: - kwargs['passwd'] = kwargs.pop('password') - super(MySQLMetadata, self).__init__(database, **kwargs) - - def get_column_types(self, table, schema=None): - column_types = {} - - # Look up the actual column type for each column. - cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) - - # Store column metadata in dictionary keyed by column name. - for column_description in cursor.description: - name, type_code = column_description[:2] - column_types[name] = self.column_map.get(type_code, UnknownField) - - return column_types, {} - - -class SqliteMetadata(Metadata): - column_map = { - 'bigint': BigIntegerField, - 'blob': BlobField, - 'bool': BooleanField, - 'boolean': BooleanField, - 'char': CharField, - 'date': DateField, - 'datetime': DateTimeField, - 'decimal': DecimalField, - 'float': FloatField, - 'integer': IntegerField, - 'integer unsigned': IntegerField, - 'int': IntegerField, - 'long': BigIntegerField, - 'numeric': DecimalField, - 'real': FloatField, - 'smallinteger': IntegerField, - 'smallint': IntegerField, - 'smallint unsigned': IntegerField, - 'text': TextField, - 'time': TimeField, - 'varchar': CharField, - } - - begin = '(?:["\[\(]+)?' - end = '(?:["\]\)]+)?' - re_foreign_key = ( - '(?:FOREIGN KEY\s*)?' - '{begin}(.+?){end}\s+(?:.+\s+)?' - 'references\s+{begin}(.+?){end}' - '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) - re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' - - def _map_col(self, column_type): - raw_column_type = column_type.lower() - if raw_column_type in self.column_map: - field_class = self.column_map[raw_column_type] - elif re.search(self.re_varchar, raw_column_type): - field_class = CharField - else: - column_type = re.sub('\(.+\)', '', raw_column_type) - if column_type == '': - field_class = BareField - else: - field_class = self.column_map.get(column_type, UnknownField) - return field_class - - def get_column_types(self, table, schema=None): - column_types = {} - columns = self.database.get_columns(table) - - for column in columns: - column_types[column.name] = self._map_col(column.data_type) - - return column_types, {} - - -_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( - 'columns', - 'primary_keys', - 'foreign_keys', - 'model_names', - 'indexes')) - - -class DatabaseMetadata(_DatabaseMetadata): - def multi_column_indexes(self, table): - accum = [] - for index in self.indexes[table]: - if len(index.columns) > 1: - field_names = [self.columns[table][column].name - for column in index.columns - if column in self.columns[table]] - accum.append((field_names, index.unique)) - return accum - - def column_indexes(self, table): - accum = {} - for index in self.indexes[table]: - if len(index.columns) == 1: - accum[index.columns[0]] = index.unique - return accum - - -class Introspector(object): - pk_classes = [AutoField, IntegerField] - - def __init__(self, metadata, schema=None): - self.metadata = metadata - self.schema = schema - - def __repr__(self): - return '' % self.metadata.database - - @classmethod - def from_database(cls, database, schema=None): - if isinstance(database, PostgresqlDatabase): - metadata = PostgresqlMetadata(database) - elif isinstance(database, MySQLDatabase): - metadata = MySQLMetadata(database) - elif isinstance(database, SqliteDatabase): - metadata = SqliteMetadata(database) - else: - raise ValueError('Introspection not supported for %r' % database) - return cls(metadata, schema=schema) - - def get_database_class(self): - return type(self.metadata.database) - - def get_database_name(self): - return self.metadata.database.database - - def get_database_kwargs(self): - return self.metadata.database.connect_params - - def get_additional_imports(self): - if self.metadata.requires_extension: - return '\n' + self.metadata.extension_import - return '' - - def make_model_name(self, table, snake_case=True): - if snake_case: - table = make_snake_case(table) - model = re.sub('[^\w]+', '', table) - model_name = ''.join(sub.title() for sub in model.split('_')) - if not model_name[0].isalpha(): - model_name = 'T' + model_name - return model_name - - def make_column_name(self, column, is_foreign_key=False, snake_case=True): - column = column.strip() - if snake_case: - column = make_snake_case(column) - column = column.lower() - if is_foreign_key: - # Strip "_id" from foreign keys, unless the foreign-key happens to - # be named "_id", in which case the name is retained. - column = re.sub('_id$', '', column) or column - - # Remove characters that are invalid for Python identifiers. - column = re.sub('[^\w]+', '_', column) - if column in RESERVED_WORDS: - column += '_' - if len(column) and column[0].isdigit(): - column = '_' + column - return column - - def introspect(self, table_names=None, literal_column_names=False, - include_views=False, snake_case=True): - # Retrieve all the tables in the database. - tables = self.metadata.database.get_tables(schema=self.schema) - if include_views: - views = self.metadata.database.get_views(schema=self.schema) - tables.extend([view.name for view in views]) - - if table_names is not None: - tables = [table for table in tables if table in table_names] - table_set = set(tables) - - # Store a mapping of table name -> dictionary of columns. - columns = {} - - # Store a mapping of table name -> set of primary key columns. - primary_keys = {} - - # Store a mapping of table -> foreign keys. - foreign_keys = {} - - # Store a mapping of table name -> model name. - model_names = {} - - # Store a mapping of table name -> indexes. - indexes = {} - - # Gather the columns for each table. - for table in tables: - table_indexes = self.metadata.get_indexes(table, self.schema) - table_columns = self.metadata.get_columns(table, self.schema) - try: - foreign_keys[table] = self.metadata.get_foreign_keys( - table, self.schema) - except ValueError as exc: - err(*exc.args) - foreign_keys[table] = [] - else: - # If there is a possibility we could exclude a dependent table, - # ensure that we introspect it so FKs will work. - if table_names is not None: - for foreign_key in foreign_keys[table]: - if foreign_key.dest_table not in table_set: - tables.append(foreign_key.dest_table) - table_set.add(foreign_key.dest_table) - - model_names[table] = self.make_model_name(table, snake_case) - - # Collect sets of all the column names as well as all the - # foreign-key column names. - lower_col_names = set(column_name.lower() - for column_name in table_columns) - fks = set(fk_col.column for fk_col in foreign_keys[table]) - - for col_name, column in table_columns.items(): - if literal_column_names: - new_name = re.sub('[^\w]+', '_', col_name) - else: - new_name = self.make_column_name(col_name, col_name in fks, - snake_case) - - # If we have two columns, "parent" and "parent_id", ensure - # that when we don't introduce naming conflicts. - lower_name = col_name.lower() - if lower_name.endswith('_id') and new_name in lower_col_names: - new_name = col_name.lower() - - column.name = new_name - - for index in table_indexes: - if len(index.columns) == 1: - column = index.columns[0] - if column in table_columns: - table_columns[column].unique = index.unique - table_columns[column].index = True - - primary_keys[table] = self.metadata.get_primary_keys( - table, self.schema) - columns[table] = table_columns - indexes[table] = table_indexes - - # Gather all instances where we might have a `related_name` conflict, - # either due to multiple FKs on a table pointing to the same table, - # or a related_name that would conflict with an existing field. - related_names = {} - sort_fn = lambda foreign_key: foreign_key.column - for table in tables: - models_referenced = set() - for foreign_key in sorted(foreign_keys[table], key=sort_fn): - try: - column = columns[table][foreign_key.column] - except KeyError: - continue - - dest_table = foreign_key.dest_table - if dest_table in models_referenced: - related_names[column] = '%s_%s_set' % ( - dest_table, - column.name) - else: - models_referenced.add(dest_table) - - # On the second pass convert all foreign keys. - for table in tables: - for foreign_key in foreign_keys[table]: - src = columns[foreign_key.table][foreign_key.column] - try: - dest = columns[foreign_key.dest_table][ - foreign_key.dest_column] - except KeyError: - dest = None - - src.set_foreign_key( - foreign_key=foreign_key, - model_names=model_names, - dest=dest, - related_name=related_names.get(src)) - - return DatabaseMetadata( - columns, - primary_keys, - foreign_keys, - model_names, - indexes) - - def generate_models(self, skip_invalid=False, table_names=None, - literal_column_names=False, bare_fields=False, - include_views=False): - database = self.introspect(table_names, literal_column_names, - include_views) - models = {} - - class BaseModel(Model): - class Meta: - database = self.metadata.database - schema = self.schema - - def _create_model(table, models): - for foreign_key in database.foreign_keys[table]: - dest = foreign_key.dest_table - - if dest not in models and dest != table: - _create_model(dest, models) - - primary_keys = [] - columns = database.columns[table] - for column_name, column in columns.items(): - if column.primary_key: - primary_keys.append(column.name) - - multi_column_indexes = database.multi_column_indexes(table) - column_indexes = database.column_indexes(table) - - class Meta: - indexes = multi_column_indexes - table_name = table - - # Fix models with multi-column primary keys. - composite_key = False - if len(primary_keys) == 0: - primary_keys = columns.keys() - if len(primary_keys) > 1: - Meta.primary_key = CompositeKey(*[ - field.name for col, field in columns.items() - if col in primary_keys]) - composite_key = True - - attrs = {'Meta': Meta} - for column_name, column in columns.items(): - FieldClass = column.field_class - if FieldClass is not ForeignKeyField and bare_fields: - FieldClass = BareField - elif FieldClass is UnknownField: - FieldClass = BareField - - params = { - 'column_name': column_name, - 'null': column.nullable} - if column.primary_key and composite_key: - if FieldClass is AutoField: - FieldClass = IntegerField - params['primary_key'] = False - elif column.primary_key and FieldClass is not AutoField: - params['primary_key'] = True - if column.is_foreign_key(): - if column.is_self_referential_fk(): - params['model'] = 'self' - else: - dest_table = column.foreign_key.dest_table - params['model'] = models[dest_table] - if column.to_field: - params['field'] = column.to_field - - # Generate a unique related name. - params['backref'] = '%s_%s_rel' % (table, column_name) - - if column.default is not None: - constraint = SQL('DEFAULT %s' % column.default) - params['constraints'] = [constraint] - - if column_name in column_indexes and not \ - column.is_primary_key(): - if column_indexes[column_name]: - params['unique'] = True - elif not column.is_foreign_key(): - params['index'] = True - - attrs[column.name] = FieldClass(**params) - - try: - models[table] = type(str(table), (BaseModel,), attrs) - except ValueError: - if not skip_invalid: - raise - - # Actually generate Model classes. - for table, model in sorted(database.model_names.items()): - if table not in models: - _create_model(table, models) - - return models - - -def introspect(database, schema=None): - introspector = Introspector.from_database(database, schema=schema) - return introspector.introspect() - - -def generate_models(database, schema=None, **options): - introspector = Introspector.from_database(database, schema=schema) - return introspector.generate_models(**options) - - -def print_model(model, indexes=True, inline_indexes=False): - print(model._meta.name) - for field in model._meta.sorted_fields: - parts = [' %s %s' % (field.name, field.field_type)] - if field.primary_key: - parts.append(' PK') - elif inline_indexes: - if field.unique: - parts.append(' UNIQUE') - elif field.index: - parts.append(' INDEX') - if isinstance(field, ForeignKeyField): - parts.append(' FK: %s.%s' % (field.rel_model.__name__, - field.rel_field.name)) - print(''.join(parts)) - - if indexes: - index_list = model._meta.fields_to_index() - if not index_list: - return - - print('\nindex(es)') - for index in index_list: - parts = [' '] - ctx = model._meta.database.get_sql_context() - with ctx.scope_values(param='%s', quote='""'): - ctx.sql(CommaNodeList(index._expressions)) - if index._where: - ctx.literal(' WHERE ') - ctx.sql(index._where) - sql, params = ctx.query() - - clean = sql % tuple(map(_query_val_transform, params)) - parts.append(clean.replace('"', '')) - - if index._unique: - parts.append(' UNIQUE') - print(''.join(parts)) - - -def get_table_sql(model): - sql, params = model._schema._create_table().query() - if model._meta.database.param != '%s': - sql = sql.replace(model._meta.database.param, '%s') - - # Format and indent the table declaration, simplest possible approach. - match_obj = re.match('^(.+?\()(.+)(\).*)', sql) - create, columns, extra = match_obj.groups() - indented = ',\n'.join(' %s' % column for column in columns.split(', ')) - - clean = '\n'.join((create, indented, extra)).strip() - return clean % tuple(map(_query_val_transform, params)) - -def print_table_sql(model): - print(get_table_sql(model)) diff --git a/libs/playhouse/shortcuts.py b/libs/playhouse/shortcuts.py deleted file mode 100644 index 1772cf1d3..000000000 --- a/libs/playhouse/shortcuts.py +++ /dev/null @@ -1,228 +0,0 @@ -from peewee import * -from peewee import Alias -from peewee import SENTINEL -from peewee import callable_ - - -_clone_set = lambda s: set(s) if s else set() - - -def model_to_dict(model, recurse=True, backrefs=False, only=None, - exclude=None, seen=None, extra_attrs=None, - fields_from_query=None, max_depth=None, manytomany=False): - """ - Convert a model instance (and any related objects) to a dictionary. - - :param bool recurse: Whether foreign-keys should be recursed. - :param bool backrefs: Whether lists of related objects should be recursed. - :param only: A list (or set) of field instances indicating which fields - should be included. - :param exclude: A list (or set) of field instances that should be - excluded from the dictionary. - :param list extra_attrs: Names of model instance attributes or methods - that should be included. - :param SelectQuery fields_from_query: Query that was source of model. Take - fields explicitly selected by the query and serialize them. - :param int max_depth: Maximum depth to recurse, value <= 0 means no max. - :param bool manytomany: Process many-to-many fields. - """ - max_depth = -1 if max_depth is None else max_depth - if max_depth == 0: - recurse = False - - only = _clone_set(only) - extra_attrs = _clone_set(extra_attrs) - should_skip = lambda n: (n in exclude) or (only and (n not in only)) - - if fields_from_query is not None: - for item in fields_from_query._returning: - if isinstance(item, Field): - only.add(item) - elif isinstance(item, Alias): - extra_attrs.add(item._alias) - - data = {} - exclude = _clone_set(exclude) - seen = _clone_set(seen) - exclude |= seen - model_class = type(model) - - if manytomany: - for name, m2m in model._meta.manytomany.items(): - if should_skip(name): - continue - - exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) - for fkf in m2m.through_model._meta.refs: - exclude.add(fkf) - - accum = [] - for rel_obj in getattr(model, name): - accum.append(model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - max_depth=max_depth - 1)) - data[name] = accum - - for field in model._meta.sorted_fields: - if should_skip(field): - continue - - field_data = model.__data__.get(field.name) - if isinstance(field, ForeignKeyField) and recurse: - if field_data is not None: - seen.add(field) - rel_obj = getattr(model, field.name) - field_data = model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - seen=seen, - max_depth=max_depth - 1) - else: - field_data = None - - data[field.name] = field_data - - if extra_attrs: - for attr_name in extra_attrs: - attr = getattr(model, attr_name) - if callable_(attr): - data[attr_name] = attr() - else: - data[attr_name] = attr - - if backrefs and recurse: - for foreign_key, rel_model in model._meta.backrefs.items(): - if foreign_key.backref == '+': continue - descriptor = getattr(model_class, foreign_key.backref) - if descriptor in exclude or foreign_key in exclude: - continue - if only and (descriptor not in only) and (foreign_key not in only): - continue - - accum = [] - exclude.add(foreign_key) - related_query = getattr(model, foreign_key.backref) - - for rel_obj in related_query: - accum.append(model_to_dict( - rel_obj, - recurse=recurse, - backrefs=backrefs, - only=only, - exclude=exclude, - max_depth=max_depth - 1)) - - data[foreign_key.backref] = accum - - return data - - -def update_model_from_dict(instance, data, ignore_unknown=False): - meta = instance._meta - backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) - - for key, value in data.items(): - if key in meta.combined: - field = meta.combined[key] - is_backref = False - elif key in backrefs: - field = backrefs[key] - is_backref = True - elif ignore_unknown: - setattr(instance, key, value) - continue - else: - raise AttributeError('Unrecognized attribute "%s" for model ' - 'class %s.' % (key, type(instance))) - - is_foreign_key = isinstance(field, ForeignKeyField) - - if not is_backref and is_foreign_key and isinstance(value, dict): - try: - rel_instance = instance.__rel__[field.name] - except KeyError: - rel_instance = field.rel_model() - setattr( - instance, - field.name, - update_model_from_dict(rel_instance, value, ignore_unknown)) - elif is_backref and isinstance(value, (list, tuple)): - instances = [ - dict_to_model(field.model, row_data, ignore_unknown) - for row_data in value] - for rel_instance in instances: - setattr(rel_instance, field.name, instance) - setattr(instance, field.backref, instances) - else: - setattr(instance, field.name, value) - - return instance - - -def dict_to_model(model_class, data, ignore_unknown=False): - return update_model_from_dict(model_class(), data, ignore_unknown) - - -class ReconnectMixin(object): - """ - Mixin class that attempts to automatically reconnect to the database under - certain error conditions. - - For example, MySQL servers will typically close connections that are idle - for 28800 seconds ("wait_timeout" setting). If your application makes use - of long-lived connections, you may find your connections are closed after - a period of no activity. This mixin will attempt to reconnect automatically - when these errors occur. - - This mixin class probably should not be used with Postgres (unless you - REALLY know what you are doing) and definitely has no business being used - with Sqlite. If you wish to use with Postgres, you will need to adapt the - `reconnect_errors` attribute to something appropriate for Postgres. - """ - reconnect_errors = ( - # Error class, error message fragment (or empty string for all). - (OperationalError, '2006'), # MySQL server has gone away. - (OperationalError, '2013'), # Lost connection to MySQL server. - (OperationalError, '2014'), # Commands out of sync. - - # mysql-connector raises a slightly different error when an idle - # connection is terminated by the server. This is equivalent to 2013. - (OperationalError, 'MySQL Connection not available.'), - ) - - def __init__(self, *args, **kwargs): - super(ReconnectMixin, self).__init__(*args, **kwargs) - - # Normalize the reconnect errors to a more efficient data-structure. - self._reconnect_errors = {} - for exc_class, err_fragment in self.reconnect_errors: - self._reconnect_errors.setdefault(exc_class, []) - self._reconnect_errors[exc_class].append(err_fragment.lower()) - - def execute_sql(self, sql, params=None, commit=SENTINEL): - try: - return super(ReconnectMixin, self).execute_sql(sql, params, commit) - except Exception as exc: - exc_class = type(exc) - if exc_class not in self._reconnect_errors: - raise exc - - exc_repr = str(exc).lower() - for err_fragment in self._reconnect_errors[exc_class]: - if err_fragment in exc_repr: - break - else: - raise exc - - if not self.is_closed(): - self.close() - self.connect() - - return super(ReconnectMixin, self).execute_sql(sql, params, commit) diff --git a/libs/playhouse/signals.py b/libs/playhouse/signals.py deleted file mode 100644 index 4e92872e5..000000000 --- a/libs/playhouse/signals.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Provide django-style hooks for model events. -""" -from peewee import Model as _Model - - -class Signal(object): - def __init__(self): - self._flush() - - def _flush(self): - self._receivers = set() - self._receiver_list = [] - - def connect(self, receiver, name=None, sender=None): - name = name or receiver.__name__ - key = (name, sender) - if key not in self._receivers: - self._receivers.add(key) - self._receiver_list.append((name, receiver, sender)) - else: - raise ValueError('receiver named %s (for sender=%s) already ' - 'connected' % (name, sender or 'any')) - - def disconnect(self, receiver=None, name=None, sender=None): - if receiver: - name = name or receiver.__name__ - if not name: - raise ValueError('a receiver or a name must be provided') - - key = (name, sender) - if key not in self._receivers: - raise ValueError('receiver named %s for sender=%s not found.' % - (name, sender or 'any')) - - self._receivers.remove(key) - self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list - if n != name and s != sender] - - def __call__(self, name=None, sender=None): - def decorator(fn): - self.connect(fn, name, sender) - return fn - return decorator - - def send(self, instance, *args, **kwargs): - sender = type(instance) - responses = [] - for n, r, s in self._receiver_list: - if s is None or isinstance(instance, s): - responses.append((r, r(sender, instance, *args, **kwargs))) - return responses - - -pre_save = Signal() -post_save = Signal() -pre_delete = Signal() -post_delete = Signal() -pre_init = Signal() - - -class Model(_Model): - def __init__(self, *args, **kwargs): - super(Model, self).__init__(*args, **kwargs) - pre_init.send(self) - - def save(self, *args, **kwargs): - pk_value = self._pk if self._meta.primary_key else True - created = kwargs.get('force_insert', False) or not bool(pk_value) - pre_save.send(self, created=created) - ret = super(Model, self).save(*args, **kwargs) - post_save.send(self, created=created) - return ret - - def delete_instance(self, *args, **kwargs): - pre_delete.send(self) - ret = super(Model, self).delete_instance(*args, **kwargs) - post_delete.send(self) - return ret diff --git a/libs/playhouse/sqlcipher_ext.py b/libs/playhouse/sqlcipher_ext.py deleted file mode 100644 index 9bad1eca6..000000000 --- a/libs/playhouse/sqlcipher_ext.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Peewee integration with pysqlcipher. - -Project page: https://github.com/leapcode/pysqlcipher/ - -**WARNING!!! EXPERIMENTAL!!!** - -* Although this extention's code is short, it has not been properly - peer-reviewed yet and may have introduced vulnerabilities. - -Also note that this code relies on pysqlcipher and sqlcipher, and -the code there might have vulnerabilities as well, but since these -are widely used crypto modules, we can expect "short zero days" there. - -Example usage: - - from peewee.playground.ciphersql_ext import SqlCipherDatabase - db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") - -* `passphrase`: should be "long enough". - Note that *length beats vocabulary* (much exponential), and even - a lowercase-only passphrase like easytorememberyethardforotherstoguess - packs more noise than 8 random printable characters and *can* be memorized. - -When opening an existing database, passphrase should be the one used when the -database was created. If the passphrase is incorrect, an exception will only be -raised **when you access the database**. - -If you need to ask for an interactive passphrase, here's example code you can -put after the `db = ...` line: - - try: # Just access the database so that it checks the encryption. - db.get_tables() - # We're looking for a DatabaseError with a specific error message. - except peewee.DatabaseError as e: - # Check whether the message *means* "passphrase is wrong" - if e.args[0] == 'file is encrypted or is not a database': - raise Exception('Developer should Prompt user for passphrase ' - 'again.') - else: - # A different DatabaseError. Raise it. - raise e - -See a more elaborate example with this code at -https://gist.github.com/thedod/11048875 -""" -import datetime -import decimal -import sys - -from peewee import * -from playhouse.sqlite_ext import SqliteExtDatabase -if sys.version_info[0] != 3: - from pysqlcipher import dbapi2 as sqlcipher -else: - try: - from sqlcipher3 import dbapi2 as sqlcipher - except ImportError: - from pysqlcipher3 import dbapi2 as sqlcipher - -sqlcipher.register_adapter(decimal.Decimal, str) -sqlcipher.register_adapter(datetime.date, str) -sqlcipher.register_adapter(datetime.time, str) - - -class _SqlCipherDatabase(object): - def _connect(self): - params = dict(self.connect_params) - passphrase = params.pop('passphrase', '').replace("'", "''") - - conn = sqlcipher.connect(self.database, isolation_level=None, **params) - try: - if passphrase: - conn.execute("PRAGMA key='%s'" % passphrase) - self._add_conn_hooks(conn) - except: - conn.close() - raise - return conn - - def set_passphrase(self, passphrase): - if not self.is_closed(): - raise ImproperlyConfigured('Cannot set passphrase when database ' - 'is open. To change passphrase of an ' - 'open database use the rekey() method.') - - self.connect_params['passphrase'] = passphrase - - def rekey(self, passphrase): - if self.is_closed(): - self.connect() - - self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) - self.connect_params['passphrase'] = passphrase - return True - - -class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): - pass - - -class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): - pass diff --git a/libs/playhouse/sqlite_changelog.py b/libs/playhouse/sqlite_changelog.py deleted file mode 100644 index b036af20f..000000000 --- a/libs/playhouse/sqlite_changelog.py +++ /dev/null @@ -1,123 +0,0 @@ -from peewee import * -from playhouse.sqlite_ext import JSONField - - -class BaseChangeLog(Model): - timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')]) - action = TextField() - table = TextField() - primary_key = IntegerField() - changes = JSONField() - - -class ChangeLog(object): - # Model class that will serve as the base for the changelog. This model - # will be subclassed and mapped to your application database. - base_model = BaseChangeLog - - # Template for the triggers that handle updating the changelog table. - # table: table name - # action: insert / update / delete - # new_old: NEW or OLD (OLD is for DELETE) - # primary_key: table primary key column name - # column_array: output of build_column_array() - # change_table: changelog table name - template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s - AFTER %(action)s ON %(table)s - BEGIN - INSERT INTO %(change_table)s - ("action", "table", "primary_key", "changes") - SELECT - '%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes" - FROM ( - SELECT json_group_object( - col, - json_array("oldval", "newval")) AS "changes" - FROM ( - SELECT json_extract(value, '$[0]') as "col", - json_extract(value, '$[1]') as "oldval", - json_extract(value, '$[2]') as "newval" - FROM json_each(json_array(%(column_array)s)) - WHERE "oldval" IS NOT "newval" - ) - ); - END;""" - - drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s' - - _actions = ('INSERT', 'UPDATE', 'DELETE') - - def __init__(self, db, table_name='changelog'): - self.db = db - self.table_name = table_name - - def _build_column_array(self, model, use_old, use_new, skip_fields=None): - # Builds a list of SQL expressions for each field we are tracking. This - # is used as the data source for change tracking in our trigger. - col_array = [] - for field in model._meta.sorted_fields: - if field.primary_key: - continue - - if skip_fields is not None and field.name in skip_fields: - continue - - column = field.column_name - new = 'NULL' if not use_new else 'NEW."%s"' % column - old = 'NULL' if not use_old else 'OLD."%s"' % column - - if isinstance(field, JSONField): - # Ensure that values are cast to JSON so that the serialization - # is preserved when calculating the old / new. - if use_old: old = 'json(%s)' % old - if use_new: new = 'json(%s)' % new - - col_array.append("json_array('%s', %s, %s)" % (column, old, new)) - - return ', '.join(col_array) - - def trigger_sql(self, model, action, skip_fields=None): - assert action in self._actions - use_old = action != 'INSERT' - use_new = action != 'DELETE' - cols = self._build_column_array(model, use_old, use_new, skip_fields) - return self.template % { - 'table': model._meta.table_name, - 'action': action, - 'new_old': 'NEW' if action != 'DELETE' else 'OLD', - 'primary_key': model._meta.primary_key.column_name, - 'column_array': cols, - 'change_table': self.table_name} - - def drop_trigger_sql(self, model, action): - assert action in self._actions - return self.drop_template % { - 'table': model._meta.table_name, - 'action': action} - - @property - def model(self): - if not hasattr(self, '_changelog_model'): - class ChangeLog(self.base_model): - class Meta: - database = self.db - table_name = self.table_name - self._changelog_model = ChangeLog - - return self._changelog_model - - def install(self, model, skip_fields=None, drop=True, insert=True, - update=True, delete=True, create_table=True): - ChangeLog = self.model - if create_table: - ChangeLog.create_table() - - actions = list(zip((insert, update, delete), self._actions)) - if drop: - for _, action in actions: - self.db.execute_sql(self.drop_trigger_sql(model, action)) - - for enabled, action in actions: - if enabled: - sql = self.trigger_sql(model, action, skip_fields) - self.db.execute_sql(sql) diff --git a/libs/playhouse/sqlite_ext.py b/libs/playhouse/sqlite_ext.py deleted file mode 100644 index d9504c5fd..000000000 --- a/libs/playhouse/sqlite_ext.py +++ /dev/null @@ -1,1293 +0,0 @@ -import json -import math -import re -import struct -import sys - -from peewee import * -from peewee import ColumnBase -from peewee import EnclosedNodeList -from peewee import Entity -from peewee import Expression -from peewee import Node -from peewee import NodeList -from peewee import OP -from peewee import VirtualField -from peewee import merge_dict -from peewee import sqlite3 -try: - from playhouse._sqlite_ext import ( - backup, - backup_to_file, - Blob, - ConnectionHelper, - register_bloomfilter, - register_hash_functions, - register_rank_functions, - sqlite_get_db_status, - sqlite_get_status, - TableFunction, - ZeroBlob, - ) - CYTHON_SQLITE_EXTENSIONS = True -except ImportError: - CYTHON_SQLITE_EXTENSIONS = False - - -if sys.version_info[0] == 3: - basestring = str - - -FTS3_MATCHINFO = 'pcx' -FTS4_MATCHINFO = 'pcnalx' -if sqlite3 is not None: - FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 -else: - FTS_VERSION = 3 - -FTS5_MIN_SQLITE_VERSION = (3, 9, 0) - - -class RowIDField(AutoField): - auto_increment = True - column_name = name = required_name = 'rowid' - - def bind(self, model, name, *args): - if name != self.required_name: - raise ValueError('%s must be named "%s".' % - (type(self), self.required_name)) - super(RowIDField, self).bind(model, name, *args) - - -class DocIDField(RowIDField): - column_name = name = required_name = 'docid' - - -class AutoIncrementField(AutoField): - def ddl(self, ctx): - node_list = super(AutoIncrementField, self).ddl(ctx) - return NodeList((node_list, SQL('AUTOINCREMENT'))) - - -class TDecimalField(DecimalField): - field_type = 'TEXT' - def get_modifiers(self): pass - - -class JSONPath(ColumnBase): - def __init__(self, field, path=None): - super(JSONPath, self).__init__() - self._field = field - self._path = path or () - - @property - def path(self): - return Value('$%s' % ''.join(self._path)) - - def __getitem__(self, idx): - if isinstance(idx, int): - item = '[%s]' % idx - else: - item = '.%s' % idx - return JSONPath(self._field, self._path + (item,)) - - def set(self, value, as_json=None): - if as_json or isinstance(value, (list, dict)): - value = fn.json(self._field._json_dumps(value)) - return fn.json_set(self._field, self.path, value) - - def update(self, value): - return self.set(fn.json_patch(self, self._field._json_dumps(value))) - - def remove(self): - return fn.json_remove(self._field, self.path) - - def json_type(self): - return fn.json_type(self._field, self.path) - - def length(self): - return fn.json_array_length(self._field, self.path) - - def children(self): - return fn.json_each(self._field, self.path) - - def tree(self): - return fn.json_tree(self._field, self.path) - - def __sql__(self, ctx): - return ctx.sql(fn.json_extract(self._field, self.path) - if self._path else self._field) - - -class JSONField(TextField): - field_type = 'JSON' - - def __init__(self, json_dumps=None, json_loads=None, **kwargs): - self._json_dumps = json_dumps or json.dumps - self._json_loads = json_loads or json.loads - super(JSONField, self).__init__(**kwargs) - - def python_value(self, value): - if value is not None: - try: - return self._json_loads(value) - except (TypeError, ValueError): - return value - - def db_value(self, value): - if value is not None: - if not isinstance(value, Node): - value = fn.json(self._json_dumps(value)) - return value - - def _e(op): - def inner(self, rhs): - if isinstance(rhs, (list, dict)): - rhs = Value(rhs, converter=self.db_value, unpack=False) - return Expression(self, op, rhs) - return inner - __eq__ = _e(OP.EQ) - __ne__ = _e(OP.NE) - __gt__ = _e(OP.GT) - __ge__ = _e(OP.GTE) - __lt__ = _e(OP.LT) - __le__ = _e(OP.LTE) - __hash__ = Field.__hash__ - - def __getitem__(self, item): - return JSONPath(self)[item] - - def set(self, value, as_json=None): - return JSONPath(self).set(value, as_json) - - def update(self, data): - return JSONPath(self).update(data) - - def remove(self): - return JSONPath(self).remove() - - def json_type(self): - return fn.json_type(self) - - def length(self): - return fn.json_array_length(self) - - def children(self): - """ - Schema of `json_each` and `json_tree`: - - key, - value, - type TEXT (object, array, string, etc), - atom (value for primitive/scalar types, NULL for array and object) - id INTEGER (unique identifier for element) - parent INTEGER (unique identifier of parent element or NULL) - fullkey TEXT (full path describing element) - path TEXT (path to the container of the current element) - json JSON hidden (1st input parameter to function) - root TEXT hidden (2nd input parameter, path at which to start) - """ - return fn.json_each(self) - - def tree(self): - return fn.json_tree(self) - - -class SearchField(Field): - def __init__(self, unindexed=False, column_name=None, **k): - if k: - raise ValueError('SearchField does not accept these keyword ' - 'arguments: %s.' % sorted(k)) - super(SearchField, self).__init__(unindexed=unindexed, - column_name=column_name, null=True) - - def match(self, term): - return match(self, term) - - -class VirtualTableSchemaManager(SchemaManager): - def _create_virtual_table(self, safe=True, **options): - options = self.model.clean_options( - merge_dict(self.model._meta.options, options)) - - # Structure: - # CREATE VIRTUAL TABLE - # USING - # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) - ctx = self._create_context() - ctx.literal('CREATE VIRTUAL TABLE ') - if safe: - ctx.literal('IF NOT EXISTS ') - (ctx - .sql(self.model) - .literal(' USING ')) - - ext_module = self.model._meta.extension_module - if isinstance(ext_module, Node): - return ctx.sql(ext_module) - - ctx.sql(SQL(ext_module)).literal(' ') - arguments = [] - meta = self.model._meta - - if meta.prefix_arguments: - arguments.extend([SQL(a) for a in meta.prefix_arguments]) - - # Constraints, data-types, foreign and primary keys are all omitted. - for field in meta.sorted_fields: - if isinstance(field, (RowIDField)) or field._hidden: - continue - field_def = [Entity(field.column_name)] - if field.unindexed: - field_def.append(SQL('UNINDEXED')) - arguments.append(NodeList(field_def)) - - if meta.arguments: - arguments.extend([SQL(a) for a in meta.arguments]) - - if options: - arguments.extend(self._create_table_option_sql(options)) - return ctx.sql(EnclosedNodeList(arguments)) - - def _create_table(self, safe=True, **options): - if issubclass(self.model, VirtualModel): - return self._create_virtual_table(safe, **options) - - return super(VirtualTableSchemaManager, self)._create_table( - safe, **options) - - -class VirtualModel(Model): - class Meta: - arguments = None - extension_module = None - prefix_arguments = None - primary_key = False - schema_manager_class = VirtualTableSchemaManager - - @classmethod - def clean_options(cls, options): - return options - - -class BaseFTSModel(VirtualModel): - @classmethod - def clean_options(cls, options): - content = options.get('content') - prefix = options.get('prefix') - tokenize = options.get('tokenize') - - if isinstance(content, basestring) and content == '': - # Special-case content-less full-text search tables. - options['content'] = "''" - elif isinstance(content, Field): - # Special-case to ensure fields are fully-qualified. - options['content'] = Entity(content.model._meta.table_name, - content.column_name) - - if prefix: - if isinstance(prefix, (list, tuple)): - prefix = ','.join([str(i) for i in prefix]) - options['prefix'] = "'%s'" % prefix.strip("' ") - - if tokenize and cls._meta.extension_module.lower() == 'fts5': - # Tokenizers need to be in quoted string for FTS5, but not for FTS3 - # or FTS4. - options['tokenize'] = '"%s"' % tokenize - - return options - - -class FTSModel(BaseFTSModel): - """ - VirtualModel class for creating tables that use either the FTS3 or FTS4 - search extensions. Peewee automatically determines which version of the - FTS extension is supported and will use FTS4 if possible. - """ - # FTS3/4 uses "docid" in the same way a normal table uses "rowid". - docid = DocIDField() - - class Meta: - extension_module = 'FTS%s' % FTS_VERSION - - @classmethod - def _fts_cmd(cls, cmd): - tbl = cls._meta.table_name - res = cls._meta.database.execute_sql( - "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) - return res.fetchone() - - @classmethod - def optimize(cls): - return cls._fts_cmd('optimize') - - @classmethod - def rebuild(cls): - return cls._fts_cmd('rebuild') - - @classmethod - def integrity_check(cls): - return cls._fts_cmd('integrity-check') - - @classmethod - def merge(cls, blocks=200, segments=8): - return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) - - @classmethod - def automerge(cls, state=True): - return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) - - @classmethod - def match(cls, term): - """ - Generate a `MATCH` expression appropriate for searching this table. - """ - return match(cls._meta.entity, term) - - @classmethod - def rank(cls, *weights): - matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) - return fn.fts_rank(matchinfo, *weights) - - @classmethod - def bm25(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_bm25(match_info, *weights) - - @classmethod - def bm25f(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_bm25f(match_info, *weights) - - @classmethod - def lucene(cls, *weights): - match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) - return fn.fts_lucene(match_info, *weights) - - @classmethod - def _search(cls, term, weights, with_score, score_alias, score_fn, - explicit_ordering): - if not weights: - rank = score_fn() - elif isinstance(weights, dict): - weight_args = [] - for field in cls._meta.sorted_fields: - # Attempt to get the specified weight of the field by looking - # it up using it's field instance followed by name. - field_weight = weights.get(field, weights.get(field.name, 1.0)) - weight_args.append(field_weight) - rank = score_fn(*weight_args) - else: - rank = score_fn(*weights) - - selection = () - order_by = rank - if with_score: - selection = (cls, rank.alias(score_alias)) - if with_score and not explicit_ordering: - order_by = SQL(score_alias) - - return (cls - .select(*selection) - .where(cls.match(term)) - .order_by(order_by)) - - @classmethod - def search(cls, term, weights=None, with_score=False, score_alias='score', - explicit_ordering=False): - """Full-text search using selected `term`.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.rank, - explicit_ordering) - - @classmethod - def search_bm25(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.bm25, - explicit_ordering) - - @classmethod - def search_bm25f(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.bm25f, - explicit_ordering) - - @classmethod - def search_lucene(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search for selected `term` using BM25 algorithm.""" - return cls._search( - term, - weights, - with_score, - score_alias, - cls.lucene, - explicit_ordering) - - -_alphabet = 'abcdefghijklmnopqrstuvwxyz' -_alphanum = (set('\t ,"(){}*:_+0123456789') | - set(_alphabet) | - set(_alphabet.upper()) | - set((chr(26),))) -_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) -_quote_re = re.compile('(?:[^\s"]|"(?:\\.|[^"])*")+') - - -class FTS5Model(BaseFTSModel): - """ - Requires SQLite >= 3.9.0. - - Table options: - - content: table name of external content, or empty string for "contentless" - content_rowid: column name of external content primary key - prefix: integer(s). Ex: '2' or '2 3 4' - tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' - - The unicode tokenizer supports the following parameters: - - * remove_diacritics (1 or 0, default is 1) - * tokenchars (string of characters, e.g. '-_' - * separators (string of characters) - - Parameters are passed as alternating parameter name and value, so: - - {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} - - Content-less tables: - - If you don't need the full-text content in it's original form, you can - specify a content-less table. Searches and auxiliary functions will work - as usual, but the only values returned when SELECT-ing can be rowid. Also - content-less tables do not support UPDATE or DELETE. - - External content tables: - - You can set up triggers to sync these, e.g. - - -- Create a table. And an external content fts5 table to index it. - CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); - CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); - - -- Triggers to keep the FTS index up to date. - CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN - INSERT INTO ft(rowid, b) VALUES (new.a, new.b); - END; - CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN - INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); - END; - CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN - INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); - INSERT INTO ft(rowid, b) VALUES (new.a, new.b); - END; - - Built-in auxiliary functions: - - * bm25(tbl[, weight_0, ... weight_n]) - * highlight(tbl, col_idx, prefix, suffix) - * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) - """ - # FTS5 does not support declared primary keys, but we can use the - # implicit rowid. - rowid = RowIDField() - - class Meta: - extension_module = 'fts5' - - _error_messages = { - 'field_type': ('Besides the implicit `rowid` column, all columns must ' - 'be instances of SearchField'), - 'index': 'Secondary indexes are not supported for FTS5 models', - 'pk': 'FTS5 models must use the default `rowid` primary key', - } - - @classmethod - def validate_model(cls): - # Perform FTS5-specific validation and options post-processing. - if cls._meta.primary_key.name != 'rowid': - raise ImproperlyConfigured(cls._error_messages['pk']) - for field in cls._meta.fields.values(): - if not isinstance(field, (SearchField, RowIDField)): - raise ImproperlyConfigured(cls._error_messages['field_type']) - if cls._meta.indexes: - raise ImproperlyConfigured(cls._error_messages['index']) - - @classmethod - def fts5_installed(cls): - if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: - return False - - # Test in-memory DB to determine if the FTS5 extension is installed. - tmp_db = sqlite3.connect(':memory:') - try: - tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') - except: - try: - tmp_db.enable_load_extension(True) - tmp_db.load_extension('fts5') - except: - return False - else: - cls._meta.database.load_extension('fts5') - finally: - tmp_db.close() - - return True - - @staticmethod - def validate_query(query): - """ - Simple helper function to indicate whether a search query is a - valid FTS5 query. Note: this simply looks at the characters being - used, and is not guaranteed to catch all problematic queries. - """ - tokens = _quote_re.findall(query) - for token in tokens: - if token.startswith('"') and token.endswith('"'): - continue - if set(token) & _invalid_ascii: - return False - return True - - @staticmethod - def clean_query(query, replace=chr(26)): - """ - Clean a query of invalid tokens. - """ - accum = [] - any_invalid = False - tokens = _quote_re.findall(query) - for token in tokens: - if token.startswith('"') and token.endswith('"'): - accum.append(token) - continue - token_set = set(token) - invalid_for_token = token_set & _invalid_ascii - if invalid_for_token: - any_invalid = True - for c in invalid_for_token: - token = token.replace(c, replace) - accum.append(token) - - if any_invalid: - return ' '.join(accum) - return query - - @classmethod - def match(cls, term): - """ - Generate a `MATCH` expression appropriate for searching this table. - """ - return match(cls._meta.entity, term) - - @classmethod - def rank(cls, *args): - return cls.bm25(*args) if args else SQL('rank') - - @classmethod - def bm25(cls, *weights): - return fn.bm25(cls._meta.entity, *weights) - - @classmethod - def search(cls, term, weights=None, with_score=False, score_alias='score', - explicit_ordering=False): - """Full-text search using selected `term`.""" - return cls.search_bm25( - FTS5Model.clean_query(term), - weights, - with_score, - score_alias, - explicit_ordering) - - @classmethod - def search_bm25(cls, term, weights=None, with_score=False, - score_alias='score', explicit_ordering=False): - """Full-text search using selected `term`.""" - if not weights: - rank = SQL('rank') - elif isinstance(weights, dict): - weight_args = [] - for field in cls._meta.sorted_fields: - if isinstance(field, SearchField) and not field.unindexed: - weight_args.append( - weights.get(field, weights.get(field.name, 1.0))) - rank = fn.bm25(cls._meta.entity, *weight_args) - else: - rank = fn.bm25(cls._meta.entity, *weights) - - selection = () - order_by = rank - if with_score: - selection = (cls, rank.alias(score_alias)) - if with_score and not explicit_ordering: - order_by = SQL(score_alias) - - return (cls - .select(*selection) - .where(cls.match(FTS5Model.clean_query(term))) - .order_by(order_by)) - - @classmethod - def _fts_cmd_sql(cls, cmd, **extra_params): - tbl = cls._meta.entity - columns = [tbl] - values = [cmd] - for key, value in extra_params.items(): - columns.append(Entity(key)) - values.append(value) - - return NodeList(( - SQL('INSERT INTO'), - cls._meta.entity, - EnclosedNodeList(columns), - SQL('VALUES'), - EnclosedNodeList(values))) - - @classmethod - def _fts_cmd(cls, cmd, **extra_params): - query = cls._fts_cmd_sql(cmd, **extra_params) - return cls._meta.database.execute(query) - - @classmethod - def automerge(cls, level): - if not (0 <= level <= 16): - raise ValueError('level must be between 0 and 16') - return cls._fts_cmd('automerge', rank=level) - - @classmethod - def merge(cls, npages): - return cls._fts_cmd('merge', rank=npages) - - @classmethod - def set_pgsz(cls, pgsz): - return cls._fts_cmd('pgsz', rank=pgsz) - - @classmethod - def set_rank(cls, rank_expression): - return cls._fts_cmd('rank', rank=rank_expression) - - @classmethod - def delete_all(cls): - return cls._fts_cmd('delete-all') - - @classmethod - def VocabModel(cls, table_type='row', table=None): - if table_type not in ('row', 'col', 'instance'): - raise ValueError('table_type must be either "row", "col" or ' - '"instance".') - - attr = '_vocab_model_%s' % table_type - - if not hasattr(cls, attr): - class Meta: - database = cls._meta.database - table_name = table or cls._meta.table_name + '_v' - extension_module = fn.fts5vocab( - cls._meta.entity, - SQL(table_type)) - - attrs = { - 'term': VirtualField(TextField), - 'doc': IntegerField(), - 'cnt': IntegerField(), - 'rowid': RowIDField(), - 'Meta': Meta, - } - if table_type == 'col': - attrs['col'] = VirtualField(TextField) - elif table_type == 'instance': - attrs['offset'] = VirtualField(IntegerField) - - class_name = '%sVocab' % cls.__name__ - setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) - - return getattr(cls, attr) - - -def ClosureTable(model_class, foreign_key=None, referencing_class=None, - referencing_key=None): - """Model factory for the transitive closure extension.""" - if referencing_class is None: - referencing_class = model_class - - if foreign_key is None: - for field_obj in model_class._meta.refs: - if field_obj.rel_model is model_class: - foreign_key = field_obj - break - else: - raise ValueError('Unable to find self-referential foreign key.') - - source_key = model_class._meta.primary_key - if referencing_key is None: - referencing_key = source_key - - class BaseClosureTable(VirtualModel): - depth = VirtualField(IntegerField) - id = VirtualField(IntegerField) - idcolumn = VirtualField(TextField) - parentcolumn = VirtualField(TextField) - root = VirtualField(IntegerField) - tablename = VirtualField(TextField) - - class Meta: - extension_module = 'transitive_closure' - - @classmethod - def descendants(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(source_key == cls.id)) - .where(cls.root == node) - .objects()) - if depth is not None: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def ancestors(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(source_key == cls.root)) - .where(cls.id == node) - .objects()) - if depth: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def siblings(cls, node, include_node=False): - if referencing_class is model_class: - # self-join - fk_value = node.__data__.get(foreign_key.name) - query = model_class.select().where(foreign_key == fk_value) - else: - # siblings as given in reference_class - siblings = (referencing_class - .select(referencing_key) - .join(cls, on=(foreign_key == cls.root)) - .where((cls.id == node) & (cls.depth == 1))) - - # the according models - query = (model_class - .select() - .where(source_key << siblings) - .objects()) - - if not include_node: - query = query.where(source_key != node) - - return query - - class Meta: - database = referencing_class._meta.database - options = { - 'tablename': referencing_class._meta.table_name, - 'idcolumn': referencing_key.column_name, - 'parentcolumn': foreign_key.column_name} - primary_key = False - - name = '%sClosure' % model_class.__name__ - return type(name, (BaseClosureTable,), {'Meta': Meta}) - - -class LSMTable(VirtualModel): - class Meta: - extension_module = 'lsm1' - filename = None - - @classmethod - def clean_options(cls, options): - filename = cls._meta.filename - if not filename: - raise ValueError('LSM1 extension requires that you specify a ' - 'filename for the LSM database.') - else: - if len(filename) >= 2 and filename[0] != '"': - filename = '"%s"' % filename - if not cls._meta.primary_key: - raise ValueError('LSM1 models must specify a primary-key field.') - - key = cls._meta.primary_key - if isinstance(key, AutoField): - raise ValueError('LSM1 models must explicitly declare a primary ' - 'key field.') - if not isinstance(key, (TextField, BlobField, IntegerField)): - raise ValueError('LSM1 key must be a TextField, BlobField, or ' - 'IntegerField.') - key._hidden = True - if isinstance(key, IntegerField): - data_type = 'UINT' - elif isinstance(key, BlobField): - data_type = 'BLOB' - else: - data_type = 'TEXT' - cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] - - # Does the key map to a scalar value, or a tuple of values? - if len(cls._meta.sorted_fields) == 2: - cls._meta._value_field = cls._meta.sorted_fields[1] - else: - cls._meta._value_field = None - - return options - - @classmethod - def load_extension(cls, path='lsm.so'): - cls._meta.database.load_extension(path) - - @staticmethod - def slice_to_expr(key, idx): - if idx.start is not None and idx.stop is not None: - return key.between(idx.start, idx.stop) - elif idx.start is not None: - return key >= idx.start - elif idx.stop is not None: - return key <= idx.stop - - @staticmethod - def _apply_lookup_to_query(query, key, lookup): - if isinstance(lookup, slice): - expr = LSMTable.slice_to_expr(key, lookup) - if expr is not None: - query = query.where(expr) - return query, False - elif isinstance(lookup, Expression): - return query.where(lookup), False - else: - return query.where(key == lookup), True - - @classmethod - def get_by_id(cls, pk): - query, is_single = cls._apply_lookup_to_query( - cls.select().namedtuples(), - cls._meta.primary_key, - pk) - - if is_single: - try: - row = query.get() - except cls.DoesNotExist: - raise KeyError(pk) - return row[1] if cls._meta._value_field is not None else row - else: - return query - - @classmethod - def set_by_id(cls, key, value): - if cls._meta._value_field is not None: - data = {cls._meta._value_field: value} - elif isinstance(value, tuple): - data = {} - for field, fval in zip(cls._meta.sorted_fields[1:], value): - data[field] = fval - elif isinstance(value, dict): - data = value - elif isinstance(value, cls): - data = value.__dict__ - data[cls._meta.primary_key] = key - cls.replace(data).execute() - - @classmethod - def delete_by_id(cls, pk): - query, is_single = cls._apply_lookup_to_query( - cls.delete(), - cls._meta.primary_key, - pk) - return query.execute() - - -OP.MATCH = 'MATCH' - -def _sqlite_regexp(regex, value): - return re.search(regex, value) is not None - - -class SqliteExtDatabase(SqliteDatabase): - def __init__(self, database, c_extensions=None, rank_functions=True, - hash_functions=False, regexp_function=False, - bloomfilter=False, json_contains=False, *args, **kwargs): - super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) - self._row_factory = None - - if c_extensions and not CYTHON_SQLITE_EXTENSIONS: - raise ImproperlyConfigured('SqliteExtDatabase initialized with ' - 'C extensions, but shared library was ' - 'not found!') - prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) - if rank_functions: - if prefer_c: - register_rank_functions(self) - else: - self.register_function(bm25, 'fts_bm25') - self.register_function(rank, 'fts_rank') - self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. - self.register_function(bm25, 'fts_lucene') - if hash_functions: - if not prefer_c: - raise ValueError('C extension required to register hash ' - 'functions.') - register_hash_functions(self) - if regexp_function: - self.register_function(_sqlite_regexp, 'regexp', 2) - if bloomfilter: - if not prefer_c: - raise ValueError('C extension required to use bloomfilter.') - register_bloomfilter(self) - if json_contains: - self.register_function(_json_contains, 'json_contains') - - self._c_extensions = prefer_c - - def _add_conn_hooks(self, conn): - super(SqliteExtDatabase, self)._add_conn_hooks(conn) - if self._row_factory: - conn.row_factory = self._row_factory - - def row_factory(self, fn): - self._row_factory = fn - - -if CYTHON_SQLITE_EXTENSIONS: - SQLITE_STATUS_MEMORY_USED = 0 - SQLITE_STATUS_PAGECACHE_USED = 1 - SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 - SQLITE_STATUS_SCRATCH_USED = 3 - SQLITE_STATUS_SCRATCH_OVERFLOW = 4 - SQLITE_STATUS_MALLOC_SIZE = 5 - SQLITE_STATUS_PARSER_STACK = 6 - SQLITE_STATUS_PAGECACHE_SIZE = 7 - SQLITE_STATUS_SCRATCH_SIZE = 8 - SQLITE_STATUS_MALLOC_COUNT = 9 - SQLITE_DBSTATUS_LOOKASIDE_USED = 0 - SQLITE_DBSTATUS_CACHE_USED = 1 - SQLITE_DBSTATUS_SCHEMA_USED = 2 - SQLITE_DBSTATUS_STMT_USED = 3 - SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 - SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 - SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 - SQLITE_DBSTATUS_CACHE_HIT = 7 - SQLITE_DBSTATUS_CACHE_MISS = 8 - SQLITE_DBSTATUS_CACHE_WRITE = 9 - SQLITE_DBSTATUS_DEFERRED_FKS = 10 - #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 - - def __status__(flag, return_highwater=False): - """ - Expose a sqlite3_status() call for a particular flag as a property of - the Database object. - """ - def getter(self): - result = sqlite_get_status(flag) - return result[1] if return_highwater else result - return property(getter) - - def __dbstatus__(flag, return_highwater=False, return_current=False): - """ - Expose a sqlite3_dbstatus() call for a particular flag as a property of - the Database instance. Unlike sqlite3_status(), the dbstatus properties - pertain to the current connection. - """ - def getter(self): - if self._state.conn is None: - raise ImproperlyConfigured('database connection not opened.') - result = sqlite_get_db_status(self._state.conn, flag) - if return_current: - return result[0] - return result[1] if return_highwater else result - return property(getter) - - class CSqliteExtDatabase(SqliteExtDatabase): - def __init__(self, *args, **kwargs): - self._conn_helper = None - self._commit_hook = self._rollback_hook = self._update_hook = None - self._replace_busy_handler = False - super(CSqliteExtDatabase, self).__init__(*args, **kwargs) - - def init(self, database, replace_busy_handler=False, **kwargs): - super(CSqliteExtDatabase, self).init(database, **kwargs) - self._replace_busy_handler = replace_busy_handler - - def _close(self, conn): - if self._commit_hook: - self._conn_helper.set_commit_hook(None) - if self._rollback_hook: - self._conn_helper.set_rollback_hook(None) - if self._update_hook: - self._conn_helper.set_update_hook(None) - return super(CSqliteExtDatabase, self)._close(conn) - - def _add_conn_hooks(self, conn): - super(CSqliteExtDatabase, self)._add_conn_hooks(conn) - self._conn_helper = ConnectionHelper(conn) - if self._commit_hook is not None: - self._conn_helper.set_commit_hook(self._commit_hook) - if self._rollback_hook is not None: - self._conn_helper.set_rollback_hook(self._rollback_hook) - if self._update_hook is not None: - self._conn_helper.set_update_hook(self._update_hook) - if self._replace_busy_handler: - timeout = self._timeout or 5 - self._conn_helper.set_busy_handler(timeout * 1000) - - def on_commit(self, fn): - self._commit_hook = fn - if not self.is_closed(): - self._conn_helper.set_commit_hook(fn) - return fn - - def on_rollback(self, fn): - self._rollback_hook = fn - if not self.is_closed(): - self._conn_helper.set_rollback_hook(fn) - return fn - - def on_update(self, fn): - self._update_hook = fn - if not self.is_closed(): - self._conn_helper.set_update_hook(fn) - return fn - - def changes(self): - return self._conn_helper.changes() - - @property - def last_insert_rowid(self): - return self._conn_helper.last_insert_rowid() - - @property - def autocommit(self): - return self._conn_helper.autocommit() - - def backup(self, destination, pages=None, name=None, progress=None): - return backup(self.connection(), destination.connection(), - pages=pages, name=name, progress=progress) - - def backup_to_file(self, filename, pages=None, name=None, - progress=None): - return backup_to_file(self.connection(), filename, pages=pages, - name=name, progress=progress) - - def blob_open(self, table, column, rowid, read_only=False): - return Blob(self, table, column, rowid, read_only) - - # Status properties. - memory_used = __status__(SQLITE_STATUS_MEMORY_USED) - malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) - malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) - pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) - pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) - pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) - scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) - scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) - scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) - - # Connection status properties. - lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) - lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) - lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, - True) - lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, - True) - cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) - #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, - # False, True) - schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) - statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) - cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) - cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) - cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) - - -def match(lhs, rhs): - return Expression(lhs, OP.MATCH, rhs) - -def _parse_match_info(buf): - # See http://sqlite.org/fts3.html#matchinfo - bufsize = len(buf) # Length in bytes. - return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] - -def get_weights(ncol, raw_weights): - if not raw_weights: - return [1] * ncol - else: - weights = [0] * ncol - for i, weight in enumerate(raw_weights): - weights[i] = weight - return weights - -# Ranking implementation, which parse matchinfo. -def rank(raw_match_info, *raw_weights): - # Handle match_info called w/default args 'pcx' - based on the example rank - # function http://sqlite.org/fts3.html#appendix_a - match_info = _parse_match_info(raw_match_info) - score = 0.0 - - p, c = match_info[:2] - weights = get_weights(c, raw_weights) - - # matchinfo X value corresponds to, for each phrase in the search query, a - # list of 3 values for each column in the search table. - # So if we have a two-phrase search query and three columns of data, the - # following would be the layout: - # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] - # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] - for phrase_num in range(p): - phrase_info_idx = 2 + (phrase_num * c * 3) - for col_num in range(c): - weight = weights[col_num] - if not weight: - continue - - col_idx = phrase_info_idx + (col_num * 3) - - # The idea is that we count the number of times the phrase appears - # in this column of the current row, compared to how many times it - # appears in this column across all rows. The ratio of these values - # provides a rough way to score based on "high value" terms. - row_hits = match_info[col_idx] - all_rows_hits = match_info[col_idx + 1] - if row_hits > 0: - score += weight * (float(row_hits) / all_rows_hits) - - return -score - -# Okapi BM25 ranking implementation (FTS4 only). -def bm25(raw_match_info, *args): - """ - Usage: - - # Format string *must* be pcnalx - # Second parameter to bm25 specifies the index of the column, on - # the table being queries. - bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank - """ - match_info = _parse_match_info(raw_match_info) - K = 1.2 - B = 0.75 - score = 0.0 - - P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. - term_count = match_info[P_O] # n - col_count = match_info[C_O] - total_docs = match_info[N_O] # N - L_O = A_O + col_count - X_O = L_O + col_count - - # Worked example of pcnalx for two columns and two phrases, 100 docs total. - # { - # p = 2 - # c = 2 - # n = 100 - # a0 = 4 -- avg number of tokens for col0, e.g. title - # a1 = 40 -- avg number of tokens for col1, e.g. body - # l0 = 5 -- curr doc has 5 tokens in col0 - # l1 = 30 -- curr doc has 30 tokens in col1 - # - # x000 -- hits this row for phrase0, col0 - # x001 -- hits all rows for phrase0, col0 - # x002 -- rows with phrase0 in col0 at least once - # - # x010 -- hits this row for phrase0, col1 - # x011 -- hits all rows for phrase0, col1 - # x012 -- rows with phrase0 in col1 at least once - # - # x100 -- hits this row for phrase1, col0 - # x101 -- hits all rows for phrase1, col0 - # x102 -- rows with phrase1 in col0 at least once - # - # x110 -- hits this row for phrase1, col1 - # x111 -- hits all rows for phrase1, col1 - # x112 -- rows with phrase1 in col1 at least once - # } - - weights = get_weights(col_count, args) - - for i in range(term_count): - for j in range(col_count): - weight = weights[j] - if weight == 0: - continue - - x = X_O + (3 * (j + i * col_count)) - term_frequency = float(match_info[x]) # f(qi, D) - docs_with_term = float(match_info[x + 2]) # n(qi) - - # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) - idf = math.log( - (total_docs - docs_with_term + 0.5) / - (docs_with_term + 0.5)) - if idf <= 0.0: - idf = 1e-6 - - doc_length = float(match_info[L_O + j]) # |D| - avg_length = float(match_info[A_O + j]) or 1. # avgdl - ratio = doc_length / avg_length - - num = term_frequency * (K + 1.0) - b_part = 1.0 - B + (B * ratio) - denom = term_frequency + (K * b_part) - - pc_score = idf * (num / denom) - score += (pc_score * weight) - - return -score - - -def _json_contains(src_json, obj_json): - stack = [] - try: - stack.append((json.loads(obj_json), json.loads(src_json))) - except: - # Invalid JSON! - return False - - while stack: - obj, src = stack.pop() - if isinstance(src, dict): - if isinstance(obj, dict): - for key in obj: - if key not in src: - return False - stack.append((obj[key], src[key])) - elif isinstance(obj, list): - for item in obj: - if item not in src: - return False - elif obj not in src: - return False - elif isinstance(src, list): - if isinstance(obj, dict): - return False - elif isinstance(obj, list): - try: - for i in range(len(obj)): - stack.append((obj[i], src[i])) - except IndexError: - return False - elif obj not in src: - return False - elif obj != src: - return False - return True diff --git a/libs/playhouse/sqlite_udf.py b/libs/playhouse/sqlite_udf.py deleted file mode 100644 index 28dbd8560..000000000 --- a/libs/playhouse/sqlite_udf.py +++ /dev/null @@ -1,522 +0,0 @@ -import datetime -import hashlib -import heapq -import math -import os -import random -import re -import sys -import threading -import zlib -try: - from collections import Counter -except ImportError: - Counter = None -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -try: - from playhouse._sqlite_ext import TableFunction -except ImportError: - TableFunction = None - - -SQLITE_DATETIME_FORMATS = ( - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d', - '%H:%M:%S', - '%H:%M:%S.%f', - '%H:%M') - -from peewee import format_date_time - -def format_date_time_sqlite(date_value): - return format_date_time(date_value, SQLITE_DATETIME_FORMATS) - -try: - from playhouse import _sqlite_udf as cython_udf -except ImportError: - cython_udf = None - - -# Group udf by function. -CONTROL_FLOW = 'control_flow' -DATE = 'date' -FILE = 'file' -HELPER = 'helpers' -MATH = 'math' -STRING = 'string' - -AGGREGATE_COLLECTION = {} -TABLE_FUNCTION_COLLECTION = {} -UDF_COLLECTION = {} - - -class synchronized_dict(dict): - def __init__(self, *args, **kwargs): - super(synchronized_dict, self).__init__(*args, **kwargs) - self._lock = threading.Lock() - - def __getitem__(self, key): - with self._lock: - return super(synchronized_dict, self).__getitem__(key) - - def __setitem__(self, key, value): - with self._lock: - return super(synchronized_dict, self).__setitem__(key, value) - - def __delitem__(self, key): - with self._lock: - return super(synchronized_dict, self).__delitem__(key) - - -STATE = synchronized_dict() -SETTINGS = synchronized_dict() - -# Class and function decorators. -def aggregate(*groups): - def decorator(klass): - for group in groups: - AGGREGATE_COLLECTION.setdefault(group, []) - AGGREGATE_COLLECTION[group].append(klass) - return klass - return decorator - -def table_function(*groups): - def decorator(klass): - for group in groups: - TABLE_FUNCTION_COLLECTION.setdefault(group, []) - TABLE_FUNCTION_COLLECTION[group].append(klass) - return klass - return decorator - -def udf(*groups): - def decorator(fn): - for group in groups: - UDF_COLLECTION.setdefault(group, []) - UDF_COLLECTION[group].append(fn) - return fn - return decorator - -# Register aggregates / functions with connection. -def register_aggregate_groups(db, *groups): - seen = set() - for group in groups: - klasses = AGGREGATE_COLLECTION.get(group, ()) - for klass in klasses: - name = getattr(klass, 'name', klass.__name__) - if name not in seen: - seen.add(name) - db.register_aggregate(klass, name) - -def register_table_function_groups(db, *groups): - seen = set() - for group in groups: - klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) - for klass in klasses: - if klass.name not in seen: - seen.add(klass.name) - db.register_table_function(klass) - -def register_udf_groups(db, *groups): - seen = set() - for group in groups: - functions = UDF_COLLECTION.get(group, ()) - for function in functions: - name = function.__name__ - if name not in seen: - seen.add(name) - db.register_function(function, name) - -def register_groups(db, *groups): - register_aggregate_groups(db, *groups) - register_table_function_groups(db, *groups) - register_udf_groups(db, *groups) - -def register_all(db): - register_aggregate_groups(db, *AGGREGATE_COLLECTION) - register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) - register_udf_groups(db, *UDF_COLLECTION) - - -# Begin actual user-defined functions and aggregates. - -# Scalar functions. -@udf(CONTROL_FLOW) -def if_then_else(cond, truthy, falsey=None): - if cond: - return truthy - return falsey - -@udf(DATE) -def strip_tz(date_str): - date_str = date_str.replace('T', ' ') - tz_idx1 = date_str.find('+') - if tz_idx1 != -1: - return date_str[:tz_idx1] - tz_idx2 = date_str.find('-') - if tz_idx2 > 13: - return date_str[:tz_idx2] - return date_str - -@udf(DATE) -def human_delta(nseconds, glue=', '): - parts = ( - (86400 * 365, 'year'), - (86400 * 30, 'month'), - (86400 * 7, 'week'), - (86400, 'day'), - (3600, 'hour'), - (60, 'minute'), - (1, 'second'), - ) - accum = [] - for offset, name in parts: - val, nseconds = divmod(nseconds, offset) - if val: - suffix = val != 1 and 's' or '' - accum.append('%s %s%s' % (val, name, suffix)) - if not accum: - return '0 seconds' - return glue.join(accum) - -@udf(FILE) -def file_ext(filename): - try: - res = os.path.splitext(filename) - except ValueError: - return None - return res[1] - -@udf(FILE) -def file_read(filename): - try: - with open(filename) as fh: - return fh.read() - except: - pass - -if sys.version_info[0] == 2: - @udf(HELPER) - def gzip(data, compression=9): - return buffer(zlib.compress(data, compression)) - - @udf(HELPER) - def gunzip(data): - return zlib.decompress(data) -else: - @udf(HELPER) - def gzip(data, compression=9): - if isinstance(data, str): - data = bytes(data.encode('raw_unicode_escape')) - return zlib.compress(data, compression) - - @udf(HELPER) - def gunzip(data): - return zlib.decompress(data) - -@udf(HELPER) -def hostname(url): - parse_result = urlparse(url) - if parse_result: - return parse_result.netloc - -@udf(HELPER) -def toggle(key): - key = key.lower() - STATE[key] = ret = not STATE.get(key) - return ret - -@udf(HELPER) -def setting(key, value=None): - if value is None: - return SETTINGS.get(key) - else: - SETTINGS[key] = value - return value - -@udf(HELPER) -def clear_settings(): - SETTINGS.clear() - -@udf(HELPER) -def clear_toggles(): - STATE.clear() - -@udf(MATH) -def randomrange(start, end=None, step=None): - if end is None: - start, end = 0, start - elif step is None: - step = 1 - return random.randrange(start, end, step) - -@udf(MATH) -def gauss_distribution(mean, sigma): - try: - return random.gauss(mean, sigma) - except ValueError: - return None - -@udf(MATH) -def sqrt(n): - try: - return math.sqrt(n) - except ValueError: - return None - -@udf(MATH) -def tonumber(s): - try: - return int(s) - except ValueError: - try: - return float(s) - except: - return None - -@udf(STRING) -def substr_count(haystack, needle): - if not haystack or not needle: - return 0 - return haystack.count(needle) - -@udf(STRING) -def strip_chars(haystack, chars): - return haystack.strip(chars) - -def _hash(constructor, *args): - hash_obj = constructor() - for arg in args: - hash_obj.update(arg) - return hash_obj.hexdigest() - -# Aggregates. -class _heap_agg(object): - def __init__(self): - self.heap = [] - self.ct = 0 - - def process(self, value): - return value - - def step(self, value): - self.ct += 1 - heapq.heappush(self.heap, self.process(value)) - -class _datetime_heap_agg(_heap_agg): - def process(self, value): - return format_date_time_sqlite(value) - -if sys.version_info[:2] == (2, 6): - def total_seconds(td): - return (td.seconds + - (td.days * 86400) + - (td.microseconds / (10.**6))) -else: - total_seconds = lambda td: td.total_seconds() - -@aggregate(DATE) -class mintdiff(_datetime_heap_agg): - def finalize(self): - dtp = min_diff = None - while self.heap: - if min_diff is None: - if dtp is None: - dtp = heapq.heappop(self.heap) - continue - dt = heapq.heappop(self.heap) - diff = dt - dtp - if min_diff is None or min_diff > diff: - min_diff = diff - dtp = dt - if min_diff is not None: - return total_seconds(min_diff) - -@aggregate(DATE) -class avgtdiff(_datetime_heap_agg): - def finalize(self): - if self.ct < 1: - return - elif self.ct == 1: - return 0 - - total = ct = 0 - dtp = None - while self.heap: - if total == 0: - if dtp is None: - dtp = heapq.heappop(self.heap) - continue - - dt = heapq.heappop(self.heap) - diff = dt - dtp - ct += 1 - total += total_seconds(diff) - dtp = dt - - return float(total) / ct - -@aggregate(DATE) -class duration(object): - def __init__(self): - self._min = self._max = None - - def step(self, value): - dt = format_date_time_sqlite(value) - if self._min is None or dt < self._min: - self._min = dt - if self._max is None or dt > self._max: - self._max = dt - - def finalize(self): - if self._min and self._max: - td = (self._max - self._min) - return total_seconds(td) - return None - -@aggregate(MATH) -class mode(object): - if Counter: - def __init__(self): - self.items = Counter() - - def step(self, *args): - self.items.update(args) - - def finalize(self): - if self.items: - return self.items.most_common(1)[0][0] - else: - def __init__(self): - self.items = [] - - def step(self, item): - self.items.append(item) - - def finalize(self): - if self.items: - return max(set(self.items), key=self.items.count) - -@aggregate(MATH) -class minrange(_heap_agg): - def finalize(self): - if self.ct == 0: - return - elif self.ct == 1: - return 0 - - prev = min_diff = None - - while self.heap: - if min_diff is None: - if prev is None: - prev = heapq.heappop(self.heap) - continue - curr = heapq.heappop(self.heap) - diff = curr - prev - if min_diff is None or min_diff > diff: - min_diff = diff - prev = curr - return min_diff - -@aggregate(MATH) -class avgrange(_heap_agg): - def finalize(self): - if self.ct == 0: - return - elif self.ct == 1: - return 0 - - total = ct = 0 - prev = None - while self.heap: - if total == 0: - if prev is None: - prev = heapq.heappop(self.heap) - continue - - curr = heapq.heappop(self.heap) - diff = curr - prev - ct += 1 - total += diff - prev = curr - - return float(total) / ct - -@aggregate(MATH) -class _range(object): - name = 'range' - - def __init__(self): - self._min = self._max = None - - def step(self, value): - if self._min is None or value < self._min: - self._min = value - if self._max is None or value > self._max: - self._max = value - - def finalize(self): - if self._min is not None and self._max is not None: - return self._max - self._min - return None - - -if cython_udf is not None: - damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) - levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) - str_dist = udf(STRING)(cython_udf.str_dist) - median = aggregate(MATH)(cython_udf.median) - - -if TableFunction is not None: - @table_function(STRING) - class RegexSearch(TableFunction): - params = ['regex', 'search_string'] - columns = ['match'] - name = 'regex_search' - - def initialize(self, regex=None, search_string=None): - self._iter = re.finditer(regex, search_string) - - def iterate(self, idx): - return (next(self._iter).group(0),) - - @table_function(DATE) - class DateSeries(TableFunction): - params = ['start', 'stop', 'step_seconds'] - columns = ['date'] - name = 'date_series' - - def initialize(self, start, stop, step_seconds=86400): - self.start = format_date_time_sqlite(start) - self.stop = format_date_time_sqlite(stop) - step_seconds = int(step_seconds) - self.step_seconds = datetime.timedelta(seconds=step_seconds) - - if (self.start.hour == 0 and - self.start.minute == 0 and - self.start.second == 0 and - step_seconds >= 86400): - self.format = '%Y-%m-%d' - elif (self.start.year == 1900 and - self.start.month == 1 and - self.start.day == 1 and - self.stop.year == 1900 and - self.stop.month == 1 and - self.stop.day == 1 and - step_seconds < 86400): - self.format = '%H:%M:%S' - else: - self.format = '%Y-%m-%d %H:%M:%S' - - def iterate(self, idx): - if self.start > self.stop: - raise StopIteration - current = self.start - self.start += self.step_seconds - return (current.strftime(self.format),) diff --git a/libs/playhouse/sqliteq.py b/libs/playhouse/sqliteq.py deleted file mode 100644 index bd213549d..000000000 --- a/libs/playhouse/sqliteq.py +++ /dev/null @@ -1,330 +0,0 @@ -import logging -import weakref -from threading import local as thread_local -from threading import Event -from threading import Thread -try: - from Queue import Queue -except ImportError: - from queue import Queue - -try: - import gevent - from gevent import Greenlet as GThread - from gevent.event import Event as GEvent - from gevent.local import local as greenlet_local - from gevent.queue import Queue as GQueue -except ImportError: - GThread = GQueue = GEvent = None - -from peewee import SENTINEL -from playhouse.sqlite_ext import SqliteExtDatabase - - -logger = logging.getLogger('peewee.sqliteq') - - -class ResultTimeout(Exception): - pass - -class WriterPaused(Exception): - pass - -class ShutdownException(Exception): - pass - - -class AsyncCursor(object): - __slots__ = ('sql', 'params', 'commit', 'timeout', - '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') - - def __init__(self, event, sql, params, commit, timeout): - self._event = event - self.sql = sql - self.params = params - self.commit = commit - self.timeout = timeout - self._cursor = self._exc = self._idx = self._rows = None - self._ready = False - - def set_result(self, cursor, exc=None): - self._cursor = cursor - self._exc = exc - self._idx = 0 - self._rows = cursor.fetchall() if exc is None else [] - self._event.set() - return self - - def _wait(self, timeout=None): - timeout = timeout if timeout is not None else self.timeout - if not self._event.wait(timeout=timeout) and timeout: - raise ResultTimeout('results not ready, timed out.') - if self._exc is not None: - raise self._exc - self._ready = True - - def __iter__(self): - if not self._ready: - self._wait() - if self._exc is not None: - raise self._exec - return self - - def next(self): - if not self._ready: - self._wait() - try: - obj = self._rows[self._idx] - except IndexError: - raise StopIteration - else: - self._idx += 1 - return obj - __next__ = next - - @property - def lastrowid(self): - if not self._ready: - self._wait() - return self._cursor.lastrowid - - @property - def rowcount(self): - if not self._ready: - self._wait() - return self._cursor.rowcount - - @property - def description(self): - return self._cursor.description - - def close(self): - self._cursor.close() - - def fetchall(self): - return list(self) # Iterating implies waiting until populated. - - def fetchone(self): - if not self._ready: - self._wait() - try: - return next(self) - except StopIteration: - return None - -SHUTDOWN = StopIteration -PAUSE = object() -UNPAUSE = object() - - -class Writer(object): - __slots__ = ('database', 'queue') - - def __init__(self, database, queue): - self.database = database - self.queue = queue - - def run(self): - conn = self.database.connection() - try: - while True: - try: - if conn is None: # Paused. - if self.wait_unpause(): - conn = self.database.connection() - else: - conn = self.loop(conn) - except ShutdownException: - logger.info('writer received shutdown request, exiting.') - return - finally: - if conn is not None: - self.database._close(conn) - self.database._state.reset() - - def wait_unpause(self): - obj = self.queue.get() - if obj is UNPAUSE: - logger.info('writer unpaused - reconnecting to database.') - return True - elif obj is SHUTDOWN: - raise ShutdownException() - elif obj is PAUSE: - logger.error('writer received pause, but is already paused.') - else: - obj.set_result(None, WriterPaused()) - logger.warning('writer paused, not handling %s', obj) - - def loop(self, conn): - obj = self.queue.get() - if isinstance(obj, AsyncCursor): - self.execute(obj) - elif obj is PAUSE: - logger.info('writer paused - closing database connection.') - self.database._close(conn) - self.database._state.reset() - return - elif obj is UNPAUSE: - logger.error('writer received unpause, but is already running.') - elif obj is SHUTDOWN: - raise ShutdownException() - else: - logger.error('writer received unsupported object: %s', obj) - return conn - - def execute(self, obj): - logger.debug('received query %s', obj.sql) - try: - cursor = self.database._execute(obj.sql, obj.params, obj.commit) - except Exception as execute_err: - cursor = None - exc = execute_err # python3 is so fucking lame. - else: - exc = None - return obj.set_result(cursor, exc) - - -class SqliteQueueDatabase(SqliteExtDatabase): - WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' - 'journal mode when using this feature. WAL mode ' - 'allows one or more readers to continue reading ' - 'while another connection writes to the ' - 'database.') - - def __init__(self, database, use_gevent=False, autostart=True, - queue_max_size=None, results_timeout=None, *args, **kwargs): - kwargs['check_same_thread'] = False - - # Ensure that journal_mode is WAL. This value is passed to the parent - # class constructor below. - pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) - - # Reference to execute_sql on the parent class. Since we've overridden - # execute_sql(), this is just a handy way to reference the real - # implementation. - Parent = super(SqliteQueueDatabase, self) - self._execute = Parent.execute_sql - - # Call the parent class constructor with our modified pragmas. - Parent.__init__(database, pragmas=pragmas, *args, **kwargs) - - self._autostart = autostart - self._results_timeout = results_timeout - self._is_stopped = True - - # Get different objects depending on the threading implementation. - self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) - - # Create the writer thread, optionally starting it. - self._create_write_queue() - if self._autostart: - self.start() - - def get_thread_impl(self, use_gevent): - return GreenletHelper if use_gevent else ThreadHelper - - def _validate_journal_mode(self, pragmas=None): - if pragmas: - pdict = dict((k.lower(), v) for (k, v) in pragmas) - if pdict.get('journal_mode', 'wal').lower() != 'wal': - raise ValueError(self.WAL_MODE_ERROR_MESSAGE) - - return [(k, v) for (k, v) in pragmas - if k != 'journal_mode'] + [('journal_mode', 'wal')] - else: - return [('journal_mode', 'wal')] - - def _create_write_queue(self): - self._write_queue = self._thread_helper.queue() - - def queue_size(self): - return self._write_queue.qsize() - - def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None): - if commit is SENTINEL: - commit = not sql.lower().startswith('select') - - if not commit: - return self._execute(sql, params, commit=commit) - - cursor = AsyncCursor( - event=self._thread_helper.event(), - sql=sql, - params=params, - commit=commit, - timeout=self._results_timeout if timeout is None else timeout) - self._write_queue.put(cursor) - return cursor - - def start(self): - with self._lock: - if not self._is_stopped: - return False - def run(): - writer = Writer(self, self._write_queue) - writer.run() - - self._writer = self._thread_helper.thread(run) - self._writer.start() - self._is_stopped = False - return True - - def stop(self): - logger.debug('environment stop requested.') - with self._lock: - if self._is_stopped: - return False - self._write_queue.put(SHUTDOWN) - self._writer.join() - self._is_stopped = True - return True - - def is_stopped(self): - with self._lock: - return self._is_stopped - - def pause(self): - with self._lock: - self._write_queue.put(PAUSE) - - def unpause(self): - with self._lock: - self._write_queue.put(UNPAUSE) - - def __unsupported__(self, *args, **kwargs): - raise ValueError('This method is not supported by %r.' % type(self)) - atomic = transaction = savepoint = __unsupported__ - - -class ThreadHelper(object): - __slots__ = ('queue_max_size',) - - def __init__(self, queue_max_size=None): - self.queue_max_size = queue_max_size - - def event(self): return Event() - - def queue(self, max_size=None): - max_size = max_size if max_size is not None else self.queue_max_size - return Queue(maxsize=max_size or 0) - - def thread(self, fn, *args, **kwargs): - thread = Thread(target=fn, args=args, kwargs=kwargs) - thread.daemon = True - return thread - - -class GreenletHelper(ThreadHelper): - __slots__ = () - - def event(self): return GEvent() - - def queue(self, max_size=None): - max_size = max_size if max_size is not None else self.queue_max_size - return GQueue(maxsize=max_size or 0) - - def thread(self, fn, *args, **kwargs): - def wrap(*a, **k): - gevent.sleep() - return fn(*a, **k) - return GThread(wrap, *args, **kwargs) diff --git a/libs/playhouse/test_utils.py b/libs/playhouse/test_utils.py deleted file mode 100644 index 333dc078b..000000000 --- a/libs/playhouse/test_utils.py +++ /dev/null @@ -1,62 +0,0 @@ -from functools import wraps -import logging - - -logger = logging.getLogger('peewee') - - -class _QueryLogHandler(logging.Handler): - def __init__(self, *args, **kwargs): - self.queries = [] - logging.Handler.__init__(self, *args, **kwargs) - - def emit(self, record): - self.queries.append(record) - - -class count_queries(object): - def __init__(self, only_select=False): - self.only_select = only_select - self.count = 0 - - def get_queries(self): - return self._handler.queries - - def __enter__(self): - self._handler = _QueryLogHandler() - logger.setLevel(logging.DEBUG) - logger.addHandler(self._handler) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.removeHandler(self._handler) - if self.only_select: - self.count = len([q for q in self._handler.queries - if q.msg[0].startswith('SELECT ')]) - else: - self.count = len(self._handler.queries) - - -class assert_query_count(count_queries): - def __init__(self, expected, only_select=False): - super(assert_query_count, self).__init__(only_select=only_select) - self.expected = expected - - def __call__(self, f): - @wraps(f) - def decorated(*args, **kwds): - with self: - ret = f(*args, **kwds) - - self._assert_count() - return ret - - return decorated - - def _assert_count(self): - error_msg = '%s != %s' % (self.count, self.expected) - assert self.count == self.expected, error_msg - - def __exit__(self, exc_type, exc_val, exc_tb): - super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) - self._assert_count() diff --git a/libs/sqlite3worker.py b/libs/sqlite3worker.py new file mode 100644 index 000000000..f7653000d --- /dev/null +++ b/libs/sqlite3worker.py @@ -0,0 +1,198 @@ +# Copyright (c) 2014 Palantir Technologies +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +"""Thread safe sqlite3 interface.""" + +__author__ = "Shawn Lee" +__email__ = "shawnl@palantir.com" +__license__ = "MIT" + +import logging +try: + import queue as Queue # module re-named in Python 3 +except ImportError: + import Queue +import sqlite3 +import threading +import time +import uuid + +LOGGER = logging.getLogger('sqlite3worker') + + +class Sqlite3Worker(threading.Thread): + """Sqlite thread safe object. + + Example: + from sqlite3worker import Sqlite3Worker + sql_worker = Sqlite3Worker("/tmp/test.sqlite") + sql_worker.execute( + "CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") + sql_worker.execute( + "INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) + sql_worker.execute( + "INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) + sql_worker.execute("SELECT * from tester") + sql_worker.close() + """ + def __init__(self, file_name, max_queue_size=100): + """Automatically starts the thread. + + Args: + file_name: The name of the file. + max_queue_size: The max queries that will be queued. + """ + threading.Thread.__init__(self) + self.daemon = True + self.sqlite3_conn = sqlite3.connect( + file_name, check_same_thread=False, + detect_types=sqlite3.PARSE_DECLTYPES) + self.sqlite3_conn.row_factory = sqlite3.Row + self.sqlite3_cursor = self.sqlite3_conn.cursor() + self.sql_queue = Queue.Queue(maxsize=max_queue_size) + self.results = {} + self.max_queue_size = max_queue_size + self.exit_set = False + # Token that is put into queue when close() is called. + self.exit_token = str(uuid.uuid4()) + self.start() + self.thread_running = True + + def run(self): + """Thread loop. + + This is an infinite loop. The iter method calls self.sql_queue.get() + which blocks if there are not values in the queue. As soon as values + are placed into the queue the process will continue. + + If many executes happen at once it will churn through them all before + calling commit() to speed things up by reducing the number of times + commit is called. + """ + LOGGER.debug("run: Thread started") + execute_count = 0 + for token, query, values in iter(self.sql_queue.get, None): + LOGGER.debug("sql_queue: %s", self.sql_queue.qsize()) + if token != self.exit_token: + LOGGER.debug("run: %s", query) + self.run_query(token, query, values) + execute_count += 1 + # Let the executes build up a little before committing to disk + # to speed things up. + if ( + self.sql_queue.empty() or + execute_count == self.max_queue_size): + LOGGER.debug("run: commit") + self.sqlite3_conn.commit() + execute_count = 0 + # Only exit if the queue is empty. Otherwise keep getting + # through the queue until it's empty. + if self.exit_set and self.sql_queue.empty(): + self.sqlite3_conn.commit() + self.sqlite3_conn.close() + self.thread_running = False + return + + def run_query(self, token, query, values): + """Run a query. + + Args: + token: A uuid object of the query you want returned. + query: A sql query with ? placeholders for values. + values: A tuple of values to replace "?" in query. + """ + if query.lower().strip().startswith("select"): + try: + self.sqlite3_cursor.execute(query, values) + self.results[token] = self.sqlite3_cursor.fetchall() + except sqlite3.Error as err: + # Put the error into the output queue since a response + # is required. + self.results[token] = ( + "Query returned error: %s: %s: %s" % (query, values, err)) + LOGGER.error( + "Query returned error: %s: %s: %s", query, values, err) + else: + try: + self.sqlite3_cursor.execute(query, values) + except sqlite3.Error as err: + LOGGER.error( + "Query returned error: %s: %s: %s", query, values, err) + + def close(self): + """Close down the thread and close the sqlite3 database file.""" + self.exit_set = True + self.sql_queue.put((self.exit_token, "", ""), timeout=5) + # Sleep and check that the thread is done before returning. + while self.thread_running: + time.sleep(.01) # Don't kill the CPU waiting. + + @property + def queue_size(self): + """Return the queue size.""" + return self.sql_queue.qsize() + + def query_results(self, token): + """Get the query results for a specific token. + + Args: + token: A uuid object of the query you want returned. + + Returns: + Return the results of the query when it's executed by the thread. + """ + delay = .001 + while True: + if token in self.results: + return_val = self.results[token] + del self.results[token] + return return_val + # Double back on the delay to a max of 8 seconds. This prevents + # a long lived select statement from trashing the CPU with this + # infinite loop as it's waiting for the query results. + LOGGER.debug("Sleeping: %s %s", delay, token) + time.sleep(delay) + if delay < 8: + delay += delay + + def execute(self, query, values=None): + """Execute a query. + + Args: + query: The sql string using ? for placeholders of dynamic values. + values: A tuple of values to be replaced into the ? of the query. + + Returns: + If it's a select query it will return the results of the query. + """ + if self.exit_set: + LOGGER.debug("Exit set, not running: %s", query) + return "Exit Called" + LOGGER.debug("execute: %s", query) + values = values or [] + # A token to track this query with. + token = str(uuid.uuid4()) + # If it's a select we queue it up with a token to mark the results + # into the output queue so we know what results are ours. + if query.lower().strip().startswith("select"): + self.sql_queue.put((token, query, values), timeout=5) + return self.query_results(token) + else: + self.sql_queue.put((token, query, values), timeout=5)