From 2b9d892ca9151118917c2d682c9bde204a5c72fa Mon Sep 17 00:00:00 2001 From: morpheus65535 Date: Wed, 26 May 2021 16:47:14 -0400 Subject: [PATCH] Implemented Peewee ORM in replacement of raw SQL queries. --- bazarr/api.py | 661 ++- bazarr/create_db.sql | 88 - bazarr/database.py | 648 +-- bazarr/embedded_subs_reader.py | 45 +- bazarr/get_episodes.py | 83 +- bazarr/get_languages.py | 45 +- bazarr/get_movies.py | 65 +- bazarr/get_rootfolder.py | 68 +- bazarr/get_series.py | 50 +- bazarr/get_subtitle.py | 323 +- bazarr/init.py | 56 - bazarr/list_subtitles.py | 91 +- bazarr/main.py | 10 +- bazarr/notifier.py | 34 +- bazarr/signalr_client.py | 71 +- bazarr/utils.py | 84 +- libs/peewee.py | 7746 +++++++++++++++++++++++++ libs/playhouse/README.md | 48 + libs/playhouse/__init__.py | 0 libs/playhouse/_pysqlite/cache.h | 73 + libs/playhouse/_pysqlite/connection.h | 129 + libs/playhouse/_pysqlite/module.h | 58 + libs/playhouse/_sqlite_ext.pyx | 1595 +++++ libs/playhouse/_sqlite_udf.pyx | 137 + libs/playhouse/apsw_ext.py | 146 + libs/playhouse/cockroachdb.py | 207 + libs/playhouse/dataset.py | 451 ++ libs/playhouse/db_url.py | 130 + libs/playhouse/fields.py | 64 + libs/playhouse/flask_utils.py | 185 + libs/playhouse/hybrid.py | 53 + libs/playhouse/kv.py | 172 + libs/playhouse/migrate.py | 886 +++ libs/playhouse/mysql_ext.py | 49 + libs/playhouse/pool.py | 318 + libs/playhouse/postgres_ext.py | 493 ++ libs/playhouse/reflection.py | 833 +++ libs/playhouse/shortcuts.py | 252 + libs/playhouse/signals.py | 79 + libs/playhouse/sqlcipher_ext.py | 103 + libs/playhouse/sqlite_changelog.py | 123 + libs/playhouse/sqlite_ext.py | 1294 +++++ libs/playhouse/sqlite_udf.py | 536 ++ libs/playhouse/sqliteq.py | 331 ++ libs/playhouse/test_utils.py | 62 + libs/sqlite3worker.py | 219 - libs/version.txt | 1 + 47 files changed, 17878 insertions(+), 1317 deletions(-) delete mode 100644 bazarr/create_db.sql create mode 100644 libs/peewee.py create mode 100644 libs/playhouse/README.md create mode 100644 libs/playhouse/__init__.py create mode 100644 libs/playhouse/_pysqlite/cache.h create mode 100644 libs/playhouse/_pysqlite/connection.h create mode 100644 libs/playhouse/_pysqlite/module.h create mode 100644 libs/playhouse/_sqlite_ext.pyx create mode 100644 libs/playhouse/_sqlite_udf.pyx create mode 100644 libs/playhouse/apsw_ext.py create mode 100644 libs/playhouse/cockroachdb.py create mode 100644 libs/playhouse/dataset.py create mode 100644 libs/playhouse/db_url.py create mode 100644 libs/playhouse/fields.py create mode 100644 libs/playhouse/flask_utils.py create mode 100644 libs/playhouse/hybrid.py create mode 100644 libs/playhouse/kv.py create mode 100644 libs/playhouse/migrate.py create mode 100644 libs/playhouse/mysql_ext.py create mode 100644 libs/playhouse/pool.py create mode 100644 libs/playhouse/postgres_ext.py create mode 100644 libs/playhouse/reflection.py create mode 100644 libs/playhouse/shortcuts.py create mode 100644 libs/playhouse/signals.py create mode 100644 libs/playhouse/sqlcipher_ext.py create mode 100644 libs/playhouse/sqlite_changelog.py create mode 100644 libs/playhouse/sqlite_ext.py create mode 100644 libs/playhouse/sqlite_udf.py create mode 100644 libs/playhouse/sqliteq.py create mode 100644 libs/playhouse/test_utils.py delete mode 100644 libs/sqlite3worker.py diff --git a/bazarr/api.py b/bazarr/api.py index dc8e57fec..61da3c415 100644 --- a/bazarr/api.py +++ b/bazarr/api.py @@ -5,13 +5,16 @@ from datetime import timedelta from dateutil import rrule import pretty import time +import operator from operator import itemgetter +from functools import reduce import platform import re import json import hashlib import apprise import gc +from peewee import fn, Value from get_args import args from config import settings, base_url, save_settings, get_settings @@ -19,8 +22,10 @@ from logger import empty_log from init import * import logging -from database import database, get_exclusion_clause, get_profiles_list, get_desired_languages, get_profile_id_name, \ - get_audio_profile_languages, update_profile_id_list, convert_list_to_clause +from database import get_exclusion_clause, get_profiles_list, get_desired_languages, get_profile_id_name, \ + get_audio_profile_languages, update_profile_id_list, convert_list_to_clause, TableEpisodes, TableShows, \ + TableMovies, TableSettingsLanguages, TableSettingsNotifier, TableLanguagesProfiles, TableHistory, \ + TableHistoryMovie, TableBlacklist, TableBlacklistMovie from helper import path_mappings from get_languages import language_from_alpha2, language_from_alpha3, alpha2_from_alpha3, alpha3_from_alpha2 from get_subtitle import download_subtitle, series_download_subtitles, manual_search, manual_download_subtitle, \ @@ -70,7 +75,7 @@ def authenticate(actual_method): return wrapper -def postprocess(item: dict): +def postprocess(item): # Remove ffprobe_cache if 'ffprobe_cache' in item: del (item['ffprobe_cache']) @@ -310,15 +315,15 @@ class System(Resource): class Badges(Resource): @authenticate def get(self): - missing_episodes = database.execute("SELECT table_shows.tags, table_episodes.monitored, table_shows.seriesType " - "FROM table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId =" - " table_episodes.sonarrSeriesId WHERE missing_subtitles is not null AND " - "missing_subtitles != '[]'" + get_exclusion_clause('series')) - missing_episodes = len(missing_episodes) + episodes_conditions = [(TableEpisodes.missing_subtitles is not None), + (TableEpisodes.missing_subtitles != '[]')] + episodes_conditions += get_exclusion_clause('series') + missing_episodes = TableEpisodes.select().where(reduce(operator.and_, episodes_conditions)).count() - missing_movies = database.execute("SELECT tags, monitored FROM table_movies WHERE missing_subtitles is not " - "null AND missing_subtitles != '[]'" + get_exclusion_clause('movie')) - missing_movies = len(missing_movies) + movies_conditions = [(TableMovies.missing_subtitles is not None), + (TableMovies.missing_subtitles != '[]')] + movies_conditions += get_exclusion_clause('movie') + missing_movies = TableMovies.select().where(reduce(operator.and_, movies_conditions)).count() throttled_providers = len(eval(str(get_throttled_providers()))) @@ -336,7 +341,11 @@ class Badges(Resource): class Languages(Resource): @authenticate def get(self): - result = database.execute("SELECT name, code2, enabled FROM table_settings_languages ORDER BY name") + result = TableSettingsLanguages.select(TableSettingsLanguages.name, + TableSettingsLanguages.code2, + TableSettingsLanguages.enabled)\ + .order_by(TableSettingsLanguages.name).dicts() + result = list(result) for item in result: item['enabled'] = item['enabled'] == 1 return jsonify(result) @@ -376,16 +385,24 @@ class Searches(Resource): if query: if settings.general.getboolean('use_sonarr'): # Get matching series - series = database.execute("SELECT title, sonarrSeriesId, year FROM table_shows WHERE title LIKE ? " - "ORDER BY title ASC", ("%" + query + "%",)) - + series = TableShows.select(TableShows.title, + TableShows.sonarrSeriesId, + TableShows.year)\ + .where(TableShows.title.contains(query))\ + .order_by(TableShows.title)\ + .dicts() + series = list(series) search_list += series if settings.general.getboolean('use_radarr'): # Get matching movies - movies = database.execute("SELECT title, radarrId, year FROM table_movies WHERE title LIKE ? ORDER BY " - "title ASC", ("%" + query + "%",)) - + movies = TableMovies.select(TableMovies.title, + TableMovies.radarrId, + TableMovies.year) \ + .where(TableMovies.title.contains(query)) \ + .order_by(TableMovies.title) \ + .dicts() + movies = list(movies) search_list += movies return jsonify(search_list) @@ -396,7 +413,8 @@ class SystemSettings(Resource): def get(self): data = get_settings() - notifications = database.execute("SELECT * FROM table_settings_notifier ORDER BY name") + notifications = TableSettingsNotifier.select().order_by(TableSettingsNotifier.name).dicts() + notifications = list(notifications) for i, item in enumerate(notifications): item["enabled"] = item["enabled"] == 1 notifications[i] = item @@ -410,37 +428,51 @@ class SystemSettings(Resource): def post(self): enabled_languages = request.form.getlist('languages-enabled') if len(enabled_languages) != 0: - database.execute("UPDATE table_settings_languages SET enabled=0") + TableSettingsLanguages.update({ + TableSettingsLanguages.enabled: 0 + }).execute() for code in enabled_languages: - database.execute("UPDATE table_settings_languages SET enabled=1 WHERE code2=?",(code,)) + TableSettingsLanguages.update({ + TableSettingsLanguages.enabled: 1 + })\ + .where(TableSettingsLanguages.code2 == code)\ + .execute() event_stream("languages") languages_profiles = request.form.get('languages-profiles') if languages_profiles: - existing_ids = database.execute('SELECT profileId FROM table_languages_profiles') + existing_ids = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId).dicts() + existing_ids = list(existing_ids) existing = [x['profileId'] for x in existing_ids] for item in json.loads(languages_profiles): if item['profileId'] in existing: # Update existing profiles - database.execute('UPDATE table_languages_profiles SET name = ?, cutoff = ?, items = ? ' - 'WHERE profileId = ?', (item['name'], - item['cutoff'] if item['cutoff'] != 'null' else None, - json.dumps(item['items']), - item['profileId'])) + TableLanguagesProfiles.update({ + TableLanguagesProfiles.name: item['name'], + TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None, + TableLanguagesProfiles.items: json.dumps(item['items']) + })\ + .where(TableLanguagesProfiles.profileId == item['profileId'])\ + .execute() existing.remove(item['profileId']) else: # Add new profiles - database.execute('INSERT INTO table_languages_profiles (profileId, name, cutoff, items) ' - 'VALUES (?, ?, ?, ?)', (item['profileId'], - item['name'], - item['cutoff'] if item['cutoff'] != 'null' else None, - json.dumps(item['items']))) + TableLanguagesProfiles.insert({ + TableLanguagesProfiles.profileId: item['profileId'], + TableLanguagesProfiles.name: item['name'], + TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None, + TableLanguagesProfiles.items: json.dumps(item['items']) + }).execute() for profileId in existing: # Unassign this profileId from series and movies - database.execute('UPDATE table_shows SET profileId = null WHERE profileId = ?', (profileId,)) - database.execute('UPDATE table_movies SET profileId = null WHERE profileId = ?', (profileId,)) + TableShows.update({ + TableShows.profileId: None + }).where(TableShows.profileId == profileId).execute() + TableMovies.update({ + TableMovies.profileId: None + }).where(TableMovies.profileId == profileId).execute() # Remove deleted profiles - database.execute('DELETE FROM table_languages_profiles WHERE profileId = ?', (profileId,)) + TableLanguagesProfiles.delete().where(TableLanguagesProfiles.profileId == profileId).execute() update_profile_id_list() event_stream("languages") @@ -454,8 +486,10 @@ class SystemSettings(Resource): notifications = request.form.getlist('notifications-providers') for item in notifications: item = json.loads(item) - database.execute("UPDATE table_settings_notifier SET enabled = ?, url = ? WHERE name = ?", - (item['enabled'], item['url'], item['name'])) + TableSettingsNotifier.update({ + TableSettingsNotifier.enabled: item['enabled'], + TableSettingsNotifier.url: item['url'] + }).where(TableSettingsNotifier.name == item['name']).execute() save_settings(zip(request.form.keys(), request.form.listvalues())) event_stream("settings") @@ -574,35 +608,43 @@ class Series(Resource): length = request.args.get('length') or -1 seriesId = request.args.getlist('seriesid[]') - count = database.execute("SELECT COUNT(*) as count FROM table_shows", only_one=True)['count'] + count = TableShows.select().count() if len(seriesId) != 0: - result = database.execute( - f"SELECT * FROM table_shows WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY sortTitle ASC") + result = TableShows.select()\ + .where(TableShows.sonarrSeriesId.in_(seriesId))\ + .order_by(TableShows.sortTitle).dicts() else: - result = database.execute("SELECT * FROM table_shows ORDER BY sortTitle ASC LIMIT ? OFFSET ?" - , (length, start)) + result = TableShows.select().order_by(TableShows.sortTitle).limit(length).offset(start).dicts() + + result = list(result) for item in result: postprocessSeries(item) # Add missing subtitles episode count - episodeMissingCount = database.execute("SELECT table_shows.tags, table_episodes.monitored, " - "table_shows.seriesType FROM table_episodes INNER JOIN table_shows " - "on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId " - "WHERE table_episodes.sonarrSeriesId=? AND missing_subtitles is not " - "null AND missing_subtitles != '[]'" + - get_exclusion_clause('series'), (item['sonarrSeriesId'],)) - episodeMissingCount = len(episodeMissingCount) + episodes_missing_conditions = [(TableEpisodes.sonarrSeriesId == item['sonarrSeriesId']), + (TableEpisodes.missing_subtitles != '[]')] + episodes_missing_conditions += get_exclusion_clause('series') + + episodeMissingCount = TableEpisodes.select(TableShows.tags, + TableEpisodes.monitored, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, episodes_missing_conditions))\ + .count() item.update({"episodeMissingCount": episodeMissingCount}) # Add episode count - episodeFileCount = database.execute("SELECT table_shows.tags, table_episodes.monitored, " - "table_shows.seriesType FROM table_episodes INNER JOIN table_shows on " - "table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " - "table_episodes.sonarrSeriesId=?" + get_exclusion_clause('series'), - (item['sonarrSeriesId'],)) - episodeFileCount = len(episodeFileCount) + episodes_count_conditions = [(TableEpisodes.sonarrSeriesId == item['sonarrSeriesId'])] + episodes_count_conditions += get_exclusion_clause('series') + + episodeFileCount = TableEpisodes.select(TableShows.tags, + TableEpisodes.monitored, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, episodes_count_conditions))\ + .count() item.update({"episodeFileCount": episodeFileCount}) return jsonify(data=result, total=count) @@ -624,7 +666,11 @@ class Series(Resource): except Exception: return '', 400 - database.execute("UPDATE table_shows SET profileId=? WHERE sonarrSeriesId=?", (profileId, seriesId)) + TableShows.update({ + TableShows.profileId: profileId + })\ + .where(TableShows.sonarrSeriesId == seriesId)\ + .execute() list_missing_subtitles(no=seriesId, send_event=False) @@ -657,14 +703,16 @@ class Episodes(Resource): episodeId = request.args.getlist('episodeid[]') if len(episodeId) > 0: - result = database.execute(f"SELECT * FROM table_episodes WHERE sonarrEpisodeId in {convert_list_to_clause(episodeId)}") + result = TableEpisodes.select().where(TableEpisodes.sonarrEpisodeId.in_(episodeId)).dicts() elif len(seriesId) > 0: - result = database.execute("SELECT * FROM table_episodes " - f"WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY season DESC, " - "episode DESC") + result = TableEpisodes.select()\ + .where(TableEpisodes.sonarrSeriesId.in_(seriesId))\ + .order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc())\ + .dicts() else: return "Series or Episode ID not provided", 400 + result = list(result) for item in result: postprocessEpisode(item) @@ -679,9 +727,13 @@ class EpisodesSubtitles(Resource): def patch(self): sonarrSeriesId = request.args.get('seriesid') sonarrEpisodeId = request.args.get('episodeid') - episodeInfo = database.execute( - "SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarrEpisodeId,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.scene_name, + TableEpisodes.audio_language)\ + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\ + .dicts()\ + .get() title = episodeInfo['title'] episodePath = path_mappings.path_replace(episodeInfo['path']) @@ -735,9 +787,13 @@ class EpisodesSubtitles(Resource): def post(self): sonarrSeriesId = request.args.get('seriesid') sonarrEpisodeId = request.args.get('episodeid') - episodeInfo = database.execute( - "SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarrEpisodeId,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.scene_name, + TableEpisodes.audio_language)\ + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\ + .dicts()\ + .get() title = episodeInfo['title'] episodePath = path_mappings.path_replace(episodeInfo['path']) @@ -789,9 +845,13 @@ class EpisodesSubtitles(Resource): def delete(self): sonarrSeriesId = request.args.get('seriesid') sonarrEpisodeId = request.args.get('episodeid') - episodeInfo = database.execute( - "SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarrEpisodeId,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.scene_name, + TableEpisodes.audio_language)\ + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\ + .dicts()\ + .get() episodePath = path_mappings.path_replace(episodeInfo['path']) @@ -819,13 +879,16 @@ class Movies(Resource): length = request.args.get('length') or -1 radarrId = request.args.getlist('radarrid[]') - count = database.execute("SELECT COUNT(*) as count FROM table_movies", only_one=True)['count'] + count = TableMovies.select().count() if len(radarrId) != 0: - result = database.execute(f"SELECT * FROM table_movies WHERE radarrId in {convert_list_to_clause(radarrId)} ORDER BY sortTitle ASC") + result = TableMovies.select()\ + .where(TableMovies.radarrId.in_(radarrId))\ + .order_by(TableMovies.sortTitle)\ + .dicts() else: - result = database.execute("SELECT * FROM table_movies ORDER BY sortTitle ASC LIMIT ? OFFSET ?", - (length, start)) + result = TableMovies.select().order_by(TableMovies.sortTitle).limit(length).offset(start).dicts() + result = list(result) for item in result: postprocessMovie(item) @@ -848,7 +911,11 @@ class Movies(Resource): except Exception: return '', 400 - database.execute("UPDATE table_movies SET profileId=? WHERE radarrId=?", (profileId, radarrId)) + TableMovies.update({ + TableMovies.profileId: profileId + })\ + .where(TableMovies.radarrId == radarrId)\ + .execute() list_missing_subtitles_movies(no=radarrId) @@ -885,8 +952,13 @@ class MoviesSubtitles(Resource): # Download radarrId = request.args.get('radarrid') - movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?", - (radarrId,), only_one=True) + movieInfo = TableMovies.select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.audio_language)\ + .where(TableMovies.radarrId == radarrId)\ + .dicts()\ + .get() moviePath = path_mappings.path_replace_movie(movieInfo['path']) sceneName = movieInfo['sceneName'] @@ -940,8 +1012,13 @@ class MoviesSubtitles(Resource): # Upload # TODO: Support Multiply Upload radarrId = request.args.get('radarrid') - movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?", - (radarrId,), only_one=True) + movieInfo = TableMovies.select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.audio_language) \ + .where(TableMovies.radarrId == radarrId) \ + .dicts() \ + .get() moviePath = path_mappings.path_replace_movie(movieInfo['path']) sceneName = movieInfo['sceneName'] @@ -992,7 +1069,10 @@ class MoviesSubtitles(Resource): def delete(self): # Delete radarrId = request.args.get('radarrid') - movieInfo = database.execute("SELECT path FROM table_movies WHERE radarrId=?", (radarrId,), only_one=True) + movieInfo = TableMovies.select(TableMovies.path) \ + .where(TableMovies.radarrId == radarrId) \ + .dicts() \ + .get() moviePath = path_mappings.path_replace_movie(movieInfo['path']) @@ -1044,8 +1124,13 @@ class ProviderMovies(Resource): def get(self): # Manual Search radarrId = request.args.get('radarrid') - movieInfo = database.execute("SELECT title, path, sceneName, profileId FROM table_movies WHERE radarrId=?", - (radarrId,), only_one=True) + movieInfo = TableMovies.select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.profileId) \ + .where(TableMovies.radarrId == radarrId) \ + .dicts() \ + .get() title = movieInfo['title'] moviePath = path_mappings.path_replace_movie(movieInfo['path']) @@ -1066,8 +1151,13 @@ class ProviderMovies(Resource): def post(self): # Manual Download radarrId = request.args.get('radarrid') - movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?", - (radarrId,), only_one=True) + movieInfo = TableMovies.select(TableMovies.title, + TableMovies.path, + TableMovies.sceneName, + TableMovies.audio_language) \ + .where(TableMovies.radarrId == radarrId) \ + .dicts() \ + .get() title = movieInfo['title'] moviePath = path_mappings.path_replace_movie(movieInfo['path']) @@ -1121,19 +1211,19 @@ class ProviderEpisodes(Resource): def get(self): # Manual Search sonarrEpisodeId = request.args.get('episodeid') - episodeInfo = database.execute( - "SELECT title, path, scene_name, audio_language, sonarrSeriesId FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarrEpisodeId,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.scene_name, + TableShows.profileId) \ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ + .dicts() \ + .get() title = episodeInfo['title'] episodePath = path_mappings.path_replace(episodeInfo['path']) sceneName = episodeInfo['scene_name'] - seriesId = episodeInfo['sonarrSeriesId'] - - seriesInfo = database.execute("SELECT profileId FROM table_shows WHERE sonarrSeriesId=?", (seriesId,), - only_one=True) - - profileId = seriesInfo['profileId'] + profileId = episodeInfo['profileId'] if sceneName is None: sceneName = "None" providers_list = get_providers() @@ -1150,8 +1240,12 @@ class ProviderEpisodes(Resource): # Manual Download sonarrSeriesId = request.args.get('seriesid') sonarrEpisodeId = request.args.get('episodeid') - episodeInfo = database.execute("SELECT title, path, scene_name FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarrEpisodeId,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.scene_name) \ + .where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \ + .dicts() \ + .get() title = episodeInfo['title'] episodePath = path_mappings.path_replace(episodeInfo['path']) @@ -1218,14 +1312,22 @@ class EpisodesHistory(Resource): else: query_actions = [1, 3] - upgradable_episodes = database.execute( - "SELECT video_path, MAX(timestamp) as timestamp, score, table_shows.tags, table_episodes.monitored, " - "table_shows.seriesType FROM table_history INNER JOIN table_episodes on " - "table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId INNER JOIN table_shows on " - "table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE action IN (" + - ','.join(map(str, query_actions)) + ") AND timestamp > ? AND score is not null" + - get_exclusion_clause('series') + " GROUP BY table_history.video_path", (minimum_timestamp,)) - + upgradable_episodes_conditions = [(TableHistory.action.in_(query_actions)), + (TableHistory.timestamp > minimum_timestamp), + (TableHistory.score is not None)] + upgradable_episodes_conditions += get_exclusion_clause('series') + upgradable_episodes = TableHistory.select(TableHistory.video_path, + fn.MAX(TableHistory.timestamp).alias('timestamp'), + TableHistory.score, + TableShows.tags, + TableEpisodes.monitored, + TableShows.seriesType)\ + .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\ + .join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, upgradable_episodes_conditions))\ + .group_by(TableHistory.video_path)\ + .dicts() + upgradable_episodes = list(upgradable_episodes) for upgradable_episode in upgradable_episodes: if upgradable_episode['timestamp'] > minimum_timestamp: try: @@ -1236,24 +1338,38 @@ class EpisodesHistory(Resource): if int(upgradable_episode['score']) < 360: upgradable_episodes_not_perfect.append(upgradable_episode) - # TODO: Find a better solution - query_limit = "" + query_conditions = [(TableEpisodes.title is not None)] if episodeid: - query_limit = f"AND table_episodes.sonarrEpisodeId={episodeid}" - - episode_history = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " - "table_episodes.season || 'x' || table_episodes.episode as episode_number, " - "table_episodes.title as episodeTitle, table_history.timestamp, table_history.subs_id, " - "table_history.description, table_history.sonarrSeriesId, table_episodes.path, " - "table_history.language, table_history.score, table_shows.tags, table_history.action, " - "table_history.subtitles_path, table_history.sonarrEpisodeId, table_history.provider, " - "table_shows.seriesType FROM table_history LEFT JOIN table_shows on " - "table_shows.sonarrSeriesId = table_history.sonarrSeriesId LEFT JOIN table_episodes on " - "table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId WHERE " - "table_episodes.title is not NULL " + query_limit + " ORDER BY timestamp DESC LIMIT ? OFFSET ?", - (length, start)) - - blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist ") + query_conditions.append((TableEpisodes.sonarrEpisodeId == episodeid)) + query_condition = reduce(operator.and_, query_conditions) + episode_history = TableHistory.select(TableShows.title.alias('seriesTitle'), + TableEpisodes.monitored, + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), + TableEpisodes.title.alias('episodeTitle'), + TableHistory.timestamp, + TableHistory.subs_id, + TableHistory.description, + TableHistory.sonarrSeriesId, + TableEpisodes.path, + TableHistory.language, + TableHistory.score, + TableShows.tags, + TableHistory.action, + TableHistory.subtitles_path, + TableHistory.sonarrEpisodeId, + TableHistory.provider, + TableShows.seriesType)\ + .join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\ + .where(query_condition)\ + .order_by(TableHistory.timestamp.desc())\ + .limit(length)\ + .offset(start)\ + .dicts() + episode_history = list(episode_history) + + blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts() + blacklist_db = list(blacklist_db) for item in episode_history: # Mark episode as upgradable or not @@ -1281,14 +1397,14 @@ class EpisodesHistory(Resource): item.update({"blacklisted": False}) if item['action'] not in [0, 4, 5]: for blacklisted_item in blacklist_db: - if blacklisted_item['provider'] == item['provider'] and blacklisted_item['subs_id'] == item[ - 'subs_id']: + if blacklisted_item['provider'] == item['provider'] and \ + blacklisted_item['subs_id'] == item['subs_id']: item.update({"blacklisted": True}) break - count = database.execute("SELECT COUNT(*) as count FROM table_history LEFT JOIN table_episodes " - "on table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId WHERE " - "table_episodes.title is not NULL", only_one=True)['count'] + count = TableHistory.select()\ + .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\ + .where(TableEpisodes.title is not None).count() return jsonify(data=episode_history, total=count) @@ -1312,11 +1428,20 @@ class MoviesHistory(Resource): else: query_actions = [1, 3] - upgradable_movies = database.execute( - "SELECT video_path, MAX(timestamp) as timestamp, score, tags, monitored FROM table_history_movie " - "INNER JOIN table_movies on table_movies.radarrId=table_history_movie.radarrId WHERE action IN (" + - ','.join(map(str, query_actions)) + ") AND timestamp > ? AND score is not NULL" + - get_exclusion_clause('movie') + " GROUP BY video_path", (minimum_timestamp,)) + upgradable_movies_conditions = [(TableHistoryMovie.action.in_(query_actions)), + (TableHistoryMovie.timestamp > minimum_timestamp), + (TableHistoryMovie.score is not None)] + upgradable_movies_conditions += get_exclusion_clause('movie') + upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path, + fn.MAX(TableHistoryMovie.timestamp).alias('timestamp'), + TableHistoryMovie.score, + TableMovies.tags, + TableMovies.monitored)\ + .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\ + .where(reduce(operator.and_, upgradable_movies_conditions))\ + .group_by(TableHistoryMovie.video_path)\ + .dicts() + upgradable_movies = list(upgradable_movies) for upgradable_movie in upgradable_movies: if upgradable_movie['timestamp'] > minimum_timestamp: @@ -1328,22 +1453,34 @@ class MoviesHistory(Resource): if int(upgradable_movie['score']) < 120: upgradable_movies_not_perfect.append(upgradable_movie) - # TODO: Find a better solution - query_limit = "" + query_conditions = [(TableMovies is not None)] if radarrid: - query_limit = f"AND table_movies.radarrid={radarrid}" - - movie_history = database.execute( - "SELECT table_history_movie.action, table_movies.title, table_history_movie.timestamp, " - "table_history_movie.description, table_history_movie.radarrId, table_movies.monitored," - "table_history_movie.video_path as path, table_history_movie.language, table_movies.tags, " - "table_history_movie.score, table_history_movie.subs_id, table_history_movie.provider, " - "table_history_movie.subtitles_path, table_history_movie.subtitles_path FROM " - "table_history_movie LEFT JOIN table_movies on table_movies.radarrId = " - "table_history_movie.radarrId WHERE table_movies.title is not NULL " + query_limit + " ORDER BY timestamp DESC LIMIT ? OFFSET ?", - (length, start)) - - blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist_movie") + query_conditions.append((TableMovies.radarrId == radarrid)) + query_condition = reduce(operator.and_, query_conditions) + + movie_history = TableHistoryMovie.select(TableHistoryMovie.action, + TableMovies.title, + TableHistoryMovie.timestamp, + TableHistoryMovie.description, + TableHistoryMovie.radarrId, + TableMovies.monitored, + TableHistoryMovie.video_path.alias('path'), + TableHistoryMovie.language, + TableMovies.tags, + TableHistoryMovie.score, + TableHistoryMovie.subs_id, + TableHistoryMovie.provider, + TableHistoryMovie.subtitles_path)\ + .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\ + .where(query_condition)\ + .order_by(TableHistoryMovie.timestamp.desc())\ + .limit(length)\ + .offset(start)\ + .dicts() + movie_history = list(movie_history) + + blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts() + blacklist_db = list(blacklist_db) for item in movie_history: # Mark movies as upgradable or not @@ -1375,9 +1512,10 @@ class MoviesHistory(Resource): item.update({"blacklisted": True}) break - count = database.execute("SELECT COUNT(*) as count FROM table_history_movie LEFT JOIN table_movies on " - "table_movies.radarrId = table_history_movie.radarrId WHERE table_movies.title " - "is not NULL", only_one=True)['count'] + count = TableHistoryMovie.select()\ + .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\ + .where(TableMovies.title is not None)\ + .count() return jsonify(data=movie_history, total=count) @@ -1390,38 +1528,56 @@ class HistoryStats(Resource): provider = request.args.get('provider') or 'All' language = request.args.get('language') or 'All' - history_where_clause = " WHERE id" - # timeframe must be in ['week', 'month', 'trimester', 'year'] if timeframe == 'year': - days = 364 + delay = 364 * 24 * 60 * 60 elif timeframe == 'trimester': - days = 90 + delay = 90 * 24 * 60 * 60 elif timeframe == 'month': - days = 30 + delay = 30 * 24 * 60 * 60 elif timeframe == 'week': - days = 6 + delay = 6 * 24 * 60 * 60 + + now = time.time() + past = now - delay + + history_where_clauses = [(TableHistory.timestamp.between(past, now))] + history_where_clauses_movie = [(TableHistoryMovie.timestamp.between(past, now))] - history_where_clause += " AND datetime(timestamp, 'unixepoch') BETWEEN datetime('now', '-" + str(days) + \ - " days') AND datetime('now', 'localtime')" if action != 'All': - history_where_clause += " AND action = " + action + history_where_clauses.append((TableHistory.action == action)) + history_where_clauses_movie.append((TableHistoryMovie.action == action)) else: - history_where_clause += " AND action IN (1,2,3)" + history_where_clauses.append((TableHistory.action.in_([1, 2, 3]))) + history_where_clauses_movie.append((TableHistoryMovie.action.in_([1, 2, 3]))) + if provider != 'All': - history_where_clause += " AND provider = '" + provider + "'" - if language != 'All': - history_where_clause += " AND language = '" + language + "'" + history_where_clauses.append((TableHistory.provider == provider)) + history_where_clauses_movie.append((TableHistoryMovie.provider == provider)) - data_series = database.execute("SELECT strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch')) as date, " - "COUNT(id) as count FROM table_history" + history_where_clause + - " GROUP BY strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch'))") - data_movies = database.execute("SELECT strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch')) as date, " - "COUNT(id) as count FROM table_history_movie" + history_where_clause + - " GROUP BY strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch'))") + if language != 'All': + history_where_clauses.append((TableHistory.language == language)) + history_where_clauses_movie.append((TableHistoryMovie.language == language)) + + history_where_clause = reduce(operator.and_, history_where_clauses) + history_where_clause_movie = reduce(operator.and_, history_where_clauses_movie) + + data_series = TableHistory.select(fn.strftime('%Y-%m-%d', TableHistory.timestamp, 'unixepoch').alias('date'), + fn.COUNT(TableHistory.id).alias('count'))\ + .where(history_where_clause) \ + .group_by(fn.strftime('%Y-%m-%d', TableHistory.timestamp, 'unixepoch'))\ + .dicts() + data_series = list(data_series) + + data_movies = TableHistoryMovie.select(fn.strftime('%Y-%m-%d', TableHistoryMovie.timestamp, 'unixepoch').alias('date'), + fn.COUNT(TableHistoryMovie.id).alias('count')) \ + .where(history_where_clause_movie) \ + .group_by(fn.strftime('%Y-%m-%d', TableHistoryMovie.timestamp, 'unixepoch')) \ + .dicts() + data_movies = list(data_movies) for dt in rrule.rrule(rrule.DAILY, - dtstart=datetime.datetime.now() - datetime.timedelta(days=days), + dtstart=datetime.datetime.now() - datetime.timedelta(seconds=delay), until=datetime.datetime.now()): if not any(d['date'] == dt.strftime('%Y-%m-%d') for d in data_series): data_series.append({'date': dt.strftime('%Y-%m-%d'), 'count': 0}) @@ -1439,37 +1595,59 @@ class EpisodesWanted(Resource): @authenticate def get(self): episodeid = request.args.getlist('episodeid[]') + + wanted_conditions = [(TableEpisodes.missing_subtitles != '[]')] if len(episodeid) > 0: - data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " - "table_episodes.season || 'x' || table_episodes.episode as episode_number, " - "table_episodes.title as episodeTitle, table_episodes.missing_subtitles, " - "table_episodes.sonarrSeriesId, " - "table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, " - "table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN " - "table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " - "table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') + - f" AND sonarrEpisodeId in {convert_list_to_clause(episodeid)}") - pass + wanted_conditions.append((TableEpisodes.sonarrEpisodeId in episodeid)) + wanted_conditions += get_exclusion_clause('series') + wanted_condition = reduce(operator.and_, wanted_conditions) + + if len(episodeid) > 0: + data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), + TableEpisodes.monitored, + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), + TableEpisodes.title.alias('episodeTitle'), + TableEpisodes.missing_subtitles, + TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.scene_name.alias('sceneName'), + TableShows.tags, + TableEpisodes.failedAttempts, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(wanted_condition)\ + .dicts() else: start = request.args.get('start') or 0 length = request.args.get('length') or -1 - data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " - "table_episodes.season || 'x' || table_episodes.episode as episode_number, " - "table_episodes.title as episodeTitle, table_episodes.missing_subtitles, " - "table_episodes.sonarrSeriesId, " - "table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, " - "table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN " - "table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " - "table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') + - " ORDER BY table_episodes._rowid_ DESC LIMIT ? OFFSET ?", (length, start)) + data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), + TableEpisodes.monitored, + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), + TableEpisodes.title.alias('episodeTitle'), + TableEpisodes.missing_subtitles, + TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.scene_name.alias('sceneName'), + TableShows.tags, + TableEpisodes.failedAttempts, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(wanted_condition)\ + .limit(length)\ + .offset(start)\ + .dicts() + data = list(data) for item in data: postprocessEpisode(item) - count = database.execute("SELECT COUNT(*) as count, table_shows.tags, table_shows.seriesType FROM " - "table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId = " - "table_episodes.sonarrSeriesId WHERE missing_subtitles != '[]'" + - get_exclusion_clause('series'), only_one=True)['count'] + count_conditions = [(TableEpisodes.missing_subtitles != '[]')] + count_conditions += get_exclusion_clause('series') + count = TableEpisodes.select(TableShows.tags, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, count_conditions))\ + .count() return jsonify(data=data, total=count) @@ -1479,25 +1657,48 @@ class MoviesWanted(Resource): @authenticate def get(self): radarrid = request.args.getlist("radarrid[]") + + wanted_conditions = [(TableMovies.missing_subtitles != '[]')] if len(radarrid) > 0: - result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, " - "failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" + - get_exclusion_clause('movie') + - f" AND radarrId in {convert_list_to_clause(radarrid)}") - pass + wanted_conditions.append((TableMovies.radarrId.in_(radarrid))) + wanted_conditions += get_exclusion_clause('movie') + wanted_condition = reduce(operator.and_, wanted_conditions) + + if len(radarrid) > 0: + result = TableMovies.select(TableMovies.title, + TableMovies.missing_subtitles, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.failedAttempts, + TableMovies.tags, + TableMovies.monitored)\ + .where(wanted_condition)\ + .dicts() else: start = request.args.get('start') or 0 length = request.args.get('length') or -1 - result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, " - "failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" + - get_exclusion_clause('movie') + - " ORDER BY _rowid_ DESC LIMIT ? OFFSET ?", (length, start)) + result = TableMovies.select(TableMovies.title, + TableMovies.missing_subtitles, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.failedAttempts, + TableMovies.tags, + TableMovies.monitored)\ + .where(wanted_condition)\ + .order_by(TableMovies.radarrId.desc())\ + .limit(length)\ + .offset(start)\ + .dicts() + result = list(result) for item in result: postprocessMovie(item) - count = database.execute("SELECT COUNT(*) as count FROM table_movies WHERE missing_subtitles != '[]'" + - get_exclusion_clause('movie'), only_one=True)['count'] + count_conditions = [(TableMovies.missing_subtitles != '[]')] + count_conditions += get_exclusion_clause('movie') + count = TableMovies.select()\ + .where(reduce(operator.and_, count_conditions))\ + .count() return jsonify(data=result, total=count) @@ -1511,14 +1712,21 @@ class EpisodesBlacklist(Resource): start = request.args.get('start') or 0 length = request.args.get('length') or -1 - data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.season || 'x' || " - "table_episodes.episode as episode_number, table_episodes.title as episodeTitle, " - "table_episodes.sonarrSeriesId, table_blacklist.provider, table_blacklist.subs_id, " - "table_blacklist.language, table_blacklist.timestamp FROM table_blacklist INNER JOIN " - "table_episodes on table_episodes.sonarrEpisodeId = table_blacklist.sonarr_episode_id " - "INNER JOIN table_shows on table_shows.sonarrSeriesId = " - "table_blacklist.sonarr_series_id ORDER BY table_blacklist.timestamp DESC LIMIT ? " - "OFFSET ?", (length, start)) + data = TableBlacklist.select(TableShows.title.alias('seriesTitle'), + TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'), + TableEpisodes.title.alias('episodeTitle'), + TableEpisodes.sonarrSeriesId, + TableBlacklist.provider, + TableBlacklist.subs_id, + TableBlacklist.language, + TableBlacklist.timestamp)\ + .join(TableEpisodes, on=(TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId))\ + .join(TableShows, on=(TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId))\ + .order_by(TableBlacklist.timestamp.desc())\ + .limit(length)\ + .offset(start)\ + .dicts() + data = list(data) for item in data: # Make timestamp pretty @@ -1537,8 +1745,10 @@ class EpisodesBlacklist(Resource): subs_id = request.form.get('subs_id') language = request.form.get('language') - episodeInfo = database.execute("SELECT path FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarr_episode_id,), only_one=True) + episodeInfo = TableEpisodes.select(TableEpisodes.path)\ + .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\ + .dicts()\ + .get() media_path = episodeInfo['path'] subtitles_path = request.form.get('subtitles_path') @@ -1580,12 +1790,18 @@ class MoviesBlacklist(Resource): start = request.args.get('start') or 0 length = request.args.get('length') or -1 - data = database.execute("SELECT table_movies.title, table_movies.radarrId, table_blacklist_movie.provider, " - "table_blacklist_movie.subs_id, table_blacklist_movie.language, " - "table_blacklist_movie.timestamp FROM table_blacklist_movie INNER JOIN " - "table_movies on table_movies.radarrId = table_blacklist_movie.radarr_id " - "ORDER BY table_blacklist_movie.timestamp DESC LIMIT ? " - "OFFSET ?", (length, start)) + data = TableBlacklistMovie.select(TableMovies.title, + TableMovies.radarrId, + TableBlacklistMovie.provider, + TableBlacklistMovie.subs_id, + TableBlacklistMovie.language, + TableBlacklistMovie.timestamp)\ + .join(TableMovies, on=(TableBlacklistMovie.radarr_id == TableMovies.radarrId))\ + .order_by(TableBlacklistMovie.timestamp.desc())\ + .limit(length)\ + .offset(start)\ + .dicts() + data = list(data) for item in data: postprocessMovie(item) @@ -1606,7 +1822,7 @@ class MoviesBlacklist(Resource): forced = False hi = False - data = database.execute("SELECT path FROM table_movies WHERE radarrId=?", (radarr_id,), only_one=True) + data = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == radarr_id).dicts().get() media_path = data['path'] subtitles_path = request.form.get('subtitles_path') @@ -1649,13 +1865,14 @@ class Subtitles(Resource): if media_type == 'episode': subtitles_path = path_mappings.path_replace(subtitles_path) - metadata = database.execute("SELECT path, sonarrSeriesId FROM table_episodes" - " WHERE sonarrEpisodeId = ?", (id,), only_one=True) + metadata = TableEpisodes.select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)\ + .where(TableEpisodes.sonarrEpisodeId == id)\ + .dicts()\ + .get() video_path = path_mappings.path_replace(metadata['path']) else: subtitles_path = path_mappings.path_replace_movie(subtitles_path) - metadata = database.execute("SELECT path FROM table_movies WHERE radarrId = ?", - (id,), only_one=True) + metadata = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == id).dicts().get() video_path = path_mappings.path_replace_movie(metadata['path']) if action == 'sync': diff --git a/bazarr/create_db.sql b/bazarr/create_db.sql deleted file mode 100644 index dc2188e44..000000000 --- a/bazarr/create_db.sql +++ /dev/null @@ -1,88 +0,0 @@ -BEGIN TRANSACTION; -CREATE TABLE "table_shows" ( - `tvdbId` INTEGER NOT NULL UNIQUE, - `title` TEXT NOT NULL, - `path` TEXT NOT NULL UNIQUE, - `languages` TEXT, - `hearing_impaired` TEXT, - `sonarrSeriesId` INTEGER NOT NULL UNIQUE, - `overview` TEXT, - `poster` TEXT, - `fanart` TEXT, - `audio_language` "text", - `sortTitle` "text", - PRIMARY KEY(`tvdbId`) -); -CREATE TABLE "table_settings_providers" ( - `name` TEXT NOT NULL UNIQUE, - `enabled` INTEGER, - `username` "text", - `password` "text", - PRIMARY KEY(`name`) -); -CREATE TABLE "table_settings_notifier" ( - `name` TEXT, - `url` TEXT, - `enabled` INTEGER, - PRIMARY KEY(`name`) -); -CREATE TABLE "table_settings_languages" ( - `code3` TEXT NOT NULL UNIQUE, - `code2` TEXT, - `name` TEXT NOT NULL, - `enabled` INTEGER, - `code3b` TEXT, - PRIMARY KEY(`code3`) -); -CREATE TABLE "table_history" ( - `id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE, - `action` INTEGER NOT NULL, - `sonarrSeriesId` INTEGER NOT NULL, - `sonarrEpisodeId` INTEGER NOT NULL, - `timestamp` INTEGER NOT NULL, - `description` TEXT NOT NULL -); -CREATE TABLE "table_episodes" ( - `sonarrSeriesId` INTEGER NOT NULL, - `sonarrEpisodeId` INTEGER NOT NULL UNIQUE, - `title` TEXT NOT NULL, - `path` TEXT NOT NULL, - `season` INTEGER NOT NULL, - `episode` INTEGER NOT NULL, - `subtitles` TEXT, - `missing_subtitles` TEXT, - `scene_name` TEXT, - `monitored` TEXT, - `failedAttempts` "text" -); -CREATE TABLE "table_movies" ( - `tmdbId` TEXT NOT NULL UNIQUE, - `title` TEXT NOT NULL, - `path` TEXT NOT NULL UNIQUE, - `languages` TEXT, - `subtitles` TEXT, - `missing_subtitles` TEXT, - `hearing_impaired` TEXT, - `radarrId` INTEGER NOT NULL UNIQUE, - `overview` TEXT, - `poster` TEXT, - `fanart` TEXT, - `audio_language` "text", - `sceneName` TEXT, - `monitored` TEXT, - `failedAttempts` "text", - PRIMARY KEY(`tmdbId`) -); -CREATE TABLE "table_history_movie" ( - `id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE, - `action` INTEGER NOT NULL, - `radarrId` INTEGER NOT NULL, - `timestamp` INTEGER NOT NULL, - `description` TEXT NOT NULL -); -CREATE TABLE "system" ( - `configured` TEXT, - `updated` TEXT -); -INSERT INTO `system` (configured, updated) VALUES ('0', '0'); -COMMIT; diff --git a/bazarr/database.py b/bazarr/database.py index 1e33edc78..edeadb285 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -1,78 +1,258 @@ -# coding=utf-8 - import os +import atexit +import json import ast -import sqlite3 import logging -import json -import re - -from sqlite3worker import Sqlite3Worker +from peewee import * +from playhouse.sqliteq import SqliteQueueDatabase +from playhouse.shortcuts import model_to_dict -from get_args import args from helper import path_mappings from config import settings, get_array_from +from get_args import args -global profile_id_list -profile_id_list = [] - - -def db_init(): - if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')): - # Get SQL script from file - fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r') - script = fd.read() - # Close SQL script file - fd.close() - # Open database connection - db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30) - c = db.cursor() - # Execute script and commit change to database - c.executescript(script) - # Close database connection - db.close() - logging.info('BAZARR Database created successfully') - - -database = Sqlite3Worker(os.path.join(args.config_dir, 'db', 'bazarr.db'), max_queue_size=256, as_dict=True) - - -class SqliteDictConverter: - def __init__(self): - self.keys_insert = tuple() - self.keys_update = tuple() - self.values = tuple() - self.question_marks = tuple() - - def convert(self, values_dict): - if type(values_dict) is dict: - self.keys_insert = tuple() - self.keys_update = tuple() - self.values = tuple() - self.question_marks = tuple() - - temp_keys = list() - temp_values = list() - for item in values_dict.items(): - temp_keys.append(item[0]) - temp_values.append(item[1]) - self.keys_insert = ','.join(temp_keys) - self.keys_update = ','.join([k + '=?' for k in temp_keys]) - self.values = tuple(temp_values) - self.question_marks = ','.join(list('?'*len(values_dict))) - return self - else: - pass - - -dict_converter = SqliteDictConverter() +database = SqliteQueueDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db'), + use_gevent=True, + autostart=True, + queue_max_size=256) + + +@atexit.register +def _stop_worker_threads(): + database.stop() + + +class UnknownField(object): + def __init__(self, *_, **__): pass + + +class BaseModel(Model): + class Meta: + database = database + + +class System(BaseModel): + configured = TextField(null=True) + updated = TextField(null=True) + + class Meta: + table_name = 'system' + primary_key = False + + +class TableBlacklist(BaseModel): + language = TextField(null=True) + provider = TextField(null=True) + sonarr_episode_id = IntegerField(null=True) + sonarr_series_id = IntegerField(null=True) + subs_id = TextField(null=True) + timestamp = IntegerField(null=True) + + class Meta: + table_name = 'table_blacklist' + primary_key = False + + +class TableBlacklistMovie(BaseModel): + language = TextField(null=True) + provider = TextField(null=True) + radarr_id = IntegerField(null=True) + subs_id = TextField(null=True) + timestamp = IntegerField(null=True) + + class Meta: + table_name = 'table_blacklist_movie' + primary_key = False + + +class TableEpisodes(BaseModel): + audio_codec = TextField(null=True) + audio_language = TextField(null=True) + episode = IntegerField() + episode_file_id = IntegerField(null=True) + failedAttempts = TextField(null=True) + ffprobe_cache = BlobField(null=True) + file_size = IntegerField(default=0, null=True) + format = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + path = TextField() + resolution = TextField(null=True) + scene_name = TextField(null=True) + season = IntegerField() + sonarrEpisodeId = IntegerField(unique=True) + sonarrSeriesId = IntegerField() + subtitles = TextField(null=True) + title = TextField() + video_codec = TextField(null=True) + + class Meta: + table_name = 'table_episodes' + primary_key = False + + +class TableHistory(BaseModel): + action = IntegerField() + description = TextField() + id = AutoField() + language = TextField(null=True) + provider = TextField(null=True) + score = TextField(null=True) + sonarrEpisodeId = IntegerField() + sonarrSeriesId = IntegerField() + subs_id = TextField(null=True) + subtitles_path = TextField(null=True) + timestamp = IntegerField() + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history' + + +class TableHistoryMovie(BaseModel): + action = IntegerField() + description = TextField() + id = AutoField() + language = TextField(null=True) + provider = TextField(null=True) + radarrId = IntegerField() + score = TextField(null=True) + subs_id = TextField(null=True) + subtitles_path = TextField(null=True) + timestamp = IntegerField() + video_path = TextField(null=True) + + class Meta: + table_name = 'table_history_movie' + + +class TableLanguagesProfiles(BaseModel): + cutoff = IntegerField(null=True) + items = TextField() + name = TextField() + profileId = AutoField() + + class Meta: + table_name = 'table_languages_profiles' + + +class TableMovies(BaseModel): + alternativeTitles = TextField(null=True) + audio_codec = TextField(null=True) + audio_language = TextField(null=True) + failedAttempts = TextField(null=True) + fanart = TextField(null=True) + ffprobe_cache = BlobField(null=True) + file_size = IntegerField(default=0, null=True) + format = TextField(null=True) + imdbId = TextField(null=True) + missing_subtitles = TextField(null=True) + monitored = TextField(null=True) + movie_file_id = IntegerField(null=True) + overview = TextField(null=True) + path = TextField(unique=True) + poster = TextField(null=True) + profileId = IntegerField(null=True) + radarrId = IntegerField(unique=True) + resolution = TextField(null=True) + sceneName = TextField(null=True) + sortTitle = TextField(null=True) + subtitles = TextField(null=True) + tags = TextField(null=True) + title = TextField() + tmdbId = TextField(primary_key=True) + video_codec = TextField(null=True) + year = TextField(null=True) + + class Meta: + table_name = 'table_movies' + + +class TableMoviesRootfolder(BaseModel): + accessible = IntegerField(null=True) + error = TextField(null=True) + id = IntegerField(null=True) + path = TextField(null=True) + + class Meta: + table_name = 'table_movies_rootfolder' + primary_key = False + + +class TableSettingsLanguages(BaseModel): + code2 = TextField(null=True) + code3 = TextField(primary_key=True) + code3b = TextField(null=True) + enabled = IntegerField(null=True) + name = TextField() + + class Meta: + table_name = 'table_settings_languages' + + +class TableSettingsNotifier(BaseModel): + enabled = IntegerField(null=True) + name = TextField(null=True, primary_key=True) + url = TextField(null=True) + + class Meta: + table_name = 'table_settings_notifier' + + +class TableShows(BaseModel): + alternateTitles = TextField(null=True) + audio_language = TextField(null=True) + fanart = TextField(null=True) + imdbId = TextField(default='""', null=True) + overview = TextField(null=True) + path = TextField(unique=True) + poster = TextField(null=True) + profileId = IntegerField(null=True) + seriesType = TextField(null=True) + sonarrSeriesId = IntegerField(unique=True) + sortTitle = TextField(null=True) + tags = TextField(null=True) + title = TextField() + tvdbId = AutoField() + year = TextField(null=True) + + class Meta: + table_name = 'table_shows' + + +class TableShowsRootfolder(BaseModel): + accessible = IntegerField(null=True) + error = TextField(null=True) + id = IntegerField(null=True) + path = TextField(null=True) + + class Meta: + table_name = 'table_shows_rootfolder' + primary_key = False + + +# Create tables if they don't exists. +database.create_tables([System, + TableBlacklist, + TableBlacklistMovie, + TableEpisodes, + TableHistory, + TableHistoryMovie, + TableLanguagesProfiles, + TableMovies, + TableMoviesRootfolder, + TableSettingsLanguages, + TableSettingsNotifier, + TableShows, + TableShowsRootfolder]) class SqliteDictPathMapper: def __init__(self): pass - def path_replace(self, values_dict): + @staticmethod + def path_replace(values_dict): if type(values_dict) is list: for item in values_dict: item['path'] = path_mappings.path_replace(item['path']) @@ -81,7 +261,8 @@ class SqliteDictPathMapper: else: return path_mappings.path_replace(values_dict) - def path_replace_movie(self, values_dict): + @staticmethod + def path_replace_movie(values_dict): if type(values_dict) is list: for item in values_dict: item['path'] = path_mappings.path_replace_movie(item['path']) @@ -94,293 +275,49 @@ class SqliteDictPathMapper: dict_mapper = SqliteDictPathMapper() -def db_upgrade(): - columnToAdd = [ - ['table_shows', 'year', 'text'], - ['table_shows', 'alternateTitles', 'text'], - ['table_shows', 'tags', 'text', '[]'], - ['table_shows', 'seriesType', 'text', ''], - ['table_shows', 'imdbId', 'text', ''], - ['table_shows', 'profileId', 'integer'], - ['table_episodes', 'format', 'text'], - ['table_episodes', 'resolution', 'text'], - ['table_episodes', 'video_codec', 'text'], - ['table_episodes', 'audio_codec', 'text'], - ['table_episodes', 'episode_file_id', 'integer'], - ['table_episodes', 'audio_language', 'text'], - ['table_episodes', 'file_size', 'integer', '0'], - ['table_episodes', 'ffprobe_cache', 'blob'], - ['table_movies', 'sortTitle', 'text'], - ['table_movies', 'year', 'text'], - ['table_movies', 'alternativeTitles', 'text'], - ['table_movies', 'format', 'text'], - ['table_movies', 'resolution', 'text'], - ['table_movies', 'video_codec', 'text'], - ['table_movies', 'audio_codec', 'text'], - ['table_movies', 'imdbId', 'text'], - ['table_movies', 'movie_file_id', 'integer'], - ['table_movies', 'tags', 'text', '[]'], - ['table_movies', 'profileId', 'integer'], - ['table_movies', 'file_size', 'integer', '0'], - ['table_movies', 'ffprobe_cache', 'blob'], - ['table_history', 'video_path', 'text'], - ['table_history', 'language', 'text'], - ['table_history', 'provider', 'text'], - ['table_history', 'score', 'text'], - ['table_history', 'subs_id', 'text'], - ['table_history', 'subtitles_path', 'text'], - ['table_history_movie', 'video_path', 'text'], - ['table_history_movie', 'language', 'text'], - ['table_history_movie', 'provider', 'text'], - ['table_history_movie', 'score', 'text'], - ['table_history_movie', 'subs_id', 'text'], - ['table_history_movie', 'subtitles_path', 'text'] - ] - - for column in columnToAdd: - try: - # Check if column already exist in table - columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0])) - columns_names_list = [x['name'] for x in columns_dict] - if column[1] in columns_names_list: - continue - - # Creating the missing column - if len(column) == 3: - database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}"'''.format(column[0], column[1], column[2])) - else: - database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}" DEFAULT "{3}"'''.format(column[0], column[1], column[2], column[3])) - logging.debug('BAZARR Database upgrade process added column {0} to table {1}.'.format(column[1], column[0])) - except: - pass - - # Create blacklist tables - database.execute("CREATE TABLE IF NOT EXISTS table_blacklist (sonarr_series_id integer, sonarr_episode_id integer, " - "timestamp integer, provider text, subs_id text, language text)") - database.execute("CREATE TABLE IF NOT EXISTS table_blacklist_movie (radarr_id integer, timestamp integer, " - "provider text, subs_id text, language text)") - - # Create rootfolder tables - database.execute("CREATE TABLE IF NOT EXISTS table_shows_rootfolder (id integer, path text, accessible integer, " - "error text)") - database.execute("CREATE TABLE IF NOT EXISTS table_movies_rootfolder (id integer, path text, accessible integer, " - "error text)") - - # Create languages profiles table and populate it - lang_table_content = database.execute("SELECT * FROM table_languages_profiles") - if isinstance(lang_table_content, list): - lang_table_exist = True - else: - lang_table_exist = False - database.execute("CREATE TABLE IF NOT EXISTS table_languages_profiles (" - "profileId INTEGER NOT NULL PRIMARY KEY, name TEXT NOT NULL, " - "cutoff INTEGER, items TEXT NOT NULL)") - - if not lang_table_exist: - series_default = [] - try: - for language in ast.literal_eval(settings.general.serie_default_language): - if settings.general.serie_default_forced == 'Both': - series_default.append([language, 'True', settings.general.serie_default_hi]) - series_default.append([language, 'False', settings.general.serie_default_hi]) - else: - series_default.append([language, settings.general.serie_default_forced, - settings.general.serie_default_hi]) - except ValueError: - pass - - movies_default = [] - try: - for language in ast.literal_eval(settings.general.movie_default_language): - if settings.general.movie_default_forced == 'Both': - movies_default.append([language, 'True', settings.general.movie_default_hi]) - movies_default.append([language, 'False', settings.general.movie_default_hi]) - else: - movies_default.append([language, settings.general.movie_default_forced, - settings.general.movie_default_hi]) - except ValueError: - pass - - profiles_to_create = database.execute("SELECT DISTINCT languages, hearing_impaired, forced " - "FROM (SELECT languages, hearing_impaired, forced FROM table_shows " - "UNION ALL SELECT languages, hearing_impaired, forced FROM table_movies) " - "a WHERE languages NOT null and languages NOT IN ('None', '[]')") - - if isinstance(profiles_to_create, list): - for profile in profiles_to_create: - profile_items = [] - languages_list = ast.literal_eval(profile['languages']) - for i, language in enumerate(languages_list, 1): - if profile['forced'] == 'Both': - profile_items.append({'id': i, 'language': language, 'forced': 'True', - 'hi': profile['hearing_impaired'], 'audio_exclude': 'False'}) - profile_items.append({'id': i, 'language': language, 'forced': 'False', - 'hi': profile['hearing_impaired'], 'audio_exclude': 'False'}) - else: - profile_items.append({'id': i, 'language': language, 'forced': profile['forced'], - 'hi': profile['hearing_impaired'], 'audio_exclude': 'False'}) - - # Create profiles - new_profile_name = profile['languages'] + ' (' + profile['hearing_impaired'] + '/' + profile['forced'] + ')' - database.execute("INSERT INTO table_languages_profiles (name, cutoff, items) VALUES(" - "?,null,?)", (new_profile_name, json.dumps(profile_items),)) - created_profile_id = database.execute("SELECT profileId FROM table_languages_profiles WHERE name = ?", - (new_profile_name,), only_one=True)['profileId'] - # Assign profiles to series and movies - database.execute("UPDATE table_shows SET profileId = ? WHERE languages = ? AND hearing_impaired = ? AND " - "forced = ?", (created_profile_id, profile['languages'], profile['hearing_impaired'], - profile['forced'])) - database.execute("UPDATE table_movies SET profileId = ? WHERE languages = ? AND hearing_impaired = ? AND " - "forced = ?", (created_profile_id, profile['languages'], profile['hearing_impaired'], - profile['forced'])) - - # Save new defaults - profile_items_list = [] - for item in profile_items: - profile_items_list.append([item['language'], item['forced'], item['hi']]) - try: - if created_profile_id and profile_items_list == series_default: - settings.general.serie_default_profile = str(created_profile_id) - except: - pass - - try: - if created_profile_id and profile_items_list == movies_default: - settings.general.movie_default_profile = str(created_profile_id) - except: - pass - - # null languages, forced and hearing_impaired for all series and movies - database.execute("UPDATE table_shows SET languages = null, forced = null, hearing_impaired = null") - database.execute("UPDATE table_movies SET languages = null, forced = null, hearing_impaired = null") - - # Force series, episodes and movies sync with Sonarr to get all the audio track from video files - # Set environment variable that is going to be use during the init process to run sync once Bazarr is ready. - os.environ['BAZARR_AUDIO_PROFILES_MIGRATION'] = '1' - - columnToRemove = [ - ['table_shows', 'languages'], - ['table_shows', 'hearing_impaired'], - ['table_shows', 'forced'], - ['table_shows', 'sizeOnDisk'], - ['table_episodes', 'file_ffprobe'], - ['table_movies', 'languages'], - ['table_movies', 'hearing_impaired'], - ['table_movies', 'forced'], - ['table_movies', 'file_ffprobe'], - ] - - for column in columnToRemove: - try: - # Check if column still exist in table - columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0])) - columns_names_list = [x['name'] for x in columns_dict] - if column[1] not in columns_names_list: - continue - - table_name = column[0] - column_name = column[1] - tables_query = database.execute("SELECT name FROM sqlite_master WHERE type = 'table'") - tables = [table['name'] for table in tables_query] - if table_name not in tables: - # Table doesn't exist in database. Skipping. - continue - - columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0])) - columns_names_list = [x['name'] for x in columns_dict] - if column_name in columns_names_list: - columns_names_list.remove(column_name) - columns_names_string = ', '.join(columns_names_list) - if not columns_names_list: - logging.debug("BAZARR No more columns in {}. We won't create an empty table. " - "Exiting.".format(table_name)) - continue - else: - logging.debug("BAZARR Column {} doesn't exist in {}".format(column_name, table_name)) - continue - - # get original sql statement used to create the table - original_sql_statement = database.execute("SELECT sql FROM sqlite_master WHERE type='table' AND " - "name='{}'".format(table_name))[0]['sql'] - # pretty format sql statement - original_sql_statement = original_sql_statement.replace('\n, ', ',\n\t') - original_sql_statement = original_sql_statement.replace('", "', '",\n\t"') - original_sql_statement = original_sql_statement.rstrip(')') + '\n' - - # generate sql statement for temp table - table_regex = re.compile(r"CREATE TABLE \"{}\"".format(table_name)) - column_regex = re.compile(r".+\"{}\".+\n".format(column_name)) - new_sql_statement = table_regex.sub("CREATE TABLE \"{}_temp\"".format(table_name), original_sql_statement) - new_sql_statement = column_regex.sub("", new_sql_statement).rstrip('\n').rstrip(',') + '\n)' - - # remove leftover temp table from previous execution - database.execute('DROP TABLE IF EXISTS {}_temp'.format(table_name)) - - # create new temp table - create_error = database.execute(new_sql_statement) - if create_error: - logging.debug('BAZARR cannot create temp table.') - continue - - # validate if row insertion worked as expected - new_table_rows = database.execute('INSERT INTO {0}_temp({1}) SELECT {1} FROM {0}'.format(table_name, - columns_names_string)) - previous_table_rows = database.execute('SELECT COUNT(*) as count FROM {}'.format(table_name), - only_one=True)['count'] - if new_table_rows == previous_table_rows: - drop_error = database.execute('DROP TABLE {}'.format(table_name)) - if drop_error: - logging.debug('BAZARR cannot drop {} table before renaming the temp table'.format(table_name)) - continue - else: - rename_error = database.execute('ALTER TABLE {0}_temp RENAME TO {0}'.format(table_name)) - if rename_error: - logging.debug('BAZARR cannot rename {}_temp table'.format(table_name)) - else: - logging.debug('BAZARR cannot insert existing rows to {} table.'.format(table_name)) - continue - except: - pass - - -def get_exclusion_clause(type): - where_clause = '' - if type == 'series': +def get_exclusion_clause(exclusion_type): + where_clause = [] + if exclusion_type == 'series': tagsList = ast.literal_eval(settings.sonarr.excluded_tags) for tag in tagsList: - where_clause += ' AND table_shows.tags NOT LIKE "%\'' + tag + '\'%"' + where_clause.append(~(TableShows.tags ** tag)) else: tagsList = ast.literal_eval(settings.radarr.excluded_tags) for tag in tagsList: - where_clause += ' AND table_movies.tags NOT LIKE "%\'' + tag + '\'%"' + where_clause.append(~(TableMovies.tags ** tag)) - if type == 'series': + if exclusion_type == 'series': monitoredOnly = settings.sonarr.getboolean('only_monitored') if monitoredOnly: - where_clause += ' AND table_episodes.monitored = "True"' + where_clause.append((TableEpisodes.monitored == 'True')) else: monitoredOnly = settings.radarr.getboolean('only_monitored') if monitoredOnly: - where_clause += ' AND table_movies.monitored = "True"' + where_clause.append((TableMovies.monitored == 'True')) - if type == 'series': + if exclusion_type == 'series': typesList = get_array_from(settings.sonarr.excluded_series_types) - for type in typesList: - where_clause += ' AND table_shows.seriesType != "' + type + '"' + for item in typesList: + where_clause.append((TableShows.seriesType != item)) return where_clause def update_profile_id_list(): global profile_id_list - profile_id_list = database.execute("SELECT profileId, name, cutoff, items FROM table_languages_profiles") - + profile_id_list = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId, + TableLanguagesProfiles.name, + TableLanguagesProfiles.cutoff, + TableLanguagesProfiles.items).dicts() + profile_id_list = list(profile_id_list) for profile in profile_id_list: - profile['items'] = json.loads(profile['items']) + profile['items'] = json.loads(profile['items']) def get_profiles_list(profile_id=None): - if not len(profile_id_list): + try: + len(profile_id_list) + except NameError: update_profile_id_list() if profile_id and profile_id != 'null': @@ -452,50 +389,29 @@ def get_audio_profile_languages(series_id=None, episode_id=None, movie_id=None): audio_languages = [] if series_id: - audio_languages_list_str = database.execute("SELECT audio_language FROM table_shows WHERE sonarrSeriesId=?", - (series_id,), only_one=True)['audio_language'] - try: - audio_languages_list = ast.literal_eval(audio_languages_list_str) - except ValueError: - pass - else: - for language in audio_languages_list: - audio_languages.append( - {"name": language, - "code2": alpha2_from_language(language) or None, - "code3": alpha3_from_language(language) or None} - ) + audio_languages_list_str = TableShows.get(TableShows.sonarrSeriesId == series_id).audio_language elif episode_id: - audio_languages_list_str = database.execute("SELECT audio_language FROM table_episodes WHERE sonarrEpisodeId=?", - (episode_id,), only_one=True)['audio_language'] - try: - audio_languages_list = ast.literal_eval(audio_languages_list_str) - except ValueError: - pass - else: - for language in audio_languages_list: - audio_languages.append( - {"name": language, - "code2": alpha2_from_language(language) or None, - "code3": alpha3_from_language(language) or None} - ) + audio_languages_list_str = TableEpisodes.get(TableEpisodes.sonarrEpisodeId == episode_id).audio_language elif movie_id: - audio_languages_list_str = database.execute("SELECT audio_language FROM table_movies WHERE radarrId=?", - (movie_id,), only_one=True)['audio_language'] - try: - audio_languages_list = ast.literal_eval(audio_languages_list_str) - except ValueError: - pass - else: - for language in audio_languages_list: - audio_languages.append( - {"name": language, - "code2": alpha2_from_language(language) or None, - "code3": alpha3_from_language(language) or None} - ) + audio_languages_list_str = TableMovies.get(TableMovies.radarrId == movie_id).audio_language + else: + return audio_languages + + try: + audio_languages_list = ast.literal_eval(audio_languages_list_str) + except ValueError: + pass + else: + for language in audio_languages_list: + audio_languages.append( + {"name": language, + "code2": alpha2_from_language(language) or None, + "code3": alpha3_from_language(language) or None} + ) return audio_languages + def convert_list_to_clause(arr: list): if isinstance(arr, list): return f"({','.join(str(x) for x in arr)})" diff --git a/bazarr/embedded_subs_reader.py b/bazarr/embedded_subs_reader.py index 4f013676d..853719176 100644 --- a/bazarr/embedded_subs_reader.py +++ b/bazarr/embedded_subs_reader.py @@ -7,7 +7,7 @@ from knowit import api import enzyme from enzyme.exceptions import MalformedMKVError from enzyme.exceptions import MalformedMKVError -from database import database +from database import TableEpisodes, TableMovies _FFPROBE_SPECIAL_LANGS = { "zho": { @@ -77,27 +77,30 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No # Get the actual cache value form database if episode_file_id: - cache_key = database.execute('SELECT ffprobe_cache FROM table_episodes WHERE episode_file_id=? AND file_size=?', - (episode_file_id, file_size), only_one=True) + cache_key = TableEpisodes.select(TableEpisodes.ffprobe_cache)\ + .where((TableEpisodes.episode_file_id == episode_file_id) and + (TableEpisodes.file_size == file_size))\ + .dicts()\ + .get() elif movie_file_id: - cache_key = database.execute('SELECT ffprobe_cache FROM table_movies WHERE movie_file_id=? AND file_size=?', - (movie_file_id, file_size), only_one=True) + cache_key = TableMovies.select(TableMovies.ffprobe_cache)\ + .where(TableMovies.movie_file_id == movie_file_id and + TableMovies.file_size == file_size)\ + .dicts()\ + .get() else: cache_key = None # check if we have a value for that cache key - if not isinstance(cache_key, dict): - return data + try: + # Unpickle ffprobe cache + cached_value = pickle.loads(cache_key['ffprobe_cache']) + except: + pass else: - try: - # Unpickle ffprobe cache - cached_value = pickle.loads(cache_key['ffprobe_cache']) - except: - pass - else: - # Check if file size and file id matches and if so, we return the cached value - if cached_value['file_size'] == file_size and cached_value['file_id'] in [episode_file_id, movie_file_id]: - return cached_value + # Check if file size and file id matches and if so, we return the cached value + if cached_value['file_size'] == file_size and cached_value['file_id'] in [episode_file_id, movie_file_id]: + return cached_value # if not, we retrieve the metadata from the file from utils import get_binary @@ -122,9 +125,11 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No # we write to db the result and return the newly cached ffprobe dict if episode_file_id: - database.execute('UPDATE table_episodes SET ffprobe_cache=? WHERE episode_file_id=?', - (pickle.dumps(data, pickle.HIGHEST_PROTOCOL), episode_file_id)) + TableEpisodes.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\ + .where(TableEpisodes.episode_file_id == episode_file_id)\ + .execute() elif movie_file_id: - database.execute('UPDATE table_movies SET ffprobe_cache=? WHERE movie_file_id=?', - (pickle.dumps(data, pickle.HIGHEST_PROTOCOL), movie_file_id)) + TableMovies.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\ + .where(TableMovies.movie_file_id == movie_file_id)\ + .execute() return data diff --git a/bazarr/get_episodes.py b/bazarr/get_episodes.py index 02b5e7c57..4d8938731 100644 --- a/bazarr/get_episodes.py +++ b/bazarr/get_episodes.py @@ -3,7 +3,7 @@ import os import requests import logging -from database import database, dict_converter, get_exclusion_clause +from database import get_exclusion_clause, TableEpisodes, TableShows from config import settings, url_sonarr from helper import path_mappings @@ -24,11 +24,11 @@ def sync_episodes(series_id=None, send_event=True): apikey_sonarr = settings.sonarr.apikey # Get current episodes id in DB - if series_id: - current_episodes_db = database.execute("SELECT sonarrEpisodeId, path, sonarrSeriesId FROM table_episodes WHERE " - "sonarrSeriesId = ?", (series_id,)) - else: - current_episodes_db = database.execute("SELECT sonarrEpisodeId, path, sonarrSeriesId FROM table_episodes") + current_episodes_db = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, + TableEpisodes.path, + TableEpisodes.sonarrSeriesId)\ + .where((TableEpisodes.sonarrSeriesId == series_id) if series_id else None)\ + .dicts() current_episodes_db_list = [x['sonarrEpisodeId'] for x in current_episodes_db] @@ -82,17 +82,31 @@ def sync_episodes(series_id=None, send_event=True): removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr)) for removed_episode in removed_episodes: - episode_to_delete = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM table_episodes WHERE " - "sonarrEpisodeId=?", (removed_episode,), only_one=True) - database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (removed_episode,)) + episode_to_delete = TableEpisodes.select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)\ + .where(TableEpisodes.sonarrEpisodeId == removed_episode)\ + .dicts()\ + .get() + TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == removed_episode).execute() if send_event: event_stream(type='episode', action='delete', payload=episode_to_delete['sonarrEpisodeId']) # Update existing episodes in DB episode_in_db_list = [] - episodes_in_db = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId, title, path, season, episode, " - "scene_name, monitored, format, resolution, video_codec, audio_codec, " - "episode_file_id, audio_language, file_size FROM table_episodes") + episodes_in_db = TableEpisodes.select(TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.title, + TableEpisodes.path, + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.scene_name, + TableEpisodes.monitored, + TableEpisodes.format, + TableEpisodes.resolution, + TableEpisodes.video_codec, + TableEpisodes.audio_codec, + TableEpisodes.episode_file_id, + TableEpisodes.audio_language, + TableEpisodes.file_size).dicts() for item in episodes_in_db: episode_in_db_list.append(item) @@ -100,19 +114,15 @@ def sync_episodes(series_id=None, send_event=True): episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list] for updated_episode in episodes_to_update_list: - query = dict_converter.convert(updated_episode) - database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''', - query.values + (updated_episode['sonarrEpisodeId'],)) + TableEpisodes.update(updated_episode).where(TableEpisodes.sonarrEpisodeId == + updated_episode['sonarrEpisodeId']).execute() altered_episodes.append([updated_episode['sonarrEpisodeId'], updated_episode['path'], updated_episode['sonarrSeriesId']]) # Insert new episodes in DB for added_episode in episodes_to_add: - query = dict_converter.convert(added_episode) - result = database.execute( - '''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' + query.question_marks + - ''')''', query.values) + result = TableEpisodes.insert(added_episode).on_conflict(action='IGNORE').execute() if result > 0: altered_episodes.append([added_episode['sonarrEpisodeId'], added_episode['path'], @@ -134,8 +144,10 @@ def sync_one_episode(episode_id): logging.debug('BAZARR syncing this specific episode from Sonarr: {}'.format(episode_id)) # Check if there's a row in database for this episode ID - existing_episode = database.execute('SELECT path FROM table_episodes WHERE sonarrEpisodeId = ?', (episode_id,), - only_one=True) + existing_episode = TableEpisodes.select(TableEpisodes.path)\ + .where(TableEpisodes.sonarrEpisodeId == episode_id)\ + .dicts()\ + .get() try: # Get episode data from sonarr api @@ -156,38 +168,34 @@ def sync_one_episode(episode_id): # Remove episode from DB if not episode and existing_episode: - database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (episode_id,)) + TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == episode_id).execute() event_stream(type='episode', action='delete', payload=int(episode_id)) logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace( - existing_episode['path']))) + existing_episode['path)']))) return # Update existing episodes in DB elif episode and existing_episode: - query = dict_converter.convert(episode) - database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''', - query.values + (episode['sonarrEpisodeId'],)) + TableEpisodes.update(episode).where(TableEpisodes.sonarrEpisodeId == episode.sonarrEpisodeId).execute() event_stream(type='episode', action='update', payload=int(episode_id)) logging.debug('BAZARR updated this episode into the database:{}'.format(path_mappings.path_replace( - episode['path']))) + episode.path))) # Insert new episodes in DB elif episode and not existing_episode: - query = dict_converter.convert(episode) - database.execute('''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' + - query.question_marks + ''')''', query.values) + TableEpisodes.insert(episode).on_conflict(action='IGNORE').execute() event_stream(type='episode', action='update', payload=int(episode_id)) logging.debug('BAZARR inserted this episode into the database:{}'.format(path_mappings.path_replace( - episode['path']))) + episode.path))) # Storing existing subtitles logging.debug('BAZARR storing subtitles for this episode: {}'.format(path_mappings.path_replace( - episode['path']))) - store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) + episode.path))) + store_subtitles(episode.path, path_mappings.path_replace(episode.path)) # Downloading missing subtitles logging.debug('BAZARR downloading missing subtitles for this episode: {}'.format(path_mappings.path_replace( - episode['path']))) + episode.path))) episode_download_subtitles(episode_id) @@ -248,9 +256,7 @@ def episodeParser(episode): if 'name' in item: audio_language.append(item['name']) else: - audio_language = database.execute("SELECT audio_language FROM table_shows WHERE " - "sonarrSeriesId=?", (episode['seriesId'],), - only_one=True)['audio_language'] + audio_language = TableShows.get(TableShows == episode['seriesId']).audio_language if 'mediaInfo' in episode['episodeFile']: if 'videoCodec' in episode['episodeFile']['mediaInfo']: @@ -317,8 +323,9 @@ def get_series_from_sonarr_api(series_id, url, apikey_sonarr): logging.exception("BAZARR Error trying to get series from Sonarr.") return else: + series_json = [] if series_id: - series_json = list(r.json()) + series_json.append(r.json()) else: series_json = r.json() series_list = [] diff --git a/bazarr/get_languages.py b/bazarr/get_languages.py index a38a90215..8c43bfc18 100644 --- a/bazarr/get_languages.py +++ b/bazarr/get_languages.py @@ -3,7 +3,7 @@ import pycountry from subzero.language import Language -from database import database +from database import database, TableSettingsLanguages def load_language_in_db(): @@ -13,21 +13,35 @@ def load_language_in_db(): if hasattr(lang, 'alpha_2')] # Insert languages in database table - database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) VALUES (?, ?, ?)", - langs, execute_many=True) - - database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) " - "VALUES ('pob', 'pb', 'Brazilian Portuguese')") - - database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) " - "VALUES ('zht', 'zt', 'Chinese Traditional')") + TableSettingsLanguages.insert_many(langs, + fields=[TableSettingsLanguages.code3, TableSettingsLanguages.code2, + TableSettingsLanguages.name]) \ + .on_conflict(action='IGNORE') \ + .execute() + + TableSettingsLanguages.insert({TableSettingsLanguages.code3: 'pob', TableSettingsLanguages.code2: 'pb', + TableSettingsLanguages.name: 'Brazilian Portuguese'}) \ + .on_conflict(action='IGNORE') \ + .execute() + + # update/insert chinese languages + TableSettingsLanguages.update({TableSettingsLanguages.name: 'Chinese Simplified'}) \ + .where(TableSettingsLanguages.code2 == 'zt')\ + .execute() + TableSettingsLanguages.insert({TableSettingsLanguages.code3: 'zht', TableSettingsLanguages.code2: 'zt', + TableSettingsLanguages.name: 'Chinese Traditional'}) \ + .on_conflict(action='IGNORE')\ + .execute() langs = [[lang.bibliographic, lang.alpha_3] for lang in pycountry.languages if hasattr(lang, 'alpha_2') and hasattr(lang, 'bibliographic')] # Update languages in database table - database.execute("UPDATE table_settings_languages SET code3b=? WHERE code3=?", langs, execute_many=True) + for lang in langs: + TableSettingsLanguages.update({TableSettingsLanguages.code3b: lang[0]}) \ + .where(TableSettingsLanguages.code3 == lang[1]) \ + .execute() # Create languages dictionary for faster conversion than calling database create_languages_dict() @@ -35,10 +49,10 @@ def load_language_in_db(): def create_languages_dict(): global languages_dict - #replace chinese by chinese simplified - database.execute("UPDATE table_settings_languages SET name = 'Chinese Simplified' WHERE code3 = 'zho'") - - languages_dict = database.execute("SELECT name, code2, code3, code3b FROM table_settings_languages") + languages_dict = TableSettingsLanguages.select(TableSettingsLanguages.name, + TableSettingsLanguages.code2, + TableSettingsLanguages.code3, + TableSettingsLanguages.code3b).dicts() def language_from_alpha2(lang): @@ -68,7 +82,8 @@ def alpha3_from_language(lang): def get_language_set(): - languages = database.execute("SELECT code3 FROM table_settings_languages WHERE enabled=1") + languages = TableSettingsLanguages.select(TableSettingsLanguages.code3) \ + .where(TableSettingsLanguages.enabled == 1).dicts() language_set = set() diff --git a/bazarr/get_movies.py b/bazarr/get_movies.py index 0708697a0..3dff42bbb 100644 --- a/bazarr/get_movies.py +++ b/bazarr/get_movies.py @@ -3,6 +3,8 @@ import os import requests import logging +import operator +from functools import reduce from config import settings, url_radarr from helper import path_mappings @@ -11,7 +13,7 @@ from list_subtitles import store_subtitles_movie, movies_full_scan_subtitles from get_rootfolder import check_radarr_rootfolder from get_subtitle import movies_download_subtitles -from database import database, dict_converter, get_exclusion_clause +from database import get_exclusion_clause, TableMovies from event_handler import event_stream, show_progress, hide_progress headers = {"User-Agent": os.environ["SZ_USER_AGENT"]} @@ -50,7 +52,7 @@ def update_movies(send_event=True): return else: # Get current movies in DB - current_movies_db = database.execute("SELECT tmdbId, path, radarrId FROM table_movies") + current_movies_db = TableMovies.select(TableMovies.tmdbId, TableMovies.path, TableMovies.radarrId).dicts() current_movies_db_list = [x['tmdbId'] for x in current_movies_db] @@ -101,14 +103,31 @@ def update_movies(send_event=True): removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr)) for removed_movie in removed_movies: - database.execute("DELETE FROM table_movies WHERE tmdbId=?", (removed_movie,)) + TableMovies.delete().where(TableMovies.tmdbId == removed_movie).execute() # Update movies in DB movies_in_db_list = [] - movies_in_db = database.execute("SELECT radarrId, title, path, tmdbId, overview, poster, fanart, " - "audio_language, sceneName, monitored, sortTitle, year, " - "alternativeTitles, format, resolution, video_codec, audio_codec, imdbId," - "movie_file_id, tags, file_size FROM table_movies") + movies_in_db = TableMovies.select(TableMovies.radarrId, + TableMovies.title, + TableMovies.path, + TableMovies.tmdbId, + TableMovies.overview, + TableMovies.poster, + TableMovies.fanart, + TableMovies.audio_language, + TableMovies.sceneName, + TableMovies.monitored, + TableMovies.sortTitle, + TableMovies.year, + TableMovies.alternativeTitles, + TableMovies.format, + TableMovies.resolution, + TableMovies.video_codec, + TableMovies.audio_codec, + TableMovies.imdbId, + TableMovies.movie_file_id, + TableMovies.tags, + TableMovies.file_size).dicts() for item in movies_in_db: movies_in_db_list.append(item) @@ -116,9 +135,7 @@ def update_movies(send_event=True): movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list] for updated_movie in movies_to_update_list: - query = dict_converter.convert(updated_movie) - database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE tmdbId = ?''', - query.values + (updated_movie['tmdbId'],)) + TableMovies.update(updated_movie).where(TableMovies.tmdbId == updated_movie['tmdbId']).execute() altered_movies.append([updated_movie['tmdbId'], updated_movie['path'], updated_movie['radarrId'], @@ -126,10 +143,7 @@ def update_movies(send_event=True): # Insert new movies in DB for added_movie in movies_to_add: - query = dict_converter.convert(added_movie) - result = database.execute( - '''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' + - query.question_marks + ''')''', query.values) + result = TableMovies.insert(added_movie).on_conflict(action='IGNORE').execute() if result > 0: altered_movies.append([added_movie['tmdbId'], added_movie['path'], @@ -151,8 +165,9 @@ def update_movies(send_event=True): if len(altered_movies) <= 5: logging.debug("BAZARR No more than 5 movies were added during this sync then we'll search for subtitles.") for altered_movie in altered_movies: - data = database.execute("SELECT * FROM table_movies WHERE radarrId = ?" + - get_exclusion_clause('movie'), (altered_movie[2],), only_one=True) + conditions = [(TableMovies.radarrId == altered_movie[2])] + conditions += get_exclusion_clause('movie') + data = TableMovies.get(reduce(operator.and_, conditions)) if data: movies_download_subtitles(data['radarrId']) else: @@ -165,15 +180,15 @@ def update_one_movie(movie_id, action): logging.debug('BAZARR syncing this specific movie from Radarr: {}'.format(movie_id)) # Check if there's a row in database for this movie ID - existing_movie = database.execute('SELECT path FROM table_movies WHERE radarrId = ?', (movie_id,), only_one=True) + existing_movie = TableMovies.get_or_none(TableMovies.radarrId == movie_id) # Remove movie from DB if action == 'deleted': if existing_movie: - database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,)) + TableMovies.delete().where(TableMovies.radarrId == movie_id).execute() event_stream(type='movie', action='delete', payload=int(movie_id)) logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( - existing_movie['path']))) + existing_movie.path))) return radarr_version = get_radarr_version() @@ -215,26 +230,22 @@ def update_one_movie(movie_id, action): # Remove movie from DB if not movie and existing_movie: - database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,)) + TableMovies.delete().where(TableMovies.radarrId == movie_id).execute() event_stream(type='movie', action='delete', payload=int(movie_id)) logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie( - existing_movie['path']))) + existing_movie.path))) return # Update existing movie in DB elif movie and existing_movie: - query = dict_converter.convert(movie) - database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE radarrId = ?''', - query.values + (movie['radarrId'],)) + TableMovies.update(movie).where(TableMovies.radarrId == movie['radarrId']).execute() event_stream(type='movie', action='update', payload=int(movie_id)) logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie( movie['path']))) # Insert new movie in DB elif movie and not existing_movie: - query = dict_converter.convert(movie) - database.execute('''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' + - query.question_marks + ''')''', query.values) + TableMovies.insert(movie).on_conflict(action='IGNORE').execute() event_stream(type='movie', action='update', payload=int(movie_id)) logging.debug('BAZARR inserted this movie into the database:{}'.format(path_mappings.path_replace_movie( movie['path']))) diff --git a/bazarr/get_rootfolder.py b/bazarr/get_rootfolder.py index 625ab523e..f1704ec32 100644 --- a/bazarr/get_rootfolder.py +++ b/bazarr/get_rootfolder.py @@ -6,7 +6,7 @@ import logging from config import settings, url_sonarr, url_radarr from helper import path_mappings -from database import database +from database import TableShowsRootfolder, TableMoviesRootfolder headers = {"User-Agent": os.environ["SZ_USER_AGENT"]} @@ -32,7 +32,7 @@ def get_sonarr_rootfolder(): else: for folder in rootfolder.json(): sonarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) - db_rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder') + db_rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts() rootfolder_to_remove = [x for x in db_rootfolder if not next((item for item in sonarr_rootfolder if item['id'] == x['id']), False)] rootfolder_to_update = [x for x in sonarr_rootfolder if @@ -41,26 +41,37 @@ def get_sonarr_rootfolder(): next((item for item in db_rootfolder if item['id'] == x['id']), False)] for item in rootfolder_to_remove: - database.execute('DELETE FROM table_shows_rootfolder WHERE id = ?', (item['id'],)) + TableShowsRootfolder.delete().where(TableShowsRootfolder.id == item['id']).execute() for item in rootfolder_to_update: - database.execute('UPDATE table_shows_rootfolder SET path=? WHERE id = ?', (item['path'], item['id'])) + TableShowsRootfolder.update({TableShowsRootfolder.path: item['path']})\ + .where(TableShowsRootfolder.id == item['id'])\ + .execute() for item in rootfolder_to_insert: - database.execute('INSERT INTO table_shows_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path'])) + TableShowsRootfolder.insert({TableShowsRootfolder.id: item['id'], TableShowsRootfolder.path: item['path']})\ + .execute() def check_sonarr_rootfolder(): get_sonarr_rootfolder() - rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder') + rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts() for item in rootfolder: if not os.path.isdir(path_mappings.path_replace(item['path'])): - database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'This Sonarr root directory " - "does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?", - (item['id'],)) + TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0, + TableShowsRootfolder.error: 'This Sonarr root directory does not seems to ' + 'be accessible by Bazarr. Please check path ' + 'mapping.'})\ + .where(TableShowsRootfolder.id == item['id'])\ + .execute() elif not os.access(path_mappings.path_replace(item['path']), os.W_OK): - database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'Bazarr cannot write to " - "this directory' WHERE id = ?", (item['id'],)) + TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0, + TableShowsRootfolder.error: 'Bazarr cannot write to this directory.'}) \ + .where(TableShowsRootfolder.id == item['id']) \ + .execute() else: - database.execute("UPDATE table_shows_rootfolder SET accessible = 1, error = '' WHERE id = ?", (item['id'],)) + TableShowsRootfolder.update({TableShowsRootfolder.accessible: 1, + TableShowsRootfolder.error: ''}) \ + .where(TableShowsRootfolder.id == item['id']) \ + .execute() def get_radarr_rootfolder(): @@ -84,7 +95,7 @@ def get_radarr_rootfolder(): else: for folder in rootfolder.json(): radarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) - db_rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder') + db_rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts() rootfolder_to_remove = [x for x in db_rootfolder if not next((item for item in radarr_rootfolder if item['id'] == x['id']), False)] rootfolder_to_update = [x for x in radarr_rootfolder if @@ -93,24 +104,33 @@ def get_radarr_rootfolder(): next((item for item in db_rootfolder if item['id'] == x['id']), False)] for item in rootfolder_to_remove: - database.execute('DELETE FROM table_movies_rootfolder WHERE id = ?', (item['id'],)) + TableMoviesRootfolder.delete().where(TableMoviesRootfolder.id == item['id']).execute() for item in rootfolder_to_update: - database.execute('UPDATE table_movies_rootfolder SET path=? WHERE id = ?', (item['path'], item['id'])) + TableMoviesRootfolder.update({TableMoviesRootfolder.path: item['path']})\ + .where(TableMoviesRootfolder.id == item['id']).execute() for item in rootfolder_to_insert: - database.execute('INSERT INTO table_movies_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path'])) + TableMoviesRootfolder.insert({TableMoviesRootfolder.id: item['id'], + TableMoviesRootfolder.path: item['path']}).execute() def check_radarr_rootfolder(): get_radarr_rootfolder() - rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder') + rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts() for item in rootfolder: if not os.path.isdir(path_mappings.path_replace_movie(item['path'])): - database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'This Radarr root directory " - "does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?", - (item['id'],)) + TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0, + TableMoviesRootfolder.error: 'This Radarr root directory does not seems to ' + 'be accessible by Bazarr. Please check path ' + 'mapping.'}) \ + .where(TableMoviesRootfolder.id == item['id']) \ + .execute() elif not os.access(path_mappings.path_replace_movie(item['path']), os.W_OK): - database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'Bazarr cannot write to " - "this directory' WHERE id = ?", (item['id'],)) + TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0, + TableMoviesRootfolder.error: 'Bazarr cannot write to this directory'}) \ + .where(TableMoviesRootfolder.id == item['id']) \ + .execute() else: - database.execute("UPDATE table_movies_rootfolder SET accessible = 1, error = '' WHERE id = ?", - (item['id'],)) + TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 1, + TableMoviesRootfolder.error: ''}) \ + .where(TableMoviesRootfolder.id == item['id']) \ + .execute() diff --git a/bazarr/get_series.py b/bazarr/get_series.py index e3dc1030d..1b9c6fe08 100644 --- a/bazarr/get_series.py +++ b/bazarr/get_series.py @@ -7,7 +7,7 @@ import logging from config import settings, url_sonarr from list_subtitles import list_missing_subtitles from get_rootfolder import check_sonarr_rootfolder -from database import database, dict_converter +from database import TableShows from utils import get_sonarr_version from helper import path_mappings from event_handler import event_stream, show_progress, hide_progress @@ -40,7 +40,7 @@ def update_series(send_event=True): return else: # Get current shows in DB - current_shows_db = database.execute("SELECT sonarrSeriesId FROM table_shows") + current_shows_db = TableShows.select(TableShows.sonarrSeriesId).dicts() current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db] current_shows_sonarr = [] @@ -81,15 +81,26 @@ def update_series(send_event=True): removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr)) for series in removed_series: - database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?",(series,)) + TableShows.delete().where(TableShows.sonarrSeriesId == series).execute() if send_event: event_stream(type='series', action='delete', payload=series) # Update existing series in DB series_in_db_list = [] - series_in_db = database.execute("SELECT title, path, tvdbId, sonarrSeriesId, overview, poster, fanart, " - "audio_language, sortTitle, year, alternateTitles, tags, seriesType, imdbId " - "FROM table_shows") + series_in_db = TableShows.select(TableShows.title, + TableShows.path, + TableShows.tvdbId, + TableShows.sonarrSeriesId, + TableShows.overview, + TableShows.poster, + TableShows.fanart, + TableShows.audio_language, + TableShows.sortTitle, + TableShows.year, + TableShows.alternateTitles, + TableShows.tags, + TableShows.seriesType, + TableShows.imdbId).dicts() for item in series_in_db: series_in_db_list.append(item) @@ -97,18 +108,14 @@ def update_series(send_event=True): series_to_update_list = [i for i in series_to_update if i not in series_in_db_list] for updated_series in series_to_update_list: - query = dict_converter.convert(updated_series) - database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''', - query.values + (updated_series['sonarrSeriesId'],)) + TableShows.update(updated_series).where(TableShows.sonarrSeriesId == + updated_series['sonarrSeriesId']).execute() if send_event: event_stream(type='series', payload=updated_series['sonarrSeriesId']) # Insert new series in DB for added_series in series_to_add: - query = dict_converter.convert(added_series) - result = database.execute( - '''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' + - query.question_marks + ''')''', query.values) + result = TableShows.insert(added_series).on_conflict(action='IGNORE').execute() if result: list_missing_subtitles(no=added_series['sonarrSeriesId']) else: @@ -125,8 +132,7 @@ def update_one_series(series_id, action): logging.debug('BAZARR syncing this specific series from RSonarr: {}'.format(series_id)) # Check if there's a row in database for this series ID - existing_series = database.execute('SELECT path FROM table_shows WHERE sonarrSeriesId = ?', (series_id,), - only_one=True) + existing_series = TableShows.get_or_none(TableShows.sonarrSeriesId == series_id) sonarr_version = get_sonarr_version() serie_default_enabled = settings.general.getboolean('serie_default_enabled') @@ -149,7 +155,7 @@ def update_one_series(series_id, action): series_data = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=settings.sonarr.apikey, sonarr_series_id=series_id) except requests.exceptions.HTTPError: - database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,)) + TableShows.delete().where(TableShows.sonarrSeriesId == series_id).execute() event_stream(type='series', action='delete', payload=int(series_id)) return @@ -170,26 +176,22 @@ def update_one_series(series_id, action): # Remove series from DB if action == 'deleted': - database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,)) + TableShows.delete().where(TableShows.sonarrSeriesId == series_id).execute() event_stream(type='series', action='delete', payload=int(series_id)) logging.debug('BAZARR deleted this series from the database:{}'.format(path_mappings.path_replace( - existing_series['path']))) + existing_series.path))) return # Update existing series in DB elif action == 'updated' and existing_series: - query = dict_converter.convert(series) - database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''', - query.values + (series['sonarrSeriesId'],)) + TableShows.update(series).where(TableShows.sonarrSeriesId == series['sonarrSeriesId']).execute() event_stream(type='series', action='update', payload=int(series_id)) logging.debug('BAZARR updated this series into the database:{}'.format(path_mappings.path_replace( series['path']))) # Insert new series in DB elif action == 'updated' and not existing_series: - query = dict_converter.convert(series) - database.execute('''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' + - query.question_marks + ''')''', query.values) + TableShows.insert(series).on_conflict(action='IGNORE').execute() event_stream(type='series', action='update', payload=int(series_id)) logging.debug('BAZARR inserted this series into the database:{}'.format(path_mappings.path_replace( series['path']))) diff --git a/bazarr/get_subtitle.py b/bazarr/get_subtitle.py index 4aa16ab6b..2fd84faf9 100644 --- a/bazarr/get_subtitle.py +++ b/bazarr/get_subtitle.py @@ -11,6 +11,9 @@ import codecs import re import subliminal import copy +import operator +from functools import reduce +from peewee import fn from datetime import datetime, timedelta from subzero.language import Language from subzero.video import parse_video @@ -31,8 +34,8 @@ from get_providers import get_providers, get_providers_auth, provider_throttle, from knowit import api from subsyncer import subsync from guessit import guessit -from database import database, dict_mapper, get_exclusion_clause, get_profiles_list, get_audio_profile_languages, \ - get_desired_languages +from database import dict_mapper, get_exclusion_clause, get_profiles_list, get_audio_profile_languages, \ + get_desired_languages, TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie from event_handler import event_stream, show_progress, hide_progress from embedded_subs_reader import parse_video_metadata @@ -253,10 +256,11 @@ def download_subtitle(path, language, audio_language, hi, forced, providers, pro downloaded_provider + " with a score of " + str(percent_score) + "%." if media_type == 'series': - episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM " - "table_episodes WHERE path = ?", - (path_mappings.path_replace_reverse(path),), - only_one=True) + episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId)\ + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\ + .dicts()\ + .get() series_id = episode_metadata['sonarrSeriesId'] episode_id = episode_metadata['sonarrEpisodeId'] sync_subtitles(video_path=path, srt_path=downloaded_path, @@ -265,9 +269,10 @@ def download_subtitle(path, language, audio_language, hi, forced, providers, pro sonarr_series_id=episode_metadata['sonarrSeriesId'], sonarr_episode_id=episode_metadata['sonarrEpisodeId']) else: - movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?", - (path_mappings.path_replace_reverse_movie(path),), - only_one=True) + movie_metadata = TableMovies.select(TableMovies.radarrId)\ + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ + .dicts()\ + .get() series_id = "" episode_id = movie_metadata['radarrId'] sync_subtitles(video_path=path, srt_path=downloaded_path, @@ -580,10 +585,11 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl downloaded_provider + " with a score of " + str(score) + "% using manual search." if media_type == 'series': - episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM " - "table_episodes WHERE path = ?", - (path_mappings.path_replace_reverse(path),), - only_one=True) + episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId)\ + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\ + .dicts()\ + .get() series_id = episode_metadata['sonarrSeriesId'] episode_id = episode_metadata['sonarrEpisodeId'] sync_subtitles(video_path=path, srt_path=downloaded_path, @@ -592,9 +598,10 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl sonarr_series_id=episode_metadata['sonarrSeriesId'], sonarr_episode_id=episode_metadata['sonarrEpisodeId']) else: - movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?", - (path_mappings.path_replace_reverse_movie(path),), - only_one=True) + movie_metadata = TableMovies.select(TableMovies.radarrId)\ + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ + .dicts()\ + .get() series_id = "" episode_id = movie_metadata['radarrId'] sync_subtitles(video_path=path, srt_path=downloaded_path, @@ -710,18 +717,20 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type audio_language_code3 = alpha3_from_language(audio_language) if media_type == 'series': - episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM table_episodes WHERE path = ?", - (path_mappings.path_replace_reverse(path),), - only_one=True) + episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)\ + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\ + .dicts()\ + .get() series_id = episode_metadata['sonarrSeriesId'] episode_id = episode_metadata['sonarrEpisodeId'] sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type, percent_score=100, sonarr_series_id=episode_metadata['sonarrSeriesId'], sonarr_episode_id=episode_metadata['sonarrEpisodeId']) else: - movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?", - (path_mappings.path_replace_reverse_movie(path),), - only_one=True) + movie_metadata = TableMovies.select(TableMovies.radarrId)\ + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ + .dicts()\ + .get() series_id = "" episode_id = movie_metadata['radarrId'] sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type, @@ -746,13 +755,24 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type def series_download_subtitles(no): - episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, monitored, " - "table_episodes.sonarrEpisodeId, table_episodes.scene_name, table_shows.tags, " - "table_shows.seriesType, table_episodes.audio_language, table_shows.title, " - "table_episodes.season, table_episodes.episode, table_episodes.title as episodeTitle " - "FROM table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId = " - "table_episodes.sonarrSeriesId WHERE table_episodes.sonarrSeriesId=? and " - "missing_subtitles!='[]'" + get_exclusion_clause('series'), (no,)) + conditions = [(TableEpisodes.sonarrSeriesId == no), + (TableEpisodes.missing_subtitles != '[]')] + conditions += get_exclusion_clause('series') + episodes_details = TableEpisodes.select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.monitored, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.scene_name, + TableShows.tags, + TableShows.seriesType, + TableEpisodes.audio_language, + TableShows.title, + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.alias('episodeTitle'))\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, conditions))\ + .dicts() if not episodes_details: logging.debug("BAZARR no episode for that sonarrSeriesId can be found in database:", str(no)) return @@ -821,16 +841,23 @@ def series_download_subtitles(no): hide_progress(id='series_search_progress_{}'.format(no)) - - def episode_download_subtitles(no): - episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, monitored, " - "table_episodes.sonarrEpisodeId, table_episodes.scene_name, table_shows.tags, " - "table_shows.title, table_shows.sonarrSeriesId, table_episodes.audio_language, " - "table_shows.seriesType FROM table_episodes LEFT JOIN table_shows on " - "table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId WHERE sonarrEpisodeId=?" + - get_exclusion_clause('series'), (no,)) + conditions = [(TableEpisodes.sonarrEpisodeId == no)] + conditions += get_exclusion_clause('series') + episodes_details = TableEpisodes.select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.monitored, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.scene_name, + TableShows.tags, + TableShows.title, + TableShows.sonarrSeriesId, + TableEpisodes.audio_language, + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, conditions))\ + .dicts() if not episodes_details: logging.debug("BAZARR no episode with that sonarrEpisodeId can be found in database:", str(no)) return @@ -882,9 +909,18 @@ def episode_download_subtitles(no): def movies_download_subtitles(no): - movies = database.execute( - "SELECT path, missing_subtitles, audio_language, radarrId, sceneName, title, tags, " - "monitored FROM table_movies WHERE radarrId=?" + get_exclusion_clause('movie'), (no,)) + conditions = [(TableMovies.radarrId == no)] + conditions += get_exclusion_clause('movie') + movies = TableMovies.select(TableMovies.path, + TableMovies.missing_subtitles, + TableMovies.audio_language, + TableMovies.radarrId, + TableMovies.sceneName, + TableMovies.title, + TableMovies.tags, + TableMovies.monitored)\ + .where(reduce(operator.and_, conditions))\ + .dicts() if not len(movies): logging.debug("BAZARR no movie with that radarrId can be found in database:", str(no)) return @@ -955,14 +991,19 @@ def movies_download_subtitles(no): def wanted_download_subtitles(path): - episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, " - "table_episodes.sonarrEpisodeId, table_episodes.sonarrSeriesId, " - "table_episodes.audio_language, table_episodes.scene_name," - "table_episodes.failedAttempts, table_shows.title " - "FROM table_episodes LEFT JOIN table_shows on " - "table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId " - "WHERE table_episodes.path=? and table_episodes.missing_subtitles!='[]'", - (path_mappings.path_replace_reverse(path),)) + episodes_details = TableEpisodes.select(TableEpisodes.path, + TableEpisodes.missing_subtitles, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.sonarrSeriesId, + TableEpisodes.audio_language, + TableEpisodes.scene_name, + TableEpisodes.failedAttempts, + TableShows.title)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where((TableEpisodes.path == path_mappings.path_replace_reverse(path)) and + TableEpisodes.missing_subtitles != '[]')\ + .dicts() + episodes_details = list(episodes_details) providers_list = get_providers() providers_auth = get_providers_auth() @@ -980,8 +1021,9 @@ def wanted_download_subtitles(path): if language not in att: attempt.append([language, time.time()]) - database.execute("UPDATE table_episodes SET failedAttempts=? WHERE sonarrEpisodeId=?", - (str(attempt), episode['sonarrEpisodeId'])) + TableEpisodes.update({TableEpisodes.failedAttempts: str(attempt)})\ + .where(TableEpisodes.sonarrEpisodeId == episode['sonarrEpisodeId'])\ + .execute() for i in range(len(attempt)): if attempt[i][0] == language: @@ -1028,10 +1070,17 @@ def wanted_download_subtitles(path): def wanted_download_subtitles_movie(path): - movies_details = database.execute( - "SELECT path, missing_subtitles, radarrId, audio_language, sceneName, " - "failedAttempts, title FROM table_movies WHERE path = ? " - "AND missing_subtitles != '[]'", (path_mappings.path_replace_reverse_movie(path),)) + movies_details = TableMovies.select(TableMovies.path, + TableMovies.missing_subtitles, + TableMovies.radarrId, + TableMovies.audio_language, + TableMovies.sceneName, + TableMovies.failedAttempts, + TableMovies.title)\ + .where((TableMovies.path == path_mappings.path_replace_reverse_movie(path)) and + (TableMovies.missing_subtitles != '[]'))\ + .dicts() + movies_details = list(movies_details) providers_list = get_providers() providers_auth = get_providers_auth() @@ -1049,8 +1098,9 @@ def wanted_download_subtitles_movie(path): if language not in att: attempt.append([language, time.time()]) - database.execute("UPDATE table_movies SET failedAttempts=? WHERE radarrId=?", - (str(attempt), movie['radarrId'])) + TableMovies.update({TableMovies.failedAttempts: str(attempt)})\ + .where(TableMovies.radarrId == movie['radarrId'])\ + .execute() for i in range(len(attempt)): if attempt[i][0] == language: @@ -1097,11 +1147,20 @@ def wanted_download_subtitles_movie(path): def wanted_search_missing_subtitles_series(): - episodes = database.execute("SELECT table_episodes.path, table_shows.tags, table_episodes.monitored, " - "table_shows.title, table_episodes.season, table_episodes.episode, table_episodes.title" - " as episodeTitle, table_shows.seriesType FROM table_episodes INNER JOIN table_shows on" - " table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE missing_subtitles !=" - " '[]'" + get_exclusion_clause('series')) + conditions = [(TableEpisodes.missing_subtitles != '[]')] + conditions += get_exclusion_clause('series') + episodes = TableEpisodes.select(TableEpisodes.path, + TableShows.tags, + TableEpisodes.monitored, + TableShows.title, + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.alias('episodeTitle'), + TableShows.seriesType)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(reduce(operator.and_, conditions))\ + .dicts() + episodes = list(episodes) # path_replace dict_mapper.path_replace(episodes) @@ -1135,8 +1194,15 @@ def wanted_search_missing_subtitles_series(): def wanted_search_missing_subtitles_movies(): - movies = database.execute("SELECT path, tags, monitored, title FROM table_movies WHERE missing_subtitles != '[]'" + - get_exclusion_clause('movie')) + conditions = [(TableMovies.missing_subtitles != '[]')] + conditions + get_exclusion_clause('movie') + movies = TableMovies.select(TableMovies.path, + TableMovies.tags, + TableMovies.monitored, + TableMovies.title)\ + .where(reduce(operator.and_, conditions))\ + .dicts() + movies = list(movies) # path_replace dict_mapper.path_replace_movie(movies) @@ -1194,16 +1260,25 @@ def convert_to_guessit(guessit_key, attr_from_db): def refine_from_db(path, video): if isinstance(video, Episode): - data = database.execute( - "SELECT table_shows.title as seriesTitle, table_episodes.season, table_episodes.episode, " - "table_episodes.title as episodeTitle, table_shows.year, table_shows.tvdbId, " - "table_shows.alternateTitles, table_episodes.format, table_episodes.resolution, " - "table_episodes.video_codec, table_episodes.audio_codec, table_episodes.path, table_shows.imdbId " - "FROM table_episodes INNER JOIN table_shows on " - "table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId " - "WHERE table_episodes.path = ?", (path_mappings.path_replace_reverse(path),), only_one=True) - - if data: + data = TableEpisodes.select(TableShows.title.alias('seriesTitle'), + TableEpisodes.season, + TableEpisodes.episode, + TableEpisodes.title.alias('episodeTitle'), + TableShows.year, + TableShows.tvdbId, + TableShows.alternateTitles, + TableEpisodes.format, + TableEpisodes.resolution, + TableEpisodes.video_codec, + TableEpisodes.audio_codec, + TableEpisodes.path, + TableShows.imdbId)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))\ + .dicts() + + if len(data): + data = data[0] video.series = re.sub(r'\s(\(\d\d\d\d\))', '', data['seriesTitle']) video.season = int(data['season']) video.episode = int(data['episode']) @@ -1224,11 +1299,19 @@ def refine_from_db(path, video): if not video.audio_codec: if data['audio_codec']: video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec']) elif isinstance(video, Movie): - data = database.execute("SELECT title, year, alternativeTitles, format, resolution, video_codec, audio_codec, " - "imdbId FROM table_movies WHERE path = ?", - (path_mappings.path_replace_reverse_movie(path),), only_one=True) - - if data: + data = TableMovies.select(TableMovies.title, + TableMovies.year, + TableMovies.alternativeTitles, + TableMovies.format, + TableMovies.resolution, + TableMovies.video_codec, + TableMovies.audio_codec, + TableMovies.imdbId)\ + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ + .dicts() + + if len(data): + data = data[0] video.title = re.sub(r'\s(\(\d\d\d\d\))', '', data['title']) # Commented out because Radarr provided so much bad year # if data['year']: @@ -1250,11 +1333,15 @@ def refine_from_db(path, video): def refine_from_ffprobe(path, video): if isinstance(video, Movie): - file_id = database.execute("SELECT movie_file_id FROM table_shows WHERE path = ?", - (path_mappings.path_replace_reverse_movie(path),), only_one=True) + file_id = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\ + .where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\ + .dicts()\ + .get() else: - file_id = database.execute("SELECT episode_file_id, file_size FROM table_episodes WHERE path = ?", - (path_mappings.path_replace_reverse(path),), only_one=True) + file_id = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\ + .where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\ + .dicts()\ + .get() if not isinstance(file_id, dict): return video @@ -1312,21 +1399,33 @@ def upgrade_subtitles(): query_actions = [1, 3] if settings.general.getboolean('use_sonarr'): - upgradable_episodes = database.execute("SELECT table_history.video_path, table_history.language, " - "table_history.score, table_shows.tags, table_shows.profileId, " - "table_episodes.audio_language, table_episodes.scene_name, " - "table_episodes.title, table_episodes.sonarrSeriesId, table_history.action, " - "table_history.subtitles_path, table_episodes.sonarrEpisodeId, " - "MAX(table_history.timestamp) as timestamp, table_episodes.monitored, " - "table_episodes.season, table_episodes.episode, table_shows.title as seriesTitle, " - "table_shows.seriesType FROM table_history INNER JOIN table_shows on " - "table_shows.sonarrSeriesId = table_history.sonarrSeriesId INNER JOIN " - "table_episodes on table_episodes.sonarrEpisodeId = " - "table_history.sonarrEpisodeId WHERE action IN " - "(" + ','.join(map(str, query_actions)) + ") AND timestamp > ? AND " - "score is not null" + get_exclusion_clause('series') + " GROUP BY " - "table_history.video_path, table_history.language", - (minimum_timestamp,)) + upgradable_episodes_conditions = [(TableHistory.action << query_actions), + (TableHistory.timestamp > minimum_timestamp), + (TableHistory.score is not None)] + upgradable_episodes_conditions += get_exclusion_clause('series') + upgradable_episodes = TableHistory.select(TableHistory.video_path, + TableHistory.language, + TableHistory.score, + TableShows.tags, + TableShows.profileId, + TableEpisodes.audio_language, + TableEpisodes.scene_name, + TableEpisodes.title, + TableEpisodes.sonarrSeriesId, + TableHistory.action, + TableHistory.subtitles_path, + TableEpisodes.sonarrEpisodeId, + fn.MAX(TableHistory.timestamp).alias('timestamp'), + TableEpisodes.monitored, + TableEpisodes.season, + TableEpisodes.episode, + TableShows.title.alias('seriesTitle'), + TableShows.seriesType)\ + .join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\ + .where(reduce(operator.and_, upgradable_episodes_conditions))\ + .group_by(TableHistory.video_path, TableHistory.language)\ + .dicts() upgradable_episodes_not_perfect = [] for upgradable_episode in upgradable_episodes: if upgradable_episode['timestamp'] > minimum_timestamp: @@ -1347,17 +1446,25 @@ def upgrade_subtitles(): count_episode_to_upgrade = len(episodes_to_upgrade) if settings.general.getboolean('use_radarr'): - upgradable_movies = database.execute("SELECT table_history_movie.video_path, table_history_movie.language, " - "table_history_movie.score, table_movies.profileId, table_history_movie.action, " - "table_history_movie.subtitles_path, table_movies.audio_language, " - "table_movies.sceneName, table_movies.title, table_movies.radarrId, " - "MAX(table_history_movie.timestamp) as timestamp, table_movies.tags, " - "table_movies.monitored FROM table_history_movie INNER JOIN table_movies " - "on table_movies.radarrId = table_history_movie.radarrId WHERE action IN " - "(" + ','.join(map(str, query_actions)) + ") AND timestamp > ? AND score " - "is not null" + get_exclusion_clause('movie') + " GROUP BY " - "table_history_movie.video_path, table_history_movie.language", - (minimum_timestamp,)) + upgradable_movies_conditions = [(TableHistoryMovie.action << query_actions), + (TableHistoryMovie.timestamp > minimum_timestamp), + (TableHistoryMovie.score is not None)] + upgradable_movies_conditions += get_exclusion_clause('movie') + upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path, + TableHistoryMovie.language, + TableHistoryMovie.score, + TableMovies.profileId, + TableHistoryMovie.action, + TableHistoryMovie.subtitles_path, + TableMovies.audio_language, + TableMovies.sceneName, + fn.MAX(TableHistoryMovie.timestamp).alias('timestamp'), + TableMovies.tags, + TableMovies.radarrId)\ + .join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\ + .where(reduce(operator.and_, upgradable_movies_conditions))\ + .group_by(TableHistoryMovie.video_path, TableHistoryMovie.language)\ + .dicts() upgradable_movies_not_perfect = [] for upgradable_movie in upgradable_movies: if upgradable_movie['timestamp'] > minimum_timestamp: diff --git a/bazarr/init.py b/bazarr/init.py index acdba946d..176f03bf5 100644 --- a/bazarr/init.py +++ b/bazarr/init.py @@ -126,27 +126,6 @@ if os.path.isfile(package_info_file): with open(os.path.join(args.config_dir, 'config', 'config.ini'), 'w+') as handle: settings.write(handle) -# create database file -if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')): - import sqlite3 - # Get SQL script from file - fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r') - script = fd.read() - # Close SQL script file - fd.close() - # Open database connection - db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30) - c = db.cursor() - # Execute script and commit change to database - c.executescript(script) - # Close database connection - db.close() - logging.info('BAZARR Database created successfully') - -# upgrade database schema -from database import db_upgrade -db_upgrade() - # Configure dogpile file caching for Subliminal request register_cache_backend("subzero.cache.file", "subzero.cache_backends.file", "SZFileBackend") subliminal.region.configure('subzero.cache.file', expiration_time=datetime.timedelta(days=30), @@ -175,41 +154,6 @@ with open(os.path.normpath(os.path.join(args.config_dir, 'config', 'config.ini') settings.write(handle) -# Commenting out the password reset process as it could be having unwanted effects and most of the users have already -# moved to new password hashing algorithm. - -# Reset form login password for Bazarr after migration from 0.8.x to 0.9. Password will be equal to username. -# if settings.auth.type == 'form' and \ -# os.path.exists(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))): -# username = False -# with open(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))) as json_file: -# try: -# data = json.load(json_file) -# username = next(iter(data)) -# except: -# logging.error('BAZARR is unable to migrate credentials. You should disable login by modifying config.ini ' -# 'file and settings [auth]-->type = None') -# if username: -# settings.auth.username = username -# settings.auth.password = hashlib.md5(username.encode('utf-8')).hexdigest() -# with open(os.path.join(args.config_dir, 'config', 'config.ini'), 'w+') as handle: -# settings.write(handle) -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))) -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'roles.json'))) -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'register.json'))) -# logging.info('BAZARR your login credentials have been migrated successfully and your password is now equal ' -# 'to your username. Please change it as soon as possible in settings.') -# else: -# if os.path.exists(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))): -# try: -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))) -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'roles.json'))) -# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'register.json'))) -# except: -# logging.error("BAZARR cannot delete those file. Please do it manually: users.json, roles.json, " -# "register.json") - - def init_binaries(): from utils import get_binary exe = get_binary("unrar") diff --git a/bazarr/list_subtitles.py b/bazarr/list_subtitles.py index 4597c02c2..185348143 100644 --- a/bazarr/list_subtitles.py +++ b/bazarr/list_subtitles.py @@ -9,7 +9,7 @@ from guess_language import guess_language from subliminal_patch import core, search_external_subtitles from subzero.language import Language -from database import database, get_profiles_list, get_profile_cutoff +from database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, TableMovies from get_languages import alpha2_from_alpha3, language_from_alpha2, get_language_set from config import settings from helper import path_mappings, get_subtitle_destination_folder @@ -31,8 +31,10 @@ def store_subtitles(original_path, reversed_path): if settings.general.getboolean('use_embedded_subs'): logging.debug("BAZARR is trying to index embedded subtitles.") try: - item = database.execute('SELECT file_size, episode_file_id FROM table_episodes ' - 'WHERE path = ?', (original_path,), only_one=True) + item = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\ + .where(TableEpisodes.path == original_path)\ + .dicts()\ + .get() subtitle_languages = embedded_subs_reader(reversed_path, file_size=item['file_size'], episode_file_id=item['episode_file_id']) @@ -125,10 +127,12 @@ def store_subtitles(original_path, reversed_path): logging.debug("BAZARR external subtitles detected: " + language_str) actual_subtitles.append([language_str, path_mappings.path_replace_reverse(subtitle_path)]) - database.execute("UPDATE table_episodes SET subtitles=? WHERE path=?", - (str(actual_subtitles), original_path)) - matching_episodes = database.execute("SELECT sonarrEpisodeId, sonarrSeriesId FROM table_episodes WHERE path=?", - (original_path,)) + TableEpisodes.update({TableEpisodes.subtitles: str(actual_subtitles)})\ + .where(TableEpisodes.path == original_path)\ + .execute() + matching_episodes = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId)\ + .where(TableEpisodes.path == original_path)\ + .dicts() for episode in matching_episodes: if episode: @@ -151,8 +155,10 @@ def store_subtitles_movie(original_path, reversed_path): if settings.general.getboolean('use_embedded_subs'): logging.debug("BAZARR is trying to index embedded subtitles.") try: - item = database.execute('SELECT file_size, movie_file_id FROM table_movies ' - 'WHERE path = ?', (original_path,), only_one=True) + item = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\ + .where(TableMovies.path == original_path)\ + .dicts()\ + .get() subtitle_languages = embedded_subs_reader(reversed_path, file_size=item['file_size'], movie_file_id=item['movie_file_id']) @@ -237,9 +243,10 @@ def store_subtitles_movie(original_path, reversed_path): logging.debug("BAZARR external subtitles detected: " + language_str) actual_subtitles.append([language_str, path_mappings.path_replace_reverse_movie(subtitle_path)]) - database.execute("UPDATE table_movies SET subtitles=? WHERE path=?", - (str(actual_subtitles), original_path)) - matching_movies = database.execute("SELECT radarrId FROM table_movies WHERE path=?", (original_path,)) + TableMovies.update({TableMovies.subtitles: str(actual_subtitles)})\ + .where(TableMovies.path == original_path)\ + .execute() + matching_movies = TableMovies.select(TableMovies.radarrId).where(TableMovies.path == original_path).dicts() for movie in matching_movies: if movie: @@ -257,16 +264,19 @@ def store_subtitles_movie(original_path, reversed_path): def list_missing_subtitles(no=None, epno=None, send_event=True): if epno is not None: - episodes_subtitles_clause = " WHERE table_episodes.sonarrEpisodeId=" + str(epno) + episodes_subtitles_clause = (TableEpisodes.sonarrEpisodeId == epno) elif no is not None: - episodes_subtitles_clause = " WHERE table_episodes.sonarrSeriesId=" + str(no) + episodes_subtitles_clause = (TableEpisodes.sonarrSeriesId == no) else: - episodes_subtitles_clause = "" - episodes_subtitles = database.execute("SELECT table_shows.sonarrSeriesId, table_episodes.sonarrEpisodeId, " - "table_episodes.subtitles, table_shows.profileId, " - "table_episodes.audio_language FROM table_episodes " - "LEFT JOIN table_shows on table_episodes.sonarrSeriesId = " - "table_shows.sonarrSeriesId" + episodes_subtitles_clause) + episodes_subtitles_clause = () + episodes_subtitles = TableEpisodes.select(TableShows.sonarrSeriesId, + TableEpisodes.sonarrEpisodeId, + TableEpisodes.subtitles, + TableShows.profileId, + TableEpisodes.audio_language)\ + .join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\ + .where(episodes_subtitles_clause)\ + .dicts() if isinstance(episodes_subtitles, str): logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + episodes_subtitles) return @@ -361,27 +371,26 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): missing_subtitles_text = str(missing_subtitles_output_list) - database.execute("UPDATE table_episodes SET missing_subtitles=? WHERE sonarrEpisodeId=?", - (missing_subtitles_text, episode_subtitles['sonarrEpisodeId'])) + TableEpisodes.update({TableEpisodes.missing_subtitles: missing_subtitles_text})\ + .where(TableEpisodes.sonarrEpisodeId == episode_subtitles['sonarrEpisodeId'])\ + .execute() if send_event: event_stream(type='episode', payload=episode_subtitles['sonarrEpisodeId']) event_stream(type='badges') -def list_missing_subtitles_movies(no=None, epno=None, send_event=True): - if no is not None: - movies_subtitles_clause = " WHERE radarrId=" + str(no) - else: - movies_subtitles_clause = "" - - movies_subtitles = database.execute("SELECT radarrId, subtitles, profileId, audio_language FROM table_movies" + - movies_subtitles_clause) +def list_missing_subtitles_movies(no=None, send_event=True): + movies_subtitles = TableMovies.select(TableMovies.radarrId, + TableMovies.subtitles, + TableMovies.profileId, + TableMovies.audio_language)\ + .where((TableMovies.radarrId == no) if no else None)\ + .dicts() if isinstance(movies_subtitles, str): logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + movies_subtitles) return - use_embedded_subs = settings.general.getboolean('use_embedded_subs') for movie_subtitles in movies_subtitles: @@ -470,8 +479,9 @@ def list_missing_subtitles_movies(no=None, epno=None, send_event=True): missing_subtitles_text = str(missing_subtitles_output_list) - database.execute("UPDATE table_movies SET missing_subtitles=? WHERE radarrId=?", - (missing_subtitles_text, movie_subtitles['radarrId'])) + TableMovies.update({TableMovies.missing_subtitles: missing_subtitles_text})\ + .where(TableMovies.radarrId == movie_subtitles['radarrId'])\ + .execute() if send_event: event_stream(type='movie', payload=movie_subtitles['radarrId']) @@ -479,7 +489,7 @@ def list_missing_subtitles_movies(no=None, epno=None, send_event=True): def series_full_scan_subtitles(): - episodes = database.execute("SELECT path FROM table_episodes") + episodes = TableEpisodes.select(TableEpisodes.path).dicts() count_episodes = len(episodes) for i, episode in enumerate(episodes, 1): @@ -502,7 +512,7 @@ def series_full_scan_subtitles(): def movies_full_scan_subtitles(): - movies = database.execute("SELECT path FROM table_movies") + movies = TableMovies.select(TableMovies.path).dicts() count_movies = len(movies) for i, movie in enumerate(movies, 1): @@ -525,15 +535,20 @@ def movies_full_scan_subtitles(): def series_scan_subtitles(no): - episodes = database.execute("SELECT path FROM table_episodes WHERE sonarrSeriesId=? ORDER BY sonarrEpisodeId", - (no,)) + episodes = TableEpisodes.select(TableEpisodes.path)\ + .where(TableEpisodes.sonarrSeriesId == no)\ + .order_by(TableEpisodes.sonarrEpisodeId)\ + .dicts() for episode in episodes: store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) def movies_scan_subtitles(no): - movies = database.execute("SELECT path FROM table_movies WHERE radarrId=? ORDER BY radarrId", (no,)) + movies = TableMovies.select(TableMovies.path)\ + .where(TableMovies.radarrId == no)\ + .order_by(TableMovies.radarrId)\ + .dicts() for movie in movies: store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path'])) diff --git a/bazarr/main.py b/bazarr/main.py index ff8e5d13d..0f029b326 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -26,7 +26,7 @@ from get_args import args from config import settings, url_sonarr, url_radarr, configure_proxy_func, base_url from init import * -from database import database +from database import System from notifier import update_notifier @@ -53,7 +53,7 @@ check_releases() configure_proxy_func() # Reset the updated once Bazarr have been restarted after an update -database.execute("UPDATE system SET updated='0'") +System.update({System.updated: '0'}).execute() # Load languages in database load_language_in_db() @@ -100,14 +100,14 @@ def catch_all(path): auth = False try: - updated = database.execute("SELECT updated FROM system", only_one=True)['updated'] + updated = System.select().where(updated='1') except: updated = False inject = dict() inject["baseUrl"] = base_url inject["canUpdate"] = not args.no_update - inject["hasUpdate"] = updated != '0' + inject["hasUpdate"] = len(updated) if auth: inject["apiKey"] = settings.auth.apikey @@ -164,7 +164,7 @@ def movies_images(url): def configured(): - database.execute("UPDATE system SET configured = 1") + System.update({System.configured: '1'}).execute() @check_login diff --git a/bazarr/notifier.py b/bazarr/notifier.py index d521dab80..ab8f945e1 100644 --- a/bazarr/notifier.py +++ b/bazarr/notifier.py @@ -3,7 +3,7 @@ import apprise import logging -from database import database +from database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies def update_notifier(): @@ -16,7 +16,7 @@ def update_notifier(): notifiers_new = [] notifiers_old = [] - notifiers_current_db = database.execute("SELECT name FROM table_settings_notifier") + notifiers_current_db = TableSettingsNotifier.select(TableSettingsNotifier.name).dicts() notifiers_current = [] for notifier in notifiers_current_db: @@ -24,41 +24,51 @@ def update_notifier(): for x in results['schemas']: if [x['service_name']] not in notifiers_current: - notifiers_new.append([x['service_name'], 0]) + notifiers_new.append({'name': x['service_name'], 'enabled': 0}) logging.debug('Adding new notifier agent: ' + x['service_name']) else: notifiers_old.append([x['service_name']]) notifiers_to_delete = [item for item in notifiers_current if item not in notifiers_old] - database.execute("INSERT INTO table_settings_notifier (name, enabled) VALUES (?, ?)", notifiers_new, - execute_many=True) + TableSettingsNotifier.insert_many(notifiers_new).execute() - database.execute("DELETE FROM table_settings_notifier WHERE name=?", notifiers_to_delete, execute_many=True) + for item in notifiers_to_delete: + TableSettingsNotifier.delete().where(TableSettingsNotifier.name == item).execute() def get_notifier_providers(): - providers = database.execute("SELECT name, url FROM table_settings_notifier WHERE enabled=1") + providers = TableSettingsNotifier.select(TableSettingsNotifier.name, + TableSettingsNotifier.enabled)\ + .where(TableSettingsNotifier.enabled == 1)\ + .dicts() return providers def get_series(sonarr_series_id): - data = database.execute("SELECT title, year FROM table_shows WHERE sonarrSeriesId=?", (sonarr_series_id,), - only_one=True) + data = TableShows.select(TableShows.title, TableShows.year)\ + .where(TableShows.sonarrSeriesId == sonarr_series_id)\ + .dicts()\ + .get() return {'title': data['title'], 'year': data['year']} def get_episode_name(sonarr_episode_id): - data = database.execute("SELECT title, season, episode FROM table_episodes WHERE sonarrEpisodeId=?", - (sonarr_episode_id,), only_one=True) + data = TableEpisodes.select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode)\ + .where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\ + .dicts()\ + .get() return data['title'], data['season'], data['episode'] def get_movie(radarr_id): - data = database.execute("SELECT title, year FROM table_movies WHERE radarrId=?", (radarr_id,), only_one=True) + data = TableMovies.select(TableMovies.title, TableMovies.year)\ + .where(TableMovies.radarrId == radarr_id)\ + .dicts()\ + .get() return {'title': data['title'], 'year': data['year']} diff --git a/bazarr/signalr_client.py b/bazarr/signalr_client.py index eb027ca19..3917a3c1a 100644 --- a/bazarr/signalr_client.py +++ b/bazarr/signalr_client.py @@ -56,8 +56,9 @@ class SonarrSignalrClient: logging.info('BAZARR SignalR client for Sonarr is now disconnected.') def restart(self): - if self.connection.is_open: - self.stop(log=False) + if self.connection: + if self.connection.is_open: + self.stop(log=False) if settings.general.getboolean('use_sonarr'): self.start() @@ -94,8 +95,9 @@ class RadarrSignalrClient: self.connection.stop() def restart(self): - if self.connection.transport.state.value in [0, 1, 2]: - self.stop() + if self.connection: + if self.connection.transport.state.value in [0, 1, 2]: + self.stop() if settings.general.getboolean('use_radarr'): self.start() @@ -131,35 +133,38 @@ class RadarrSignalrClient: def dispatcher(data): - topic = media_id = action = None - episodesChanged = None - if isinstance(data, dict): - topic = data['name'] - try: - media_id = data['body']['resource']['id'] - action = data['body']['action'] - if 'episodesChanged' in data['body']['resource']: - episodesChanged = data['body']['resource']['episodesChanged'] - except KeyError: - return - elif isinstance(data, list): - topic = data[0]['name'] - try: - media_id = data[0]['body']['resource']['id'] - action = data[0]['body']['action'] - except KeyError: - return - - if topic == 'series': - update_one_series(series_id=media_id, action=action) - if episodesChanged: - # this will happen if an episode monitored status is changed. - sync_episodes(series_id=media_id, send_event=True) - elif topic == 'episode': - sync_one_episode(episode_id=media_id) - elif topic == 'movie': - update_one_movie(movie_id=media_id, action=action) - else: + try: + topic = media_id = action = None + episodesChanged = None + if isinstance(data, dict): + topic = data['name'] + try: + media_id = data['body']['resource']['id'] + action = data['body']['action'] + if 'episodesChanged' in data['body']['resource']: + episodesChanged = data['body']['resource']['episodesChanged'] + except KeyError: + return + elif isinstance(data, list): + topic = data[0]['name'] + try: + media_id = data[0]['body']['resource']['id'] + action = data[0]['body']['action'] + except KeyError: + return + + if topic == 'series': + update_one_series(series_id=media_id, action=action) + if episodesChanged: + # this will happen if an episode monitored status is changed. + sync_episodes(series_id=media_id, send_event=True) + elif topic == 'episode': + sync_one_episode(episode_id=media_id) + elif topic == 'movie': + update_one_movie(movie_id=media_id, action=action) + except Exception as e: + logging.debug('BAZARR an exception occurred while parsing SignalR feed: {}'.format(repr(e))) + finally: return diff --git a/bazarr/utils.py b/bazarr/utils.py index 7c58de412..f29a101a8 100644 --- a/bazarr/utils.py +++ b/bazarr/utils.py @@ -13,7 +13,8 @@ import stat from whichcraft import which from get_args import args from config import settings, url_sonarr, url_radarr -from database import database +from database import TableHistory, TableHistoryMovie, TableBlacklist, TableBlacklistMovie, TableShowsRootfolder, \ + TableMoviesRootfolder from event_handler import event_stream from get_languages import alpha2_from_alpha3, language_from_alpha3, language_from_alpha2, alpha3_from_alpha2 from helper import path_mappings @@ -37,51 +38,83 @@ class BinaryNotFound(Exception): def history_log(action, sonarr_series_id, sonarr_episode_id, description, video_path=None, language=None, provider=None, score=None, subs_id=None, subtitles_path=None): - database.execute("INSERT INTO table_history (action, sonarrSeriesId, sonarrEpisodeId, timestamp, description," - "video_path, language, provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?,?)", - (action, sonarr_series_id, sonarr_episode_id, time.time(), description, video_path, language, - provider, score, subs_id, subtitles_path)) + TableHistory.insert({ + TableHistory.action: action, + TableHistory.sonarrSeriesId: sonarr_series_id, + TableHistory.sonarrEpisodeId: sonarr_episode_id, + TableHistory.timestamp: time.time(), + TableHistory.description: description, + TableHistory.video_path: video_path, + TableHistory.language: language, + TableHistory.provider: provider, + TableHistory.score: score, + TableHistory.subs_id: subs_id, + TableHistory.subtitles_path: subtitles_path + }).execute() event_stream(type='episode-history') def blacklist_log(sonarr_series_id, sonarr_episode_id, provider, subs_id, language): - database.execute("INSERT INTO table_blacklist (sonarr_series_id, sonarr_episode_id, timestamp, provider, " - "subs_id, language) VALUES (?,?,?,?,?,?)", - (sonarr_series_id, sonarr_episode_id, time.time(), provider, subs_id, language)) + TableBlacklist.insert({ + TableBlacklist.sonarr_series_id: sonarr_series_id, + TableBlacklist.sonarr_episode_id: sonarr_episode_id, + TableBlacklist.timestamp: time.time(), + TableBlacklist.provider: provider, + TableBlacklist.subs_id: subs_id, + TableBlacklist.language: language + }).execute() event_stream(type='episode-blacklist') def blacklist_delete(provider, subs_id): - database.execute("DELETE FROM table_blacklist WHERE provider=? AND subs_id=?", (provider, subs_id)) + TableBlacklist.delete().where((TableBlacklist.provider == provider) and + (TableBlacklist.subs_id == subs_id))\ + .execute() event_stream(type='episode-blacklist', action='delete') def blacklist_delete_all(): - database.execute("DELETE FROM table_blacklist") + TableBlacklist.delete().execute() event_stream(type='episode-blacklist', action='delete') def history_log_movie(action, radarr_id, description, video_path=None, language=None, provider=None, score=None, subs_id=None, subtitles_path=None): - database.execute("INSERT INTO table_history_movie (action, radarrId, timestamp, description, video_path, language, " - "provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?)", - (action, radarr_id, time.time(), description, video_path, language, provider, score, subs_id, subtitles_path)) + TableHistoryMovie.insert({ + TableHistoryMovie.action: action, + TableHistoryMovie.radarrId: radarr_id, + TableHistoryMovie.timestamp: time.time(), + TableHistoryMovie.description: description, + TableHistoryMovie.video_path: video_path, + TableHistoryMovie.language: language, + TableHistoryMovie.provider: provider, + TableHistoryMovie.score: score, + TableHistoryMovie.subs_id: subs_id, + TableHistoryMovie.subtitles_path: subtitles_path + }).execute() event_stream(type='movie-history') def blacklist_log_movie(radarr_id, provider, subs_id, language): - database.execute("INSERT INTO table_blacklist_movie (radarr_id, timestamp, provider, subs_id, language) " - "VALUES (?,?,?,?,?)", (radarr_id, time.time(), provider, subs_id, language)) + TableBlacklistMovie.insert({ + TableBlacklistMovie.radarr_id: radarr_id, + TableBlacklistMovie.timestamp: time.time(), + TableBlacklistMovie.provider: provider, + TableBlacklistMovie.subs_id: subs_id, + TableBlacklistMovie.language: language + }).execute() event_stream(type='movie-blacklist') def blacklist_delete_movie(provider, subs_id): - database.execute("DELETE FROM table_blacklist_movie WHERE provider=? AND subs_id=?", (provider, subs_id)) + TableBlacklistMovie.delete().where((TableBlacklistMovie.provider == provider) and + (TableBlacklistMovie.subs_id == subs_id))\ + .execute() event_stream(type='movie-blacklist', action='delete') def blacklist_delete_all_movie(): - database.execute("DELETE FROM table_blacklist_movie") + TableBlacklistMovie.delete().execute() event_stream(type='movie-blacklist', action='delete') @@ -167,9 +200,9 @@ def get_binary(name): def get_blacklist(media_type): if media_type == 'series': - blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist") + blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts() else: - blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist_movie") + blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts() blacklist_list = [] for item in blacklist_db: @@ -427,15 +460,22 @@ def get_health_issues(): # get Sonarr rootfolder issues if settings.general.getboolean('use_sonarr'): - rootfolder = database.execute('SELECT path, accessible, error FROM table_shows_rootfolder WHERE accessible = 0') + rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.path, + TableShowsRootfolder.accessible, + TableShowsRootfolder.error)\ + .where(TableShowsRootfolder.accessible == 0)\ + .dicts() for item in rootfolder: health_issues.append({'object': path_mappings.path_replace(item['path']), 'issue': item['error']}) # get Radarr rootfolder issues if settings.general.getboolean('use_radarr'): - rootfolder = database.execute('SELECT path, accessible, error FROM table_movies_rootfolder ' - 'WHERE accessible = 0') + rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.path, + TableMoviesRootfolder.accessible, + TableMoviesRootfolder.error)\ + .where(TableMoviesRootfolder.accessible == 0)\ + .dicts() for item in rootfolder: health_issues.append({'object': path_mappings.path_replace_movie(item['path']), 'issue': item['error']}) diff --git a/libs/peewee.py b/libs/peewee.py new file mode 100644 index 000000000..ad5eb7f28 --- /dev/null +++ b/libs/peewee.py @@ -0,0 +1,7746 @@ +from bisect import bisect_left +from bisect import bisect_right +from contextlib import contextmanager +from copy import deepcopy +from functools import wraps +from inspect import isclass +import calendar +import collections +import datetime +import decimal +import hashlib +import itertools +import logging +import operator +import re +import socket +import struct +import sys +import threading +import time +import uuid +import warnings +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + +try: + from pysqlite3 import dbapi2 as pysq3 +except ImportError: + try: + from pysqlite2 import dbapi2 as pysq3 + except ImportError: + pysq3 = None +try: + import sqlite3 +except ImportError: + sqlite3 = pysq3 +else: + if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: + sqlite3 = pysq3 +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass +try: + import psycopg2 + from psycopg2 import extensions as pg_extensions + try: + from psycopg2 import errors as pg_errors + except ImportError: + pg_errors = None +except ImportError: + psycopg2 = pg_errors = None +try: + from psycopg2.extras import register_uuid as pg_register_uuid + pg_register_uuid() +except Exception: + pass + +mysql_passwd = False +try: + import pymysql as mysql +except ImportError: + try: + import MySQLdb as mysql + mysql_passwd = True + except ImportError: + mysql = None + + +__version__ = '3.14.4' +__all__ = [ + 'AsIs', + 'AutoField', + 'BareField', + 'BigAutoField', + 'BigBitField', + 'BigIntegerField', + 'BinaryUUIDField', + 'BitField', + 'BlobField', + 'BooleanField', + 'Case', + 'Cast', + 'CharField', + 'Check', + 'chunked', + 'Column', + 'CompositeKey', + 'Context', + 'Database', + 'DatabaseError', + 'DatabaseProxy', + 'DataError', + 'DateField', + 'DateTimeField', + 'DecimalField', + 'DeferredForeignKey', + 'DeferredThroughModel', + 'DJANGO_MAP', + 'DoesNotExist', + 'DoubleField', + 'DQ', + 'EXCLUDED', + 'Field', + 'FixedCharField', + 'FloatField', + 'fn', + 'ForeignKeyField', + 'IdentityField', + 'ImproperlyConfigured', + 'Index', + 'IntegerField', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'IPField', + 'JOIN', + 'ManyToManyField', + 'Model', + 'ModelIndex', + 'MySQLDatabase', + 'NotSupportedError', + 'OP', + 'OperationalError', + 'PostgresqlDatabase', + 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. + 'prefetch', + 'ProgrammingError', + 'Proxy', + 'QualifiedNames', + 'SchemaManager', + 'SmallIntegerField', + 'Select', + 'SQL', + 'SqliteDatabase', + 'Table', + 'TextField', + 'TimeField', + 'TimestampField', + 'Tuple', + 'UUIDField', + 'Value', + 'ValuesList', + 'Window', +] + +try: # Python 2.7+ + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logger = logging.getLogger('peewee') +logger.addHandler(NullHandler()) + + +if sys.version_info[0] == 2: + text_type = unicode + bytes_type = str + buffer_type = buffer + izip_longest = itertools.izip_longest + callable_ = callable + multi_types = (list, tuple, frozenset, set) + exec('def reraise(tp, value, tb=None): raise tp, value, tb') + def print_(s): + sys.stdout.write(s) + sys.stdout.write('\n') +else: + import builtins + try: + from collections.abc import Callable + except ImportError: + from collections import Callable + from functools import reduce + callable_ = lambda c: isinstance(c, Callable) + text_type = str + bytes_type = bytes + buffer_type = memoryview + basestring = str + long = int + multi_types = (list, tuple, frozenset, set, range) + print_ = getattr(builtins, 'print') + izip_longest = itertools.zip_longest + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + +if sqlite3: + sqlite3.register_adapter(decimal.Decimal, str) + sqlite3.register_adapter(datetime.date, str) + sqlite3.register_adapter(datetime.time, str) + __sqlite_version__ = sqlite3.sqlite_version_info +else: + __sqlite_version__ = (0, 0, 0) + + +__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) + +# Sqlite does not support the `date_part` SQL function, so we will define an +# implementation in python. +__sqlite_datetime_formats__ = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +__sqlite_date_trunc__ = { + 'year': '%Y-01-01 00:00:00', + 'month': '%Y-%m-01 00:00:00', + 'day': '%Y-%m-%d 00:00:00', + 'hour': '%Y-%m-%d %H:00:00', + 'minute': '%Y-%m-%d %H:%M:00', + 'second': '%Y-%m-%d %H:%M:%S'} + +__mysql_date_trunc__ = __sqlite_date_trunc__.copy() +__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' +__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' + +def _sqlite_date_part(lookup_type, datetime_string): + assert lookup_type in __date_parts__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return getattr(dt, lookup_type) + +def _sqlite_date_trunc(lookup_type, datetime_string): + assert lookup_type in __sqlite_date_trunc__ + if not datetime_string: + return + dt = format_date_time(datetime_string, __sqlite_datetime_formats__) + return dt.strftime(__sqlite_date_trunc__[lookup_type]) + + +def __deprecated__(s): + warnings.warn(s, DeprecationWarning) + + +class attrdict(dict): + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + raise AttributeError(attr) + def __setattr__(self, attr, value): self[attr] = value + def __iadd__(self, rhs): self.update(rhs); return self + def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d + +SENTINEL = object() + +#: Operations for use in SQL expressions. +OP = attrdict( + AND='AND', + OR='OR', + ADD='+', + SUB='-', + MUL='*', + DIV='/', + BIN_AND='&', + BIN_OR='|', + XOR='#', + MOD='%', + EQ='=', + LT='<', + LTE='<=', + GT='>', + GTE='>=', + NE='!=', + IN='IN', + NOT_IN='NOT IN', + IS='IS', + IS_NOT='IS NOT', + LIKE='LIKE', + ILIKE='ILIKE', + BETWEEN='BETWEEN', + REGEXP='REGEXP', + IREGEXP='IREGEXP', + CONCAT='||', + BITWISE_NEGATION='~') + +# To support "django-style" double-underscore filters, create a mapping between +# operation name and operation code, e.g. "__eq" == OP.EQ. +DJANGO_MAP = attrdict({ + 'eq': operator.eq, + 'lt': operator.lt, + 'lte': operator.le, + 'gt': operator.gt, + 'gte': operator.ge, + 'ne': operator.ne, + 'in': operator.lshift, + 'is': lambda l, r: Expression(l, OP.IS, r), + 'like': lambda l, r: Expression(l, OP.LIKE, r), + 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), + 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), +}) + +#: Mapping of field type to the data-type supported by the database. Databases +#: may override or add to this list. +FIELD = attrdict( + AUTO='INTEGER', + BIGAUTO='BIGINT', + BIGINT='BIGINT', + BLOB='BLOB', + BOOL='SMALLINT', + CHAR='CHAR', + DATE='DATE', + DATETIME='DATETIME', + DECIMAL='DECIMAL', + DEFAULT='', + DOUBLE='REAL', + FLOAT='REAL', + INT='INTEGER', + SMALLINT='SMALLINT', + TEXT='TEXT', + TIME='TIME', + UUID='TEXT', + UUIDB='BLOB', + VARCHAR='VARCHAR') + +#: Join helpers (for convenience) -- all join types are supported, this object +#: is just to help avoid introducing errors by using strings everywhere. +JOIN = attrdict( + INNER='INNER JOIN', + LEFT_OUTER='LEFT OUTER JOIN', + RIGHT_OUTER='RIGHT OUTER JOIN', + FULL='FULL JOIN', + FULL_OUTER='FULL OUTER JOIN', + CROSS='CROSS JOIN', + NATURAL='NATURAL JOIN', + LATERAL='LATERAL', + LEFT_LATERAL='LEFT JOIN LATERAL') + +# Row representations. +ROW = attrdict( + TUPLE=1, + DICT=2, + NAMED_TUPLE=3, + CONSTRUCTOR=4, + MODEL=5) + +SCOPE_NORMAL = 1 +SCOPE_SOURCE = 2 +SCOPE_VALUES = 4 +SCOPE_CTE = 8 +SCOPE_COLUMN = 16 + +# Rules for parentheses around subqueries in compound select. +CSQ_PARENTHESES_NEVER = 0 +CSQ_PARENTHESES_ALWAYS = 1 +CSQ_PARENTHESES_UNNESTED = 2 + +# Regular expressions used to convert class names to snake-case table names. +# First regex handles acronym followed by word or initial lower-word followed +# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. +# Second regex handles the normal case of two title-cased words. +SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') +SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') + +# Helper functions that are used in various parts of the codebase. +MODEL_BASE = '_metaclass_helper_' + +def with_metaclass(meta, base=object): + return meta(MODEL_BASE, (base,), {}) + +def merge_dict(source, overrides): + merged = source.copy() + if overrides: + merged.update(overrides) + return merged + +def quote(path, quote_chars): + if len(path) == 1: + return path[0].join(quote_chars) + return '.'.join([part.join(quote_chars) for part in path]) + +is_model = lambda o: isclass(o) and issubclass(o, Model) + +def ensure_tuple(value): + if value is not None: + return value if isinstance(value, (list, tuple)) else (value,) + +def ensure_entity(value): + if value is not None: + return value if isinstance(value, Node) else Entity(value) + +def make_snake_case(s): + first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) + return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() + +def chunked(it, n): + marker = object() + for group in (list(g) for g in izip_longest(*[iter(it)] * n, + fillvalue=marker)): + if group[-1] is marker: + del group[group.index(marker):] + yield group + + +class _callable_context_manager(object): + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return inner + + +class Proxy(object): + """ + Create a proxy or placeholder for another object. + """ + __slots__ = ('obj', '_callbacks') + + def __init__(self): + self._callbacks = [] + self.initialize(None) + + def initialize(self, obj): + self.obj = obj + for callback in self._callbacks: + callback(obj) + + def attach_callback(self, callback): + self._callbacks.append(callback) + return callback + + def passthrough(method): + def inner(self, *args, **kwargs): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, method)(*args, **kwargs) + return inner + + # Allow proxy to be used as a context-manager. + __enter__ = passthrough('__enter__') + __exit__ = passthrough('__exit__') + + def __getattr__(self, attr): + if self.obj is None: + raise AttributeError('Cannot use uninitialized Proxy.') + return getattr(self.obj, attr) + + def __setattr__(self, attr, value): + if attr not in self.__slots__: + raise AttributeError('Cannot set attribute on proxy.') + return super(Proxy, self).__setattr__(attr, value) + + +class DatabaseProxy(Proxy): + """ + Proxy implementation specifically for proxying `Database` objects. + """ + def connection_context(self): + return ConnectionContext(self) + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + def manual_commit(self): + return _manual(self) + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + def savepoint(self): + return _savepoint(self) + + +class ModelDescriptor(object): pass + + +# SQL Generation. + + +class AliasManager(object): + __slots__ = ('_counter', '_current_index', '_mapping') + + def __init__(self): + # A list of dictionaries containing mappings at various depths. + self._counter = 0 + self._current_index = 0 + self._mapping = [] + self.push() + + @property + def mapping(self): + return self._mapping[self._current_index - 1] + + def add(self, source): + if source not in self.mapping: + self._counter += 1 + self[source] = 't%d' % self._counter + return self.mapping[source] + + def get(self, source, any_depth=False): + if any_depth: + for idx in reversed(range(self._current_index)): + if source in self._mapping[idx]: + return self._mapping[idx][source] + return self.add(source) + + def __getitem__(self, source): + return self.get(source) + + def __setitem__(self, source, alias): + self.mapping[source] = alias + + def push(self): + self._current_index += 1 + if self._current_index > len(self._mapping): + self._mapping.append({}) + + def pop(self): + if self._current_index == 1: + raise ValueError('Cannot pop() from empty alias manager.') + self._current_index -= 1 + + +class State(collections.namedtuple('_State', ('scope', 'parentheses', + 'settings'))): + def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): + return super(State, cls).__new__(cls, scope, parentheses, kwargs) + + def __call__(self, scope=None, parentheses=None, **kwargs): + # Scope and settings are "inherited" (parentheses is not, however). + scope = self.scope if scope is None else scope + + # Try to avoid unnecessary dict copying. + if kwargs and self.settings: + settings = self.settings.copy() # Copy original settings dict. + settings.update(kwargs) # Update copy with overrides. + elif kwargs: + settings = kwargs + else: + settings = self.settings + return State(scope, parentheses, **settings) + + def __getattr__(self, attr_name): + return self.settings.get(attr_name) + + +def __scope_context__(scope): + @contextmanager + def inner(self, **kwargs): + with self(scope=scope, **kwargs): + yield self + return inner + + +class Context(object): + __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') + + def __init__(self, **settings): + self.stack = [] + self._sql = [] + self._values = [] + self.alias_manager = AliasManager() + self.state = State(**settings) + + def as_new(self): + return Context(**self.state.settings) + + def column_sort_key(self, item): + return item[0].get_sort_key(self) + + @property + def scope(self): + return self.state.scope + + @property + def parentheses(self): + return self.state.parentheses + + @property + def subquery(self): + return self.state.subquery + + def __call__(self, **overrides): + if overrides and overrides.get('scope') == self.scope: + del overrides['scope'] + + self.stack.append(self.state) + self.state = self.state(**overrides) + return self + + scope_normal = __scope_context__(SCOPE_NORMAL) + scope_source = __scope_context__(SCOPE_SOURCE) + scope_values = __scope_context__(SCOPE_VALUES) + scope_cte = __scope_context__(SCOPE_CTE) + scope_column = __scope_context__(SCOPE_COLUMN) + + def __enter__(self): + if self.parentheses: + self.literal('(') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.parentheses: + self.literal(')') + self.state = self.stack.pop() + + @contextmanager + def push_alias(self): + self.alias_manager.push() + yield + self.alias_manager.pop() + + def sql(self, obj): + if isinstance(obj, (Node, Context)): + return obj.__sql__(self) + elif is_model(obj): + return obj._meta.table.__sql__(self) + else: + return self.sql(Value(obj)) + + def literal(self, keyword): + self._sql.append(keyword) + return self + + def value(self, value, converter=None, add_param=True): + if converter: + value = converter(value) + elif converter is None and self.state.converter: + # Explicitly check for None so that "False" can be used to signify + # that no conversion should be applied. + value = self.state.converter(value) + + if isinstance(value, Node): + with self(converter=None): + return self.sql(value) + elif is_model(value): + # Under certain circumstances, we could end-up treating a model- + # class itself as a value. This check ensures that we drop the + # table alias into the query instead of trying to parameterize a + # model (for instance, passing a model as a function argument). + with self.scope_column(): + return self.sql(value) + + self._values.append(value) + return self.literal(self.state.param or '?') if add_param else self + + def __sql__(self, ctx): + ctx._sql.extend(self._sql) + ctx._values.extend(self._values) + return ctx + + def parse(self, node): + return self.sql(node).query() + + def query(self): + return ''.join(self._sql), self._values + + +def query_to_string(query): + # NOTE: this function is not exported by default as it might be misused -- + # and this misuse could lead to sql injection vulnerabilities. This + # function is intended for debugging or logging purposes ONLY. + db = getattr(query, '_database', None) + if db is not None: + ctx = db.get_sql_context() + else: + ctx = Context() + + sql, params = ctx.sql(query).query() + if not params: + return sql + + param = ctx.state.param or '?' + if param == '?': + sql = sql.replace('?', '%s') + + return sql % tuple(map(_query_val_transform, params)) + +def _query_val_transform(v): + # Interpolate parameters. + if isinstance(v, (text_type, datetime.datetime, datetime.date, + datetime.time)): + v = "'%s'" % v + elif isinstance(v, bytes_type): + try: + v = v.decode('utf8') + except UnicodeDecodeError: + v = v.decode('raw_unicode_escape') + v = "'%s'" % v + elif isinstance(v, int): + v = '%s' % int(v) # Also handles booleans -> 1 or 0. + elif v is None: + v = 'NULL' + else: + v = str(v) + return v + + +# AST. + + +class Node(object): + _coerce = True + + def clone(self): + obj = self.__class__.__new__(self.__class__) + obj.__dict__ = self.__dict__.copy() + return obj + + def __sql__(self, ctx): + raise NotImplementedError + + @staticmethod + def copy(method): + def inner(self, *args, **kwargs): + clone = self.clone() + method(clone, *args, **kwargs) + return clone + return inner + + def coerce(self, _coerce=True): + if _coerce != self._coerce: + clone = self.clone() + clone._coerce = _coerce + return clone + return self + + def is_alias(self): + return False + + def unwrap(self): + return self + + +class ColumnFactory(object): + __slots__ = ('node',) + + def __init__(self, node): + self.node = node + + def __getattr__(self, attr): + return Column(self.node, attr) + + +class _DynamicColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + return ColumnFactory(instance) # Implements __getattr__(). + return self + + +class _ExplicitColumn(object): + __slots__ = () + + def __get__(self, instance, instance_type=None): + if instance is not None: + raise AttributeError( + '%s specifies columns explicitly, and does not support ' + 'dynamic column lookups.' % instance) + return self + + +class Source(Node): + c = _DynamicColumn() + + def __init__(self, alias=None): + super(Source, self).__init__() + self._alias = alias + + @Node.copy + def alias(self, name): + self._alias = name + + def select(self, *columns): + if not columns: + columns = (SQL('*'),) + return Select((self,), columns) + + def join(self, dest, join_type=JOIN.INNER, on=None): + return Join(self, dest, join_type, on) + + def left_outer_join(self, dest, on=None): + return Join(self, dest, JOIN.LEFT_OUTER, on) + + def cte(self, name, recursive=False, columns=None, materialized=None): + return CTE(name, self, recursive=recursive, columns=columns, + materialized=materialized) + + def get_sort_key(self, ctx): + if self._alias: + return (self._alias,) + return (ctx.alias_manager[self],) + + def apply_alias(self, ctx): + # If we are defining the source, include the "AS alias" declaration. An + # alias is created for the source if one is not already defined. + if ctx.scope == SCOPE_SOURCE: + if self._alias: + ctx.alias_manager[self] = self._alias + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + return ctx + + def apply_column(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class _HashableSource(object): + def __init__(self, *args, **kwargs): + super(_HashableSource, self).__init__(*args, **kwargs) + self._update_hash() + + @Node.copy + def alias(self, name): + self._alias = name + self._update_hash() + + def _update_hash(self): + self._hash = self._get_hash() + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias)) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if isinstance(other, _HashableSource): + return self._hash == other._hash + return Expression(self, OP.EQ, other) + + def __ne__(self, other): + if isinstance(other, _HashableSource): + return self._hash != other._hash + return Expression(self, OP.NE, other) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, rhs) + return inner + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + + +def __bind_database__(meth): + @wraps(meth) + def inner(self, *args, **kwargs): + result = meth(self, *args, **kwargs) + if self._database: + return result.bind(self._database) + return result + return inner + + +def __join__(join_type=JOIN.INNER, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return Join(self, other, join_type=join_type) + return method + + +class BaseTable(Source): + __and__ = __join__(JOIN.INNER) + __add__ = __join__(JOIN.LEFT_OUTER) + __sub__ = __join__(JOIN.RIGHT_OUTER) + __or__ = __join__(JOIN.FULL_OUTER) + __mul__ = __join__(JOIN.CROSS) + __rand__ = __join__(JOIN.INNER, inverted=True) + __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) + __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) + __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) + __rmul__ = __join__(JOIN.CROSS, inverted=True) + + +class _BoundTableContext(_callable_context_manager): + def __init__(self, table, database): + self.table = table + self.database = database + + def __enter__(self): + self._orig_database = self.table._database + self.table.bind(self.database) + if self.table._model is not None: + self.table._model.bind(self.database) + return self.table + + def __exit__(self, exc_type, exc_val, exc_tb): + self.table.bind(self._orig_database) + if self.table._model is not None: + self.table._model.bind(self._orig_database) + + +class Table(_HashableSource, BaseTable): + def __init__(self, name, columns=None, primary_key=None, schema=None, + alias=None, _model=None, _database=None): + self.__name__ = name + self._columns = columns + self._primary_key = primary_key + self._schema = schema + self._path = (schema, name) if schema else (name,) + self._model = _model + self._database = _database + super(Table, self).__init__(alias=alias) + + # Allow tables to restrict what columns are available. + if columns is not None: + self.c = _ExplicitColumn() + for column in columns: + setattr(self, column, Column(self, column)) + + if primary_key: + col_src = self if self._columns else self.c + self.primary_key = getattr(col_src, primary_key) + else: + self.primary_key = None + + def clone(self): + # Ensure a deep copy of the column instances. + return Table( + self.__name__, + columns=self._columns, + primary_key=self._primary_key, + schema=self._schema, + alias=self._alias, + _model=self._model, + _database=self._database) + + def bind(self, database=None): + self._database = database + return self + + def bind_ctx(self, database=None): + return _BoundTableContext(self, database) + + def _get_hash(self): + return hash((self.__class__, self._path, self._alias, self._model)) + + @__bind_database__ + def select(self, *columns): + if not columns and self._columns: + columns = [Column(self, column) for column in self._columns] + return Select((self,), columns) + + @__bind_database__ + def insert(self, insert=None, columns=None, **kwargs): + if kwargs: + insert = {} if insert is None else insert + src = self if self._columns else self.c + for key, value in kwargs.items(): + insert[getattr(src, key)] = value + return Insert(self, insert=insert, columns=columns) + + @__bind_database__ + def replace(self, insert=None, columns=None, **kwargs): + return (self + .insert(insert=insert, columns=columns) + .on_conflict('REPLACE')) + + @__bind_database__ + def update(self, update=None, **kwargs): + if kwargs: + update = {} if update is None else update + for key, value in kwargs.items(): + src = self if self._columns else self.c + update[getattr(src, key)] = value + return Update(self, update=update) + + @__bind_database__ + def delete(self): + return Delete(self) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(Entity(*self._path)) + + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return self.apply_alias(ctx.sql(Entity(*self._path))) + else: + # Refer to the table using the alias. + return self.apply_column(ctx) + + +class Join(BaseTable): + def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): + super(Join, self).__init__(alias=alias) + self.lhs = lhs + self.rhs = rhs + self.join_type = join_type + self._on = on + + def on(self, predicate): + self._on = predicate + return self + + def __sql__(self, ctx): + (ctx + .sql(self.lhs) + .literal(' %s ' % self.join_type) + .sql(self.rhs)) + if self._on is not None: + ctx.literal(' ON ').sql(self._on) + return ctx + + +class ValuesList(_HashableSource, BaseTable): + def __init__(self, values, columns=None, alias=None): + self._values = values + self._columns = columns + super(ValuesList, self).__init__(alias=alias) + + def _get_hash(self): + return hash((self.__class__, id(self._values), self._alias)) + + @Node.copy + def columns(self, *names): + self._columns = names + + def __sql__(self, ctx): + if self._alias: + ctx.alias_manager[self] = self._alias + + if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: + with ctx(parentheses=not ctx.parentheses): + ctx = (ctx + .literal('VALUES ') + .sql(CommaNodeList([ + EnclosedNodeList(row) for row in self._values]))) + + if ctx.scope == SCOPE_SOURCE: + ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) + if self._columns: + entities = [Entity(c) for c in self._columns] + ctx.sql(EnclosedNodeList(entities)) + else: + ctx.sql(Entity(ctx.alias_manager[self])) + + return ctx + + +class CTE(_HashableSource, Source): + def __init__(self, name, query, recursive=False, columns=None, + materialized=None): + self._alias = name + self._query = query + self._recursive = recursive + self._materialized = materialized + if columns is not None: + columns = [Entity(c) if isinstance(c, basestring) else c + for c in columns] + self._columns = columns + query._cte_list = () + super(CTE, self).__init__(alias=name) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns ' + 'from the CTE to select.') + + query = (Select((self,), columns) + .with_cte(self) + .bind(self._query._database)) + try: + query = query.objects(self._query.model) + except AttributeError: + pass + return query + + def _get_hash(self): + return hash((self.__class__, self._alias, id(self._query))) + + def union_all(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone + rhs, self._recursive, self._columns) + __add__ = union_all + + def union(self, rhs): + clone = self._query.clone() + return CTE(self._alias, clone | rhs, self._recursive, self._columns) + __or__ = union + + def __sql__(self, ctx): + if ctx.scope != SCOPE_CTE: + return ctx.sql(Entity(self._alias)) + + with ctx.push_alias(): + ctx.alias_manager[self] = self._alias + ctx.sql(Entity(self._alias)) + + if self._columns: + ctx.literal(' ').sql(EnclosedNodeList(self._columns)) + ctx.literal(' AS ') + + if self._materialized: + ctx.literal('MATERIALIZED ') + elif self._materialized is False: + ctx.literal('NOT MATERIALIZED ') + + with ctx.scope_normal(parentheses=True): + ctx.sql(self._query) + return ctx + + +class ColumnBase(Node): + _converter = None + + @Node.copy + def converter(self, converter=None): + self._converter = converter + + def alias(self, alias): + if alias: + return Alias(self, alias) + return self + + def unalias(self): + return self + + def cast(self, as_type): + return Cast(self, as_type) + + def asc(self, collation=None, nulls=None): + return Asc(self, collation=collation, nulls=nulls) + __pos__ = asc + + def desc(self, collation=None, nulls=None): + return Desc(self, collation=collation, nulls=nulls) + __neg__ = desc + + def __invert__(self): + return Negated(self) + + def _e(op, inv=False): + """ + Lightweight factory which returns a method that builds an Expression + consisting of the left-hand and right-hand operands, using `op`. + """ + def inner(self, rhs): + if inv: + return Expression(rhs, op, self) + return Expression(self, op, rhs) + return inner + __and__ = _e(OP.AND) + __or__ = _e(OP.OR) + + __add__ = _e(OP.ADD) + __sub__ = _e(OP.SUB) + __mul__ = _e(OP.MUL) + __div__ = __truediv__ = _e(OP.DIV) + __xor__ = _e(OP.XOR) + __radd__ = _e(OP.ADD, inv=True) + __rsub__ = _e(OP.SUB, inv=True) + __rmul__ = _e(OP.MUL, inv=True) + __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) + __rand__ = _e(OP.AND, inv=True) + __ror__ = _e(OP.OR, inv=True) + __rxor__ = _e(OP.XOR, inv=True) + + def __eq__(self, rhs): + op = OP.IS if rhs is None else OP.EQ + return Expression(self, op, rhs) + def __ne__(self, rhs): + op = OP.IS_NOT if rhs is None else OP.NE + return Expression(self, op, rhs) + + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lshift__ = _e(OP.IN) + __rshift__ = _e(OP.IS) + __mod__ = _e(OP.LIKE) + __pow__ = _e(OP.ILIKE) + + bin_and = _e(OP.BIN_AND) + bin_or = _e(OP.BIN_OR) + in_ = _e(OP.IN) + not_in = _e(OP.NOT_IN) + regexp = _e(OP.REGEXP) + + # Special expressions. + def is_null(self, is_null=True): + op = OP.IS if is_null else OP.IS_NOT + return Expression(self, op, None) + + def _escape_like_expr(self, s, template): + if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0: + s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%') + return NodeList((template % s, SQL('ESCAPE'), '\\')) + return template % s + def contains(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, + Expression(rhs, OP.CONCAT, '%')) + else: + rhs = self._escape_like_expr(rhs, '%%%s%%') + return Expression(self, OP.ILIKE, rhs) + def startswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression(rhs, OP.CONCAT, '%') + else: + rhs = self._escape_like_expr(rhs, '%s%%') + return Expression(self, OP.ILIKE, rhs) + def endswith(self, rhs): + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, rhs) + else: + rhs = self._escape_like_expr(rhs, '%%%s') + return Expression(self, OP.ILIKE, rhs) + def between(self, lo, hi): + return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) + def concat(self, rhs): + return StringExpression(self, OP.CONCAT, rhs) + def regexp(self, rhs): + return Expression(self, OP.REGEXP, rhs) + def iregexp(self, rhs): + return Expression(self, OP.IREGEXP, rhs) + def __getitem__(self, item): + if isinstance(item, slice): + if item.start is None or item.stop is None: + raise ValueError('BETWEEN range must have both a start- and ' + 'end-point.') + return self.between(item.start, item.stop) + return self == item + + def distinct(self): + return NodeList((SQL('DISTINCT'), self)) + + def collate(self, collation): + return NodeList((self, SQL('COLLATE %s' % collation))) + + def get_sort_key(self, ctx): + return () + + +class Column(ColumnBase): + def __init__(self, source, name): + self.source = source + self.name = name + + def get_sort_key(self, ctx): + if ctx.scope == SCOPE_VALUES: + return (self.name,) + else: + return self.source.get_sort_key(ctx) + (self.name,) + + def __hash__(self): + return hash((self.source, self.name)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + return ctx.sql(Entity(self.name)) + else: + with ctx.scope_column(): + return ctx.sql(self.source).literal('.').sql(Entity(self.name)) + + +class WrappedNode(ColumnBase): + def __init__(self, node): + self.node = node + self._coerce = getattr(node, '_coerce', True) + self._converter = getattr(node, '_converter', None) + + def is_alias(self): + return self.node.is_alias() + + def unwrap(self): + return self.node.unwrap() + + +class EntityFactory(object): + __slots__ = ('node',) + def __init__(self, node): + self.node = node + def __getattr__(self, attr): + return Entity(self.node, attr) + + +class _DynamicEntity(object): + __slots__ = () + def __get__(self, instance, instance_type=None): + if instance is not None: + return EntityFactory(instance._alias) # Implements __getattr__(). + return self + + +class Alias(WrappedNode): + c = _DynamicEntity() + + def __init__(self, node, alias): + super(Alias, self).__init__(node) + self._alias = alias + + def __hash__(self): + return hash(self._alias) + + def alias(self, alias=None): + if alias is None: + return self.node + else: + return Alias(self.node, alias) + + def unalias(self): + return self.node + + def is_alias(self): + return True + + def __sql__(self, ctx): + if ctx.scope == SCOPE_SOURCE: + return (ctx + .sql(self.node) + .literal(' AS ') + .sql(Entity(self._alias))) + else: + return ctx.sql(Entity(self._alias)) + + +class Negated(WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + return ctx.literal('NOT ').sql(self.node) + + +class BitwiseMixin(object): + def __and__(self, other): + return self.bin_and(other) + + def __or__(self, other): + return self.bin_or(other) + + def __sub__(self, other): + return self.bin_and(other.bin_negated()) + + def __invert__(self): + return BitwiseNegated(self) + + +class BitwiseNegated(BitwiseMixin, WrappedNode): + def __invert__(self): + return self.node + + def __sql__(self, ctx): + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + return ctx.literal(op_sql).sql(self.node) + + +class Value(ColumnBase): + def __init__(self, value, converter=None, unpack=True): + self.value = value + self.converter = converter + self.multi = unpack and isinstance(self.value, multi_types) + if self.multi: + self.values = [] + for item in self.value: + if isinstance(item, Node): + self.values.append(item) + else: + self.values.append(Value(item, self.converter)) + + def __sql__(self, ctx): + if self.multi: + # For multi-part values (e.g. lists of IDs). + return ctx.sql(EnclosedNodeList(self.values)) + + return ctx.value(self.value, self.converter) + + +def AsIs(value): + return Value(value, unpack=False) + + +class Cast(WrappedNode): + def __init__(self, node, cast): + super(Cast, self).__init__(node) + self._cast = cast + self._coerce = False + + def __sql__(self, ctx): + return (ctx + .literal('CAST(') + .sql(self.node) + .literal(' AS %s)' % self._cast)) + + +class Ordering(WrappedNode): + def __init__(self, node, direction, collation=None, nulls=None): + super(Ordering, self).__init__(node) + self.direction = direction + self.collation = collation + self.nulls = nulls + if nulls and nulls.lower() not in ('first', 'last'): + raise ValueError('Ordering nulls= parameter must be "first" or ' + '"last", got: %s' % nulls) + + def collate(self, collation=None): + return Ordering(self.node, self.direction, collation) + + def _null_ordering_case(self, nulls): + if nulls.lower() == 'last': + ifnull, notnull = 1, 0 + elif nulls.lower() == 'first': + ifnull, notnull = 0, 1 + else: + raise ValueError('unsupported value for nulls= ordering.') + return Case(None, ((self.node.is_null(), ifnull),), notnull) + + def __sql__(self, ctx): + if self.nulls and not ctx.state.nulls_ordering: + ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') + + ctx.sql(self.node).literal(' %s' % self.direction) + if self.collation: + ctx.literal(' COLLATE %s' % self.collation) + if self.nulls and ctx.state.nulls_ordering: + ctx.literal(' NULLS %s' % self.nulls) + return ctx + + +def Asc(node, collation=None, nulls=None): + return Ordering(node, 'ASC', collation, nulls) + + +def Desc(node, collation=None, nulls=None): + return Ordering(node, 'DESC', collation, nulls) + + +class Expression(ColumnBase): + def __init__(self, lhs, op, rhs, flat=False): + self.lhs = lhs + self.op = op + self.rhs = rhs + self.flat = flat + + def __sql__(self, ctx): + overrides = {'parentheses': not self.flat, 'in_expr': True} + + # First attempt to unwrap the node on the left-hand-side, so that we + # can get at the underlying Field if one is present. + node = raw_node = self.lhs + if isinstance(raw_node, WrappedNode): + node = raw_node.unwrap() + + # Set up the appropriate converter if we have a field on the left side. + if isinstance(node, Field) and raw_node._coerce: + overrides['converter'] = node.db_value + overrides['is_fk_expr'] = isinstance(node, ForeignKeyField) + else: + overrides['converter'] = None + + if ctx.state.operations: + op_sql = ctx.state.operations.get(self.op, self.op) + else: + op_sql = self.op + + with ctx(**overrides): + # Postgresql reports an error for IN/NOT IN (), so convert to + # the equivalent boolean expression. + op_in = self.op == OP.IN or self.op == OP.NOT_IN + if op_in and ctx.as_new().parse(self.rhs)[0] == '()': + return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') + + return (ctx + .sql(self.lhs) + .literal(' %s ' % op_sql) + .sql(self.rhs)) + + +class StringExpression(Expression): + def __add__(self, rhs): + return self.concat(rhs) + def __radd__(self, lhs): + return StringExpression(lhs, OP.CONCAT, self) + + +class Entity(ColumnBase): + def __init__(self, *path): + self._path = [part.replace('"', '""') for part in path if part] + + def __getattr__(self, attr): + return Entity(*self._path + [attr]) + + def get_sort_key(self, ctx): + return tuple(self._path) + + def __hash__(self): + return hash((self.__class__.__name__, tuple(self._path))) + + def __sql__(self, ctx): + return ctx.literal(quote(self._path, ctx.state.quote or '""')) + + +class SQL(ColumnBase): + def __init__(self, sql, params=None): + self.sql = sql + self.params = params + + def __sql__(self, ctx): + ctx.literal(self.sql) + if self.params: + for param in self.params: + ctx.value(param, False, add_param=False) + return ctx + + +def Check(constraint, name=None): + check = SQL('CHECK (%s)' % constraint) + if not name: + return check + return NodeList((SQL('CONSTRAINT'), Entity(name), check)) + + +class Function(ColumnBase): + def __init__(self, name, arguments, coerce=True, python_value=None): + self.name = name + self.arguments = arguments + self._filter = None + self._order_by = None + self._python_value = python_value + if name and name.lower() in ('sum', 'count', 'cast', 'array_agg'): + self._coerce = False + else: + self._coerce = coerce + + def __getattr__(self, attr): + def decorator(*args, **kwargs): + return Function(attr, args, **kwargs) + return decorator + + @Node.copy + def filter(self, where=None): + self._filter = where + + @Node.copy + def order_by(self, *ordering): + self._order_by = ordering + + @Node.copy + def python_value(self, func=None): + self._python_value = func + + def over(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, window=None, exclude=None): + if isinstance(partition_by, Window) and window is None: + window = partition_by + + if window is not None: + node = WindowAlias(window) + else: + node = Window(partition_by=partition_by, order_by=order_by, + start=start, end=end, frame_type=frame_type, + exclude=exclude, _inline=True) + return NodeList((self, SQL('OVER'), node)) + + def __sql__(self, ctx): + ctx.literal(self.name) + if not len(self.arguments): + ctx.literal('()') + else: + args = self.arguments + + # If this is an ordered aggregate, then we will modify the last + # argument to append the ORDER BY ... clause. We do this to avoid + # double-wrapping any expression args in parentheses, as NodeList + # has a special check (hack) in place to work around this. + if self._order_by: + args = list(args) + args[-1] = NodeList((args[-1], SQL('ORDER BY'), + CommaNodeList(self._order_by))) + + with ctx(in_function=True, function_arg_count=len(self.arguments)): + ctx.sql(EnclosedNodeList([ + (arg if isinstance(arg, Node) else Value(arg, False)) + for arg in args])) + + if self._filter: + ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') + return ctx + + +fn = Function(None, None) + + +class Window(Node): + # Frame start/end and frame exclusion. + CURRENT_ROW = SQL('CURRENT ROW') + GROUP = SQL('GROUP') + TIES = SQL('TIES') + NO_OTHERS = SQL('NO OTHERS') + + # Frame types. + GROUPS = 'GROUPS' + RANGE = 'RANGE' + ROWS = 'ROWS' + + def __init__(self, partition_by=None, order_by=None, start=None, end=None, + frame_type=None, extends=None, exclude=None, alias=None, + _inline=False): + super(Window, self).__init__() + if start is not None and not isinstance(start, SQL): + start = SQL(start) + if end is not None and not isinstance(end, SQL): + end = SQL(end) + + self.partition_by = ensure_tuple(partition_by) + self.order_by = ensure_tuple(order_by) + self.start = start + self.end = end + if self.start is None and self.end is not None: + raise ValueError('Cannot specify WINDOW end without start.') + self._alias = alias or 'w' + self._inline = _inline + self.frame_type = frame_type + self._extends = extends + self._exclude = exclude + + def alias(self, alias=None): + self._alias = alias or 'w' + return self + + @Node.copy + def as_range(self): + self.frame_type = Window.RANGE + + @Node.copy + def as_rows(self): + self.frame_type = Window.ROWS + + @Node.copy + def as_groups(self): + self.frame_type = Window.GROUPS + + @Node.copy + def extends(self, window=None): + self._extends = window + + @Node.copy + def exclude(self, frame_exclusion=None): + if isinstance(frame_exclusion, basestring): + frame_exclusion = SQL(frame_exclusion) + self._exclude = frame_exclusion + + @staticmethod + def following(value=None): + if value is None: + return SQL('UNBOUNDED FOLLOWING') + return SQL('%d FOLLOWING' % value) + + @staticmethod + def preceding(value=None): + if value is None: + return SQL('UNBOUNDED PRECEDING') + return SQL('%d PRECEDING' % value) + + def __sql__(self, ctx): + if ctx.scope != SCOPE_SOURCE and not self._inline: + ctx.literal(self._alias) + ctx.literal(' AS ') + + with ctx(parentheses=True): + parts = [] + if self._extends is not None: + ext = self._extends + if isinstance(ext, Window): + ext = SQL(ext._alias) + elif isinstance(ext, basestring): + ext = SQL(ext) + parts.append(ext) + if self.partition_by: + parts.extend(( + SQL('PARTITION BY'), + CommaNodeList(self.partition_by))) + if self.order_by: + parts.extend(( + SQL('ORDER BY'), + CommaNodeList(self.order_by))) + if self.start is not None and self.end is not None: + frame = self.frame_type or 'ROWS' + parts.extend(( + SQL('%s BETWEEN' % frame), + self.start, + SQL('AND'), + self.end)) + elif self.start is not None: + parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) + elif self.frame_type is not None: + parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) + if self._exclude is not None: + parts.extend((SQL('EXCLUDE'), self._exclude)) + ctx.sql(NodeList(parts)) + return ctx + + +class WindowAlias(Node): + def __init__(self, window): + self.window = window + + def alias(self, window_alias): + self.window._alias = window_alias + return self + + def __sql__(self, ctx): + return ctx.literal(self.window._alias or 'w') + + +class ForUpdate(Node): + def __init__(self, expr, of=None, nowait=None): + expr = 'FOR UPDATE' if expr is True else expr + if expr.lower().endswith('nowait'): + expr = expr[:-7] # Strip off the "nowait" bit. + nowait = True + + self._expr = expr + if of is not None and not isinstance(of, (list, set, tuple)): + of = (of,) + self._of = of + self._nowait = nowait + + def __sql__(self, ctx): + ctx.literal(self._expr) + if self._of is not None: + ctx.literal(' OF ').sql(CommaNodeList(self._of)) + if self._nowait: + ctx.literal(' NOWAIT') + return ctx + + +def Case(predicate, expression_tuples, default=None): + clauses = [SQL('CASE')] + if predicate is not None: + clauses.append(predicate) + for expr, value in expression_tuples: + clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) + if default is not None: + clauses.extend((SQL('ELSE'), default)) + clauses.append(SQL('END')) + return NodeList(clauses) + + +class NodeList(ColumnBase): + def __init__(self, nodes, glue=' ', parens=False): + self.nodes = nodes + self.glue = glue + self.parens = parens + if parens and len(self.nodes) == 1 and \ + isinstance(self.nodes[0], Expression) and \ + not self.nodes[0].flat: + # Hack to avoid double-parentheses. + self.nodes = (self.nodes[0].clone(),) + self.nodes[0].flat = True + + def __sql__(self, ctx): + n_nodes = len(self.nodes) + if n_nodes == 0: + return ctx.literal('()') if self.parens else ctx + with ctx(parentheses=self.parens): + for i in range(n_nodes - 1): + ctx.sql(self.nodes[i]) + ctx.literal(self.glue) + ctx.sql(self.nodes[n_nodes - 1]) + return ctx + + +def CommaNodeList(nodes): + return NodeList(nodes, ', ') + + +def EnclosedNodeList(nodes): + return NodeList(nodes, ', ', True) + + +class _Namespace(Node): + __slots__ = ('_name',) + def __init__(self, name): + self._name = name + def __getattr__(self, attr): + return NamespaceAttribute(self, attr) + __getitem__ = __getattr__ + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace, attribute): + self._namespace = namespace + self._attribute = attribute + + def __sql__(self, ctx): + return (ctx + .literal(self._namespace._name + '.') + .sql(Entity(self._attribute))) + +EXCLUDED = _Namespace('EXCLUDED') + + +class DQ(ColumnBase): + def __init__(self, **query): + super(DQ, self).__init__() + self.query = query + self._negated = False + + @Node.copy + def __invert__(self): + self._negated = not self._negated + + def clone(self): + node = DQ(**self.query) + node._negated = self._negated + return node + +#: Represent a row tuple. +Tuple = lambda *a: EnclosedNodeList(a) + + +class QualifiedNames(WrappedNode): + def __sql__(self, ctx): + with ctx.scope_column(): + return ctx.sql(self.node) + + +def qualify_names(node): + # Search a node heirarchy to ensure that any column-like objects are + # referenced using fully-qualified names. + if isinstance(node, Expression): + return node.__class__(qualify_names(node.lhs), node.op, + qualify_names(node.rhs), node.flat) + elif isinstance(node, ColumnBase): + return QualifiedNames(node) + return node + + +class OnConflict(Node): + def __init__(self, action=None, update=None, preserve=None, where=None, + conflict_target=None, conflict_where=None, + conflict_constraint=None): + self._action = action + self._update = update + self._preserve = ensure_tuple(preserve) + self._where = where + if conflict_target is not None and conflict_constraint is not None: + raise ValueError('only one of "conflict_target" and ' + '"conflict_constraint" may be specified.') + self._conflict_target = ensure_tuple(conflict_target) + self._conflict_where = conflict_where + self._conflict_constraint = conflict_constraint + + def get_conflict_statement(self, ctx, query): + return ctx.state.conflict_statement(self, query) + + def get_conflict_update(self, ctx, query): + return ctx.state.conflict_update(self, query) + + @Node.copy + def preserve(self, *columns): + self._preserve = columns + + @Node.copy + def update(self, _data=None, **kwargs): + if _data and kwargs and not isinstance(_data, dict): + raise ValueError('Cannot mix data with keyword arguments in the ' + 'OnConflict update method.') + _data = _data or {} + if kwargs: + _data.update(kwargs) + self._update = _data + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_target(self, *constraints): + self._conflict_constraint = None + self._conflict_target = constraints + + @Node.copy + def conflict_where(self, *expressions): + if self._conflict_where is not None: + expressions = (self._conflict_where,) + expressions + self._conflict_where = reduce(operator.and_, expressions) + + @Node.copy + def conflict_constraint(self, constraint): + self._conflict_constraint = constraint + self._conflict_target = None + + +def database_required(method): + @wraps(method) + def inner(self, database=None, *args, **kwargs): + database = self._database if database is None else database + if not database: + raise InterfaceError('Query must be bound to a database in order ' + 'to call "%s".' % method.__name__) + return method(self, database, *args, **kwargs) + return inner + +# BASE QUERY INTERFACE. + +class BaseQuery(Node): + default_row_type = ROW.DICT + + def __init__(self, _database=None, **kwargs): + self._database = _database + self._cursor_wrapper = None + self._row_type = None + self._constructor = None + super(BaseQuery, self).__init__(**kwargs) + + def bind(self, database=None): + self._database = database + return self + + def clone(self): + query = super(BaseQuery, self).clone() + query._cursor_wrapper = None + return query + + @Node.copy + def dicts(self, as_dict=True): + self._row_type = ROW.DICT if as_dict else None + return self + + @Node.copy + def tuples(self, as_tuple=True): + self._row_type = ROW.TUPLE if as_tuple else None + return self + + @Node.copy + def namedtuples(self, as_namedtuple=True): + self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None + return self + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR if constructor else None + self._constructor = constructor + return self + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + + if row_type == ROW.DICT: + return DictCursorWrapper(cursor) + elif row_type == ROW.TUPLE: + return CursorWrapper(cursor) + elif row_type == ROW.NAMED_TUPLE: + return NamedTupleCursorWrapper(cursor) + elif row_type == ROW.CONSTRUCTOR: + return ObjectCursorWrapper(cursor, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def __sql__(self, ctx): + raise NotImplementedError + + def sql(self): + if self._database: + context = self._database.get_sql_context() + else: + context = Context() + return context.parse(self) + + @database_required + def execute(self, database): + return self._execute(database) + + def _execute(self, database): + raise NotImplementedError + + def iterator(self, database=None): + return iter(self.execute(database).iterator()) + + def _ensure_execution(self): + if not self._cursor_wrapper: + if not self._database: + raise ValueError('Query has not been executed.') + self.execute() + + def __iter__(self): + self._ensure_execution() + return iter(self._cursor_wrapper) + + def __getitem__(self, value): + self._ensure_execution() + if isinstance(value, slice): + index = value.stop + else: + index = value + if index is not None: + index = index + 1 if index >= 0 else 0 + self._cursor_wrapper.fill_cache(index) + return self._cursor_wrapper.row_cache[value] + + def __len__(self): + self._ensure_execution() + return len(self._cursor_wrapper) + + def __str__(self): + return query_to_string(self) + + +class RawQuery(BaseQuery): + def __init__(self, sql=None, params=None, **kwargs): + super(RawQuery, self).__init__(**kwargs) + self._sql = sql + self._params = params + + def __sql__(self, ctx): + ctx.literal(self._sql) + if self._params: + for param in self._params: + ctx.value(param, add_param=False) + return ctx + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +class Query(BaseQuery): + def __init__(self, where=None, order_by=None, limit=None, offset=None, + **kwargs): + super(Query, self).__init__(**kwargs) + self._where = where + self._order_by = order_by + self._limit = limit + self._offset = offset + + self._cte_list = None + + @Node.copy + def with_cte(self, *cte_list): + self._cte_list = cte_list + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def orwhere(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.or_, expressions) + + @Node.copy + def order_by(self, *values): + self._order_by = values + + @Node.copy + def order_by_extend(self, *values): + self._order_by = ((self._order_by or ()) + values) or None + + @Node.copy + def limit(self, value=None): + self._limit = value + + @Node.copy + def offset(self, value=None): + self._offset = value + + @Node.copy + def paginate(self, page, paginate_by=20): + if page > 0: + page -= 1 + self._limit = paginate_by + self._offset = page * paginate_by + + def _apply_ordering(self, ctx): + if self._order_by: + (ctx + .literal(' ORDER BY ') + .sql(CommaNodeList(self._order_by))) + if self._limit is not None or (self._offset is not None and + ctx.state.limit_max): + limit = ctx.state.limit_max if self._limit is None else self._limit + ctx.literal(' LIMIT ').sql(limit) + if self._offset is not None: + ctx.literal(' OFFSET ').sql(self._offset) + return ctx + + def __sql__(self, ctx): + if self._cte_list: + # The CTE scope is only used at the very beginning of the query, + # when we are describing the various CTEs we will be using. + recursive = any(cte._recursive for cte in self._cte_list) + + # Explicitly disable the "subquery" flag here, so as to avoid + # unnecessary parentheses around subsequent selects. + with ctx.scope_cte(subquery=False): + (ctx + .literal('WITH RECURSIVE ' if recursive else 'WITH ') + .sql(CommaNodeList(self._cte_list)) + .literal(' ')) + return ctx + + +def __compound_select__(operation, inverted=False): + def method(self, other): + if inverted: + self, other = other, self + return CompoundSelectQuery(self, operation, other) + return method + + +class SelectQuery(Query): + union_all = __add__ = __compound_select__('UNION ALL') + union = __or__ = __compound_select__('UNION') + intersect = __and__ = __compound_select__('INTERSECT') + except_ = __sub__ = __compound_select__('EXCEPT') + __radd__ = __compound_select__('UNION ALL', inverted=True) + __ror__ = __compound_select__('UNION', inverted=True) + __rand__ = __compound_select__('INTERSECT', inverted=True) + __rsub__ = __compound_select__('EXCEPT', inverted=True) + + def select_from(self, *columns): + if not columns: + raise ValueError('select_from() must specify one or more columns.') + + query = (Select((self,), columns) + .bind(self._database)) + if getattr(self, 'model', None) is not None: + # Bind to the sub-select's model type, if defined. + query = query.objects(self.model) + return query + + +class SelectBase(_HashableSource, Source, SelectQuery): + def _get_hash(self): + return hash((self.__class__, self._alias or id(self))) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + @database_required + def peek(self, database, n=1): + rows = self.execute(database)[:n] + if rows: + return rows[0] if n == 1 else rows + + @database_required + def first(self, database, n=1): + if self._limit != n: + self._limit = n + self._cursor_wrapper = None + return self.peek(database, n=n) + + @database_required + def scalar(self, database, as_tuple=False): + row = self.tuples().peek(database) + return row[0] if row and not as_tuple else row + + @database_required + def count(self, database, clear_limit=False): + clone = self.order_by().alias('_wrapped') + if clear_limit: + clone._limit = clone._offset = None + try: + if clone._having is None and clone._group_by is None and \ + clone._windows is None and clone._distinct is None and \ + clone._simple_distinct is not True: + clone = clone.select(SQL('1')) + except AttributeError: + pass + return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) + + @database_required + def exists(self, database): + clone = self.columns(SQL('1')) + clone._limit = 1 + clone._offset = None + return bool(clone.scalar()) + + @database_required + def get(self, database): + self._cursor_wrapper = None + try: + return self.execute(database)[0] + except IndexError: + pass + + +# QUERY IMPLEMENTATIONS. + + +class CompoundSelectQuery(SelectBase): + def __init__(self, lhs, op, rhs): + super(CompoundSelectQuery, self).__init__() + self.lhs = lhs + self.op = op + self.rhs = rhs + + @property + def _returning(self): + return self.lhs._returning + + @database_required + def exists(self, database): + query = Select((self.limit(1),), (SQL('1'),)).bind(database) + return bool(query.scalar()) + + def _get_query_key(self): + return (self.lhs.get_query_key(), self.rhs.get_query_key()) + + def _wrap_parens(self, ctx, subq): + csq_setting = ctx.state.compound_select_parentheses + + if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: + return False + elif csq_setting == CSQ_PARENTHESES_ALWAYS: + return True + elif csq_setting == CSQ_PARENTHESES_UNNESTED: + if ctx.state.in_expr or ctx.state.in_function: + # If this compound select query is being used inside an + # expression, e.g., an IN or EXISTS(). + return False + + # If the query on the left or right is itself a compound select + # query, then we do not apply parentheses. However, if it is a + # regular SELECT query, we will apply parentheses. + return not isinstance(subq, CompoundSelectQuery) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + # Call parent method to handle any CTEs. + super(CompoundSelectQuery, self).__sql__(ctx) + + outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) + with ctx(parentheses=outer_parens): + # Should the left-hand query be wrapped in parentheses? + lhs_parens = self._wrap_parens(ctx, self.lhs) + with ctx.scope_normal(parentheses=lhs_parens, subquery=False): + ctx.sql(self.lhs) + ctx.literal(' %s ' % self.op) + with ctx.push_alias(): + # Should the right-hand query be wrapped in parentheses? + rhs_parens = self._wrap_parens(ctx, self.rhs) + with ctx.scope_normal(parentheses=rhs_parens, subquery=False): + ctx.sql(self.rhs) + + # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that + # entity names are not fully-qualified. This is a bit of a hack, as + # we're relying on the logic in Column.__sql__() to not fully + # qualify column names. + with ctx.scope_values(): + self._apply_ordering(ctx) + + return self.apply_alias(ctx) + + +class Select(SelectBase): + def __init__(self, from_list=None, columns=None, group_by=None, + having=None, distinct=None, windows=None, for_update=None, + for_update_of=None, nowait=None, lateral=None, **kwargs): + super(Select, self).__init__(**kwargs) + self._from_list = (list(from_list) if isinstance(from_list, tuple) + else from_list) or [] + self._returning = columns + self._group_by = group_by + self._having = having + self._windows = None + self._for_update = for_update # XXX: consider reorganizing. + self._for_update_of = for_update_of + self._for_update_nowait = nowait + self._lateral = lateral + + self._distinct = self._simple_distinct = None + if distinct: + if isinstance(distinct, bool): + self._simple_distinct = distinct + else: + self._distinct = distinct + + self._cursor_wrapper = None + + def clone(self): + clone = super(Select, self).clone() + if clone._from_list: + clone._from_list = list(clone._from_list) + return clone + + @Node.copy + def columns(self, *columns, **kwargs): + self._returning = columns + select = columns + + @Node.copy + def select_extend(self, *columns): + self._returning = tuple(self._returning) + columns + + @Node.copy + def from_(self, *sources): + self._from_list = list(sources) + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None): + if not self._from_list: + raise ValueError('No sources to join on.') + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + def group_by_extend(self, *values): + """@Node.copy used from group_by() call""" + group_by = tuple(self._group_by or ()) + values + return self.group_by(*group_by) + + @Node.copy + def having(self, *expressions): + if self._having is not None: + expressions = (self._having,) + expressions + self._having = reduce(operator.and_, expressions) + + @Node.copy + def distinct(self, *columns): + if len(columns) == 1 and (columns[0] is True or columns[0] is False): + self._simple_distinct = columns[0] + else: + self._simple_distinct = False + self._distinct = columns + + @Node.copy + def window(self, *windows): + self._windows = windows if windows else None + + @Node.copy + def for_update(self, for_update=True, of=None, nowait=None): + if not for_update and (of is not None or nowait): + for_update = True + self._for_update = for_update + self._for_update_of = of + self._for_update_nowait = nowait + + @Node.copy + def lateral(self, lateral=True): + self._lateral = lateral + + def _get_query_key(self): + return self._alias + + def __sql_selection__(self, ctx, is_subquery=False): + return ctx.sql(CommaNodeList(self._returning)) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_COLUMN: + return self.apply_column(ctx) + + if self._lateral and ctx.scope == SCOPE_SOURCE: + ctx.literal('LATERAL ') + + is_subquery = ctx.subquery + state = { + 'converter': None, + 'in_function': False, + 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), + 'subquery': True, + } + if ctx.state.in_function and ctx.state.function_arg_count == 1: + state['parentheses'] = False + + with ctx.scope_normal(**state): + # Defer calling parent SQL until here. This ensures that any CTEs + # for this query will be properly nested if this query is a + # sub-select or is used in an expression. See GH#1809 for example. + super(Select, self).__sql__(ctx) + + ctx.literal('SELECT ') + if self._simple_distinct or self._distinct is not None: + ctx.literal('DISTINCT ') + if self._distinct: + (ctx + .literal('ON ') + .sql(EnclosedNodeList(self._distinct)) + .literal(' ')) + + with ctx.scope_source(): + ctx = self.__sql_selection__(ctx, is_subquery) + + if self._from_list: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) + + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + if self._group_by: + ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) + + if self._having is not None: + ctx.literal(' HAVING ').sql(self._having) + + if self._windows is not None: + ctx.literal(' WINDOW ') + ctx.sql(CommaNodeList(self._windows)) + + # Apply ORDER BY, LIMIT, OFFSET. + self._apply_ordering(ctx) + + if self._for_update: + if not ctx.state.for_update: + raise ValueError('FOR UPDATE specified but not supported ' + 'by database.') + ctx.literal(' ') + ctx.sql(ForUpdate(self._for_update, self._for_update_of, + self._for_update_nowait)) + + # If the subquery is inside a function -or- we are evaluating a + # subquery on either side of an expression w/o an explicit alias, do + # not generate an alias + AS clause. + if ctx.state.in_function or (ctx.state.in_expr and + self._alias is None): + return ctx + + return self.apply_alias(ctx) + + +class _WriteQuery(Query): + def __init__(self, table, returning=None, **kwargs): + self.table = table + self._returning = returning + self._return_cursor = True if returning else False + super(_WriteQuery, self).__init__(**kwargs) + + @Node.copy + def returning(self, *returning): + self._returning = returning + self._return_cursor = True if returning else False + + def apply_returning(self, ctx): + if self._returning: + with ctx.scope_source(): + ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) + return ctx + + def _execute(self, database): + if self._returning: + cursor = self.execute_returning(database) + else: + cursor = database.execute(self) + return self.handle_result(database, cursor) + + def execute_returning(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self) + self._cursor_wrapper = self._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + return database.rows_affected(cursor) + + def _set_table_alias(self, ctx): + ctx.alias_manager[self.table] = self.table.__name__ + + def __sql__(self, ctx): + super(_WriteQuery, self).__sql__(ctx) + # We explicitly set the table alias to the table's name, which ensures + # that if a sub-select references a column on the outer table, we won't + # assign it a new alias (e.g. t2) but will refer to it as table.column. + self._set_table_alias(ctx) + return ctx + + +class Update(_WriteQuery): + def __init__(self, table, update=None, **kwargs): + super(Update, self).__init__(table, **kwargs) + self._update = update + self._from = None + + @Node.copy + def from_(self, *sources): + self._from = sources + + def __sql__(self, ctx): + super(Update, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('UPDATE ') + + expressions = [] + for k, v in sorted(self._update.items(), key=ctx.column_sort_key): + if not isinstance(v, Node): + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + if not isinstance(v, Value): + v = qualify_names(v) + expressions.append(NodeList((k, SQL('='), v))) + + (ctx + .sql(self.table) + .literal(' SET ') + .sql(CommaNodeList(expressions))) + + if self._from: + with ctx.scope_source(parentheses=False): + ctx.literal(' FROM ').sql(CommaNodeList(self._from)) + + if self._where: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Insert(_WriteQuery): + SIMPLE = 0 + QUERY = 1 + MULTI = 2 + class DefaultValuesException(Exception): pass + + def __init__(self, table, insert=None, columns=None, on_conflict=None, + **kwargs): + super(Insert, self).__init__(table, **kwargs) + self._insert = insert + self._columns = columns + self._on_conflict = on_conflict + self._query_type = None + + def where(self, *expressions): + raise NotImplementedError('INSERT queries cannot have a WHERE clause.') + + @Node.copy + def on_conflict_ignore(self, ignore=True): + self._on_conflict = OnConflict('IGNORE') if ignore else None + + @Node.copy + def on_conflict_replace(self, replace=True): + self._on_conflict = OnConflict('REPLACE') if replace else None + + @Node.copy + def on_conflict(self, *args, **kwargs): + self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) + else None) + + def _simple_insert(self, ctx): + if not self._insert: + raise self.DefaultValuesException('Error: no data to insert.') + return self._generate_insert((self._insert,), ctx) + + def get_default_data(self): + return {} + + def get_default_columns(self): + if self.table._columns: + return [getattr(self.table, col) for col in self.table._columns + if col != self.table._primary_key] + + def _generate_insert(self, insert, ctx): + rows_iter = iter(insert) + columns = self._columns + + # Load and organize column defaults (if provided). + defaults = self.get_default_data() + + # First figure out what columns are being inserted (if they weren't + # specified explicitly). Resulting columns are normalized and ordered. + if not columns: + try: + row = next(rows_iter) + except StopIteration: + raise self.DefaultValuesException('Error: no rows to insert.') + + if not isinstance(row, Mapping): + columns = self.get_default_columns() + if columns is None: + raise ValueError('Bulk insert must specify columns.') + else: + # Infer column names from the dict of data being inserted. + accum = [] + for column in row: + if isinstance(column, basestring): + column = getattr(self.table, column) + accum.append(column) + + # Add any columns present in the default data that are not + # accounted for by the dictionary of row data. + column_set = set(accum) + for col in (set(defaults) - column_set): + accum.append(col) + + columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) + rows_iter = itertools.chain(iter((row,)), rows_iter) + else: + clean_columns = [] + seen = set() + for column in columns: + if isinstance(column, basestring): + column_obj = getattr(self.table, column) + else: + column_obj = column + clean_columns.append(column_obj) + seen.add(column_obj) + + columns = clean_columns + for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): + if col not in seen: + columns.append(col) + + nullable_columns = set() + value_lookups = {} + for column in columns: + lookups = [column, column.name] + if isinstance(column, Field): + if column.name != column.column_name: + lookups.append(column.column_name) + if column.null: + nullable_columns.add(column) + value_lookups[column] = lookups + + ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') + columns_converters = [ + (column, column.db_value if isinstance(column, Field) else None) + for column in columns] + + all_values = [] + for row in rows_iter: + values = [] + is_dict = isinstance(row, Mapping) + for i, (column, converter) in enumerate(columns_converters): + try: + if is_dict: + # The logic is a bit convoluted, but in order to be + # flexible in what we accept (dict keyed by + # column/field, field name, or underlying column name), + # we try accessing the row data dict using each + # possible key. If no match is found, throw an error. + for lookup in value_lookups[column]: + try: + val = row[lookup] + except KeyError: pass + else: break + else: + raise KeyError + else: + val = row[i] + except (KeyError, IndexError): + if column in defaults: + val = defaults[column] + if callable_(val): + val = val() + elif column in nullable_columns: + val = None + else: + raise ValueError('Missing value for %s.' % column.name) + + if not isinstance(val, Node): + val = Value(val, converter=converter, unpack=False) + values.append(val) + + all_values.append(EnclosedNodeList(values)) + + if not all_values: + raise self.DefaultValuesException('Error: no data to insert.') + + with ctx.scope_values(subquery=True): + return ctx.sql(CommaNodeList(all_values)) + + def _query_insert(self, ctx): + return (ctx + .sql(EnclosedNodeList(self._columns)) + .literal(' ') + .sql(self._insert)) + + def _default_values(self, ctx): + if not self._database: + return ctx.literal('DEFAULT VALUES') + return self._database.default_values_insert(ctx) + + def __sql__(self, ctx): + super(Insert, self).__sql__(ctx) + with ctx.scope_values(): + stmt = None + if self._on_conflict is not None: + stmt = self._on_conflict.get_conflict_statement(ctx, self) + + (ctx + .sql(stmt or SQL('INSERT')) + .literal(' INTO ') + .sql(self.table) + .literal(' ')) + + if isinstance(self._insert, Mapping) and not self._columns: + try: + self._simple_insert(ctx) + except self.DefaultValuesException: + self._default_values(ctx) + self._query_type = Insert.SIMPLE + elif isinstance(self._insert, (SelectQuery, SQL)): + self._query_insert(ctx) + self._query_type = Insert.QUERY + else: + self._generate_insert(self._insert, ctx) + self._query_type = Insert.MULTI + + if self._on_conflict is not None: + update = self._on_conflict.get_conflict_update(ctx, self) + if update is not None: + ctx.literal(' ').sql(update) + + return self.apply_returning(ctx) + + def _execute(self, database): + if self._returning is None and database.returning_clause \ + and self.table._primary_key: + self._returning = (self.table._primary_key,) + try: + return super(Insert, self)._execute(database) + except self.DefaultValuesException: + pass + + def handle_result(self, database, cursor): + if self._return_cursor: + return cursor + if self._query_type != Insert.SIMPLE and not self._returning: + return database.rows_affected(cursor) + return database.last_insert_id(cursor, self._query_type) + + +class Delete(_WriteQuery): + def __sql__(self, ctx): + super(Delete, self).__sql__(ctx) + + with ctx.scope_values(subquery=True): + ctx.literal('DELETE FROM ').sql(self.table) + if self._where is not None: + with ctx.scope_normal(): + ctx.literal(' WHERE ').sql(self._where) + + self._apply_ordering(ctx) + return self.apply_returning(ctx) + + +class Index(Node): + def __init__(self, name, table, expressions, unique=False, safe=False, + where=None, using=None): + self._name = name + self._table = Entity(table) if not isinstance(table, Table) else table + self._expressions = expressions + self._where = where + self._unique = unique + self._safe = safe + self._using = using + + @Node.copy + def safe(self, _safe=True): + self._safe = _safe + + @Node.copy + def where(self, *expressions): + if self._where is not None: + expressions = (self._where,) + expressions + self._where = reduce(operator.and_, expressions) + + @Node.copy + def using(self, _using=None): + self._using = _using + + def __sql__(self, ctx): + statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' + with ctx.scope_values(subquery=True): + ctx.literal(statement) + if self._safe: + ctx.literal('IF NOT EXISTS ') + + # Sqlite uses CREATE INDEX . ON , whereas most + # others use: CREATE INDEX ON .
. + if ctx.state.index_schema_prefix and \ + isinstance(self._table, Table) and self._table._schema: + index_name = Entity(self._table._schema, self._name) + table_name = Entity(self._table.__name__) + else: + index_name = Entity(self._name) + table_name = self._table + + ctx.sql(index_name) + if self._using is not None and \ + ctx.state.index_using_precedes_table: + ctx.literal(' USING %s' % self._using) # MySQL style. + + (ctx + .literal(' ON ') + .sql(table_name) + .literal(' ')) + + if self._using is not None and not \ + ctx.state.index_using_precedes_table: + ctx.literal('USING %s ' % self._using) # Postgres/default. + + ctx.sql(EnclosedNodeList([ + SQL(expr) if isinstance(expr, basestring) else expr + for expr in self._expressions])) + if self._where is not None: + ctx.literal(' WHERE ').sql(self._where) + + return ctx + + +class ModelIndex(Index): + def __init__(self, model, fields, unique=False, safe=True, where=None, + using=None, name=None): + self._model = model + if name is None: + name = self._generate_name_from_fields(model, fields) + if using is None: + for field in fields: + if isinstance(field, Field) and hasattr(field, 'index_type'): + using = field.index_type + super(ModelIndex, self).__init__( + name=name, + table=model._meta.table, + expressions=fields, + unique=unique, + safe=safe, + where=where, + using=using) + + def _generate_name_from_fields(self, model, fields): + accum = [] + for field in fields: + if isinstance(field, basestring): + accum.append(field.split()[0]) + else: + if isinstance(field, Node) and not isinstance(field, Field): + field = field.unwrap() + if isinstance(field, Field): + accum.append(field.column_name) + + if not accum: + raise ValueError('Unable to generate a name for the index, please ' + 'explicitly specify a name.') + + clean_field_names = re.sub(r'[^\w]+', '', '_'.join(accum)) + meta = model._meta + prefix = meta.name if meta.legacy_table_names else meta.table_name + return _truncate_constraint_name('_'.join((prefix, clean_field_names))) + + +def _truncate_constraint_name(constraint, maxlen=64): + if len(constraint) > maxlen: + name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() + constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) + return constraint + + +# DB-API 2.0 EXCEPTIONS. + + +class PeeweeException(Exception): + def __init__(self, *args): + if args and isinstance(args[0], Exception): + self.orig, args = args[0], args[1:] + super(PeeweeException, self).__init__(*args) +class ImproperlyConfigured(PeeweeException): pass +class DatabaseError(PeeweeException): pass +class DataError(DatabaseError): pass +class IntegrityError(DatabaseError): pass +class InterfaceError(PeeweeException): pass +class InternalError(DatabaseError): pass +class NotSupportedError(DatabaseError): pass +class OperationalError(DatabaseError): pass +class ProgrammingError(DatabaseError): pass + + +class ExceptionWrapper(object): + __slots__ = ('exceptions',) + def __init__(self, exceptions): + self.exceptions = exceptions + def __enter__(self): pass + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + return + # psycopg2.8 shits out a million cute error types. Try to catch em all. + if pg_errors is not None and exc_type.__name__ not in self.exceptions \ + and issubclass(exc_type, pg_errors.Error): + exc_type = exc_type.__bases__[0] + if exc_type.__name__ in self.exceptions: + new_type = self.exceptions[exc_type.__name__] + exc_args = exc_value.args + reraise(new_type, new_type(exc_value, *exc_args), traceback) + + +EXCEPTIONS = { + 'ConstraintError': IntegrityError, + 'DatabaseError': DatabaseError, + 'DataError': DataError, + 'IntegrityError': IntegrityError, + 'InterfaceError': InterfaceError, + 'InternalError': InternalError, + 'NotSupportedError': NotSupportedError, + 'OperationalError': OperationalError, + 'ProgrammingError': ProgrammingError, + 'TransactionRollbackError': OperationalError} + +__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) + + +# DATABASE INTERFACE AND CONNECTION MANAGEMENT. + + +IndexMetadata = collections.namedtuple( + 'IndexMetadata', + ('name', 'sql', 'columns', 'unique', 'table')) +ColumnMetadata = collections.namedtuple( + 'ColumnMetadata', + ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) +ForeignKeyMetadata = collections.namedtuple( + 'ForeignKeyMetadata', + ('column', 'dest_table', 'dest_column', 'table')) +ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) + + +class _ConnectionState(object): + def __init__(self, **kwargs): + super(_ConnectionState, self).__init__(**kwargs) + self.reset() + + def reset(self): + self.closed = True + self.conn = None + self.ctx = [] + self.transactions = [] + + def set_connection(self, conn): + self.conn = conn + self.closed = False + self.ctx = [] + self.transactions = [] + + +class _ConnectionLocal(_ConnectionState, threading.local): pass +class _NoopLock(object): + __slots__ = () + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): pass + + +class ConnectionContext(_callable_context_manager): + __slots__ = ('db',) + def __init__(self, db): self.db = db + def __enter__(self): + if self.db.is_closed(): + self.db.connect() + def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() + + +class Database(_callable_context_manager): + context_class = Context + field_types = {} + operations = {} + param = '?' + quote = '""' + server_version = None + + # Feature toggles. + commit_select = False + compound_select_parentheses = CSQ_PARENTHESES_NEVER + for_update = False + index_schema_prefix = False + index_using_precedes_table = False + limit_max = None + nulls_ordering = False + returning_clause = False + safe_create_index = True + safe_drop_index = True + sequences = False + truncate_table = True + + def __init__(self, database, thread_safe=True, autorollback=False, + field_types=None, operations=None, autocommit=None, + autoconnect=True, **kwargs): + self._field_types = merge_dict(FIELD, self.field_types) + self._operations = merge_dict(OP, self.operations) + if field_types: + self._field_types.update(field_types) + if operations: + self._operations.update(operations) + + self.autoconnect = autoconnect + self.autorollback = autorollback + self.thread_safe = thread_safe + if thread_safe: + self._state = _ConnectionLocal() + self._lock = threading.RLock() + else: + self._state = _ConnectionState() + self._lock = _NoopLock() + + if autocommit is not None: + __deprecated__('Peewee no longer uses the "autocommit" option, as ' + 'the semantics now require it to always be True. ' + 'Because some database-drivers also use the ' + '"autocommit" parameter, you are receiving a ' + 'warning so you may update your code and remove ' + 'the parameter, as in the future, specifying ' + 'autocommit could impact the behavior of the ' + 'database driver you are using.') + + self.connect_params = {} + self.init(database, **kwargs) + + def init(self, database, **kwargs): + if not self.is_closed(): + self.close() + self.database = database + self.connect_params.update(kwargs) + self.deferred = not bool(database) + + def __enter__(self): + if self.is_closed(): + self.connect() + ctx = self.atomic() + self._state.ctx.append(ctx) + ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ctx = self._state.ctx.pop() + try: + ctx.__exit__(exc_type, exc_val, exc_tb) + finally: + if not self._state.ctx: + self.close() + + def connection_context(self): + return ConnectionContext(self) + + def _connect(self): + raise NotImplementedError + + def connect(self, reuse_if_open=False): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if not self._state.closed: + if reuse_if_open: + return False + raise OperationalError('Connection already opened.') + + self._state.reset() + with __exception_wrapper__: + self._state.set_connection(self._connect()) + if self.server_version is None: + self._set_server_version(self._state.conn) + self._initialize_connection(self._state.conn) + return True + + def _initialize_connection(self, conn): + pass + + def _set_server_version(self, conn): + self.server_version = 0 + + def close(self): + with self._lock: + if self.deferred: + raise InterfaceError('Error, database must be initialized ' + 'before opening a connection.') + if self.in_transaction(): + raise OperationalError('Attempting to close database while ' + 'transaction is open.') + is_open = not self._state.closed + try: + if is_open: + with __exception_wrapper__: + self._close(self._state.conn) + finally: + self._state.reset() + return is_open + + def _close(self, conn): + conn.close() + + def is_closed(self): + return self._state.closed + + def is_connection_usable(self): + return not self._state.closed + + def connection(self): + if self.is_closed(): + self.connect() + return self._state.conn + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor() + + def execute_sql(self, sql, params=None, commit=SENTINEL): + logger.debug((sql, params)) + if commit is SENTINEL: + if self.in_transaction(): + commit = False + elif self.commit_select: + commit = True + else: + commit = not sql[:6].lower().startswith('select') + + with __exception_wrapper__: + cursor = self.cursor(commit) + try: + cursor.execute(sql, params or ()) + except Exception: + if self.autorollback and not self.in_transaction(): + self.rollback() + raise + else: + if commit and not self.in_transaction(): + self.commit() + return cursor + + def execute(self, query, commit=SENTINEL, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + return self.execute_sql(sql, params, commit=commit) + + def get_context_options(self): + return { + 'field_types': self._field_types, + 'operations': self._operations, + 'param': self.param, + 'quote': self.quote, + 'compound_select_parentheses': self.compound_select_parentheses, + 'conflict_statement': self.conflict_statement, + 'conflict_update': self.conflict_update, + 'for_update': self.for_update, + 'index_schema_prefix': self.index_schema_prefix, + 'index_using_precedes_table': self.index_using_precedes_table, + 'limit_max': self.limit_max, + 'nulls_ordering': self.nulls_ordering, + } + + def get_sql_context(self, **context_options): + context = self.get_context_options() + if context_options: + context.update(context_options) + return self.context_class(**context) + + def conflict_statement(self, on_conflict, query): + raise NotImplementedError + + def conflict_update(self, on_conflict, query): + raise NotImplementedError + + def _build_on_conflict_update(self, on_conflict, query): + if on_conflict._conflict_target: + stmt = SQL('ON CONFLICT') + target = EnclosedNodeList([ + Entity(col) if isinstance(col, basestring) else col + for col in on_conflict._conflict_target]) + if on_conflict._conflict_where is not None: + target = NodeList([target, SQL('WHERE'), + on_conflict._conflict_where]) + else: + stmt = SQL('ON CONFLICT ON CONSTRAINT') + target = on_conflict._conflict_constraint + if isinstance(target, basestring): + target = Entity(target) + + updates = [] + if on_conflict._preserve: + for column in on_conflict._preserve: + excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), + glue='.') + expression = NodeList((ensure_entity(column), SQL('='), + excluded)) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + else: + v = QualifiedNames(v) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] + if on_conflict._where: + parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) + + return NodeList(parts) + + def last_insert_id(self, cursor, query_type=None): + return cursor.lastrowid + + def rows_affected(self, cursor): + return cursor.rowcount + + def default_values_insert(self, ctx): + return ctx.literal('DEFAULT VALUES') + + def session_start(self): + with self._lock: + return self.transaction().__enter__() + + def session_commit(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.commit(begin=self.in_transaction()) + return True + + def session_rollback(self): + with self._lock: + try: + txn = self.pop_transaction() + except IndexError: + return False + txn.rollback(begin=self.in_transaction()) + return True + + def in_transaction(self): + return bool(self._state.transactions) + + def push_transaction(self, transaction): + self._state.transactions.append(transaction) + + def pop_transaction(self): + return self._state.transactions.pop() + + def transaction_depth(self): + return len(self._state.transactions) + + def top_transaction(self): + if self._state.transactions: + return self._state.transactions[-1] + + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) + + def manual_commit(self): + return _manual(self) + + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) + + def savepoint(self): + return _savepoint(self) + + def begin(self): + if self.is_closed(): + self.connect() + + def commit(self): + with __exception_wrapper__: + return self._state.conn.commit() + + def rollback(self): + with __exception_wrapper__: + return self._state.conn.rollback() + + def batch_commit(self, it, n): + for group in chunked(it, n): + with self.atomic(): + for obj in group: + yield obj + + def table_exists(self, table_name, schema=None): + return table_name in self.get_tables(schema=schema) + + def get_tables(self, schema=None): + raise NotImplementedError + + def get_indexes(self, table, schema=None): + raise NotImplementedError + + def get_columns(self, table, schema=None): + raise NotImplementedError + + def get_primary_keys(self, table, schema=None): + raise NotImplementedError + + def get_foreign_keys(self, table, schema=None): + raise NotImplementedError + + def sequence_exists(self, seq): + raise NotImplementedError + + def create_tables(self, models, **options): + for model in sort_models(models): + model.create_table(**options) + + def drop_tables(self, models, **kwargs): + for model in reversed(sort_models(models)): + model.drop_table(**kwargs) + + def extract_date(self, date_part, date_field): + raise NotImplementedError + + def truncate_date(self, date_part, date_field): + raise NotImplementedError + + def to_timestamp(self, date_field): + raise NotImplementedError + + def from_timestamp(self, date_field): + raise NotImplementedError + + def random(self): + return fn.random() + + def bind(self, models, bind_refs=True, bind_backrefs=True): + for model in models: + model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) + + def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext(models, self, bind_refs, bind_backrefs) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) + + +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + + +class SqliteDatabase(Database): + field_types = { + 'BIGAUTO': FIELD.AUTO, + 'BIGINT': FIELD.INT, + 'BOOL': FIELD.INT, + 'DOUBLE': FIELD.FLOAT, + 'SMALLINT': FIELD.INT, + 'UUID': FIELD.TEXT} + operations = { + 'LIKE': 'GLOB', + 'ILIKE': 'LIKE'} + index_schema_prefix = True + limit_max = -1 + server_version = __sqlite_version__ + truncate_table = False + + def __init__(self, database, *args, **kwargs): + self._pragmas = kwargs.pop('pragmas', ()) + super(SqliteDatabase, self).__init__(database, *args, **kwargs) + self._aggregates = {} + self._collations = {} + self._functions = {} + self._window_functions = {} + self._table_functions = [] + self._extensions = set() + self._attached = {} + self.register_function(_sqlite_date_part, 'date_part', 2) + self.register_function(_sqlite_date_trunc, 'date_trunc', 2) + self.nulls_ordering = self.server_version >= (3, 30, 0) + + def init(self, database, pragmas=None, timeout=5, **kwargs): + if pragmas is not None: + self._pragmas = pragmas + if isinstance(self._pragmas, dict): + self._pragmas = list(self._pragmas.items()) + self._timeout = timeout + super(SqliteDatabase, self).init(database, **kwargs) + + def _set_server_version(self, conn): + pass + + def _connect(self): + if sqlite3 is None: + raise ImproperlyConfigured('SQLite driver not installed!') + conn = sqlite3.connect(self.database, timeout=self._timeout, + isolation_level=None, **self.connect_params) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + if self._attached: + self._attach_databases(conn) + if self._pragmas: + self._set_pragmas(conn) + self._load_aggregates(conn) + self._load_collations(conn) + self._load_functions(conn) + if self.server_version >= (3, 25, 0): + self._load_window_functions(conn) + if self._table_functions: + for table_function in self._table_functions: + table_function.register(conn) + if self._extensions: + self._load_extensions(conn) + + def _set_pragmas(self, conn): + cursor = conn.cursor() + for pragma, value in self._pragmas: + cursor.execute('PRAGMA %s = %s;' % (pragma, value)) + cursor.close() + + def _attach_databases(self, conn): + cursor = conn.cursor() + for name, db in self._attached.items(): + cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) + cursor.close() + + def pragma(self, key, value=SENTINEL, permanent=False, schema=None): + if schema is not None: + key = '"%s".%s' % (schema, key) + sql = 'PRAGMA %s' % key + if value is not SENTINEL: + sql += ' = %s' % (value or 0) + if permanent: + pragmas = dict(self._pragmas or ()) + pragmas[key] = value + self._pragmas = list(pragmas.items()) + elif permanent: + raise ValueError('Cannot specify a permanent pragma without value') + row = self.execute_sql(sql).fetchone() + if row: + return row[0] + + cache_size = __pragma__('cache_size') + foreign_keys = __pragma__('foreign_keys') + journal_mode = __pragma__('journal_mode') + journal_size_limit = __pragma__('journal_size_limit') + mmap_size = __pragma__('mmap_size') + page_size = __pragma__('page_size') + read_uncommitted = __pragma__('read_uncommitted') + synchronous = __pragma__('synchronous') + wal_autocheckpoint = __pragma__('wal_autocheckpoint') + + @property + def timeout(self): + return self._timeout + + @timeout.setter + def timeout(self, seconds): + if self._timeout == seconds: + return + + self._timeout = seconds + if not self.is_closed(): + # PySQLite multiplies user timeout by 1000, but the unit of the + # timeout PRAGMA is actually milliseconds. + self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + conn.create_aggregate(name, num_params, klass) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.create_collation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.create_function(name, num_params, fn) + + def _load_window_functions(self, conn): + for name, (klass, num_params) in self._window_functions.items(): + conn.create_window_function(name, num_params, klass) + + def register_aggregate(self, klass, name=None, num_params=-1): + self._aggregates[name or klass.__name__.lower()] = (klass, num_params) + if not self.is_closed(): + self._load_aggregates(self.connection()) + + def aggregate(self, name=None, num_params=-1): + def decorator(klass): + self.register_aggregate(klass, name, num_params) + return klass + return decorator + + def register_collation(self, fn, name=None): + name = name or fn.__name__ + def _collation(*args): + expressions = args + (SQL('collate %s' % name),) + return NodeList(expressions) + fn.collation = _collation + self._collations[name] = fn + if not self.is_closed(): + self._load_collations(self.connection()) + + def collation(self, name=None): + def decorator(fn): + self.register_collation(fn, name) + return fn + return decorator + + def register_function(self, fn, name=None, num_params=-1): + self._functions[name or fn.__name__] = (fn, num_params) + if not self.is_closed(): + self._load_functions(self.connection()) + + def func(self, name=None, num_params=-1): + def decorator(fn): + self.register_function(fn, name, num_params) + return fn + return decorator + + def register_window_function(self, klass, name=None, num_params=-1): + name = name or klass.__name__.lower() + self._window_functions[name] = (klass, num_params) + if not self.is_closed(): + self._load_window_functions(self.connection()) + + def window_function(self, name=None, num_params=-1): + def decorator(klass): + self.register_window_function(klass, name, num_params) + return klass + return decorator + + def register_table_function(self, klass, name=None): + if name is not None: + klass.name = name + self._table_functions.append(klass) + if not self.is_closed(): + klass.register(self.connection()) + + def table_function(self, name=None): + def decorator(klass): + self.register_table_function(klass, name) + return klass + return decorator + + def unregister_aggregate(self, name): + del(self._aggregates[name]) + + def unregister_collation(self, name): + del(self._collations[name]) + + def unregister_function(self, name): + del(self._functions[name]) + + def unregister_window_function(self, name): + del(self._window_functions[name]) + + def unregister_table_function(self, name): + for idx, klass in enumerate(self._table_functions): + if klass.name == name: + break + else: + return False + self._table_functions.pop(idx) + return True + + def _load_extensions(self, conn): + conn.enable_load_extension(True) + for extension in self._extensions: + conn.load_extension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enable_load_extension(True) + conn.load_extension(extension) + + def unload_extension(self, extension): + self._extensions.remove(extension) + + def attach(self, filename, name): + if name in self._attached: + if self._attached[name] == filename: + return False + raise OperationalError('schema "%s" already attached.' % name) + + self._attached[name] = filename + if not self.is_closed(): + self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) + return True + + def detach(self, name): + if name not in self._attached: + return False + + del self._attached[name] + if not self.is_closed(): + self.execute_sql('DETACH DATABASE "%s"' % name) + return True + + def begin(self, lock_type=None): + statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' + self.execute_sql(statement, commit=False) + + def get_tables(self, schema=None): + schema = schema or 'main' + cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' + 'type=? ORDER BY name' % schema, ('table',)) + return [row for row, in cursor.fetchall()] + + def get_views(self, schema=None): + sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' + 'ORDER BY name') % (schema or 'main') + return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] + + def get_indexes(self, table, schema=None): + schema = schema or 'main' + query = ('SELECT name, sql FROM "%s".sqlite_master ' + 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema + cursor = self.execute_sql(query, (table, 'index')) + index_to_sql = dict(cursor.fetchall()) + + # Determine which indexes have a unique constraint. + unique_indexes = set() + cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % + (schema, table)) + for row in cursor.fetchall(): + name = row[1] + is_unique = int(row[2]) == 1 + if is_unique: + unique_indexes.add(name) + + # Retrieve the indexed columns. + index_columns = {} + for index_name in sorted(index_to_sql): + cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % + (schema, index_name)) + index_columns[index_name] = [row[2] for row in cursor.fetchall()] + + return [ + IndexMetadata( + name, + index_to_sql[name], + index_columns[name], + name in unique_indexes, + table) + for name in sorted(index_to_sql)] + + def get_columns(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) + for r in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % + (schema or 'main', table)) + return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % + (schema or 'main', table)) + return [ForeignKeyMetadata(row[3], row[2], row[4], table) + for row in cursor.fetchall()] + + def get_binary_type(self): + return sqlite3.Binary + + def conflict_statement(self, on_conflict, query): + action = on_conflict._action.lower() if on_conflict._action else '' + if action and action not in ('nothing', 'update'): + return SQL('INSERT OR %s' % on_conflict._action.upper()) + + def conflict_update(self, oc, query): + # Sqlite prior to 3.24.0 does not support Postgres-style upsert. + if self.server_version < (3, 24, 0) and \ + any((oc._preserve, oc._update, oc._where, oc._conflict_target, + oc._conflict_constraint)): + raise ValueError('SQLite does not support specifying which values ' + 'to preserve or update.') + + action = oc._action.lower() if oc._action else '' + if action and action not in ('nothing', 'update', ''): + return + + if action == 'nothing': + return SQL('ON CONFLICT DO NOTHING') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"NOTHING".') + elif oc._conflict_constraint: + raise ValueError('SQLite does not support specifying named ' + 'constraints for conflict resolution.') + elif not oc._conflict_target: + raise ValueError('SQLite requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.date_part(date_part, date_field, python_value=int) + + def truncate_date(self, date_part, date_field): + return fn.date_trunc(date_part, date_field, + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.strftime('%s', date_field).cast('integer') + + def from_timestamp(self, date_field): + return fn.datetime(date_field, 'unixepoch') + + +class PostgresqlDatabase(Database): + field_types = { + 'AUTO': 'SERIAL', + 'BIGAUTO': 'BIGSERIAL', + 'BLOB': 'BYTEA', + 'BOOL': 'BOOLEAN', + 'DATETIME': 'TIMESTAMP', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'UUID': 'UUID', + 'UUIDB': 'BYTEA'} + operations = {'REGEXP': '~', 'IREGEXP': '~*'} + param = '%s' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_ALWAYS + for_update = True + nulls_ordering = True + returning_clause = True + safe_create_index = False + sequences = True + + def init(self, database, register_unicode=True, encoding=None, + isolation_level=None, **kwargs): + self._register_unicode = register_unicode + self._encoding = encoding + self._isolation_level = isolation_level + super(PostgresqlDatabase, self).init(database, **kwargs) + + def _connect(self): + if psycopg2 is None: + raise ImproperlyConfigured('Postgres driver not installed!') + conn = psycopg2.connect(database=self.database, **self.connect_params) + if self._register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + if self._encoding: + conn.set_client_encoding(self._encoding) + if self._isolation_level: + conn.set_isolation_level(self._isolation_level) + return conn + + def _set_server_version(self, conn): + self.server_version = conn.server_version + if self.server_version >= 90600: + self.safe_create_index = True + + def is_connection_usable(self): + if self._state.closed: + return False + + # Returns True if we are idle, running a command, or in an active + # connection. If the connection is in an error state or the connection + # is otherwise unusable, return False. + txn_status = self._state.conn.get_transaction_status() + return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR + + def last_insert_id(self, cursor, query_type=None): + try: + return cursor if query_type != Insert.SIMPLE else cursor[0][0] + except (IndexError, KeyError, TypeError): + pass + + def get_tables(self, schema=None): + query = ('SELECT tablename FROM pg_catalog.pg_tables ' + 'WHERE schemaname = %s ORDER BY tablename') + cursor = self.execute_sql(query, (schema or 'public',)) + return [table for table, in cursor.fetchall()] + + def get_views(self, schema=None): + query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' + 'WHERE schemaname = %s ORDER BY viewname') + cursor = self.execute_sql(query, (schema or 'public',)) + return [ViewMetadata(view_name, sql.strip(' \t;')) + for (view_name, sql) in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + query = """ + SELECT + i.relname, idxs.indexdef, idx.indisunique, + array_to_string(ARRAY( + SELECT pg_get_indexdef(idx.indexrelid, k + 1, TRUE) + FROM generate_subscripts(idx.indkey, 1) AS k + ORDER BY k), ',') + FROM pg_catalog.pg_class AS t + INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid + INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid + INNER JOIN pg_catalog.pg_indexes AS idxs ON + (idxs.tablename = t.relname AND idxs.indexname = i.relname) + WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s + ORDER BY idx.indisunique DESC, i.relname;""" + cursor = self.execute_sql(query, (table, 'r', schema or 'public')) + return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','), + is_unique, table) + for name, sql, is_unique, columns in cursor.fetchall()] + + def get_columns(self, table, schema=None): + query = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = %s + ORDER BY ordinal_position""" + cursor = self.execute_sql(query, (table, schema or 'public')) + pks = set(self.get_primary_keys(table, schema)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + query = """ + SELECT kc.column_name + FROM information_schema.table_constraints AS tc + INNER JOIN information_schema.key_column_usage AS kc ON ( + tc.table_name = kc.table_name AND + tc.table_schema = kc.table_schema AND + tc.constraint_name = kc.constraint_name) + WHERE + tc.constraint_type = %s AND + tc.table_name = %s AND + tc.table_schema = %s""" + ctype = 'PRIMARY KEY' + cursor = self.execute_sql(query, (ctype, table, schema or 'public')) + return [pk for pk, in cursor.fetchall()] + + def get_foreign_keys(self, table, schema=None): + sql = """ + SELECT DISTINCT + kcu.column_name, ccu.table_name, ccu.column_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON (tc.constraint_name = kcu.constraint_name AND + tc.constraint_schema = kcu.constraint_schema AND + tc.table_name = kcu.table_name AND + tc.table_schema = kcu.table_schema) + JOIN information_schema.constraint_column_usage AS ccu + ON (ccu.constraint_name = tc.constraint_name AND + ccu.constraint_schema = tc.constraint_schema) + WHERE + tc.constraint_type = 'FOREIGN KEY' AND + tc.table_name = %s AND + tc.table_schema = %s""" + cursor = self.execute_sql(sql, (table, schema or 'public')) + return [ForeignKeyMetadata(row[0], row[1], row[2], table) + for row in cursor.fetchall()] + + def sequence_exists(self, sequence): + res = self.execute_sql(""" + SELECT COUNT(*) FROM pg_class, pg_namespace + WHERE relkind='S' + AND pg_class.relnamespace = pg_namespace.oid + AND relname=%s""", (sequence,)) + return bool(res.fetchone()[0]) + + def get_binary_type(self): + return psycopg2.Binary + + def conflict_statement(self, on_conflict, query): + return + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + return SQL('ON CONFLICT DO NOTHING') + elif action and action != 'update': + raise ValueError('The only supported actions for conflict ' + 'resolution with Postgresql are "ignore" or ' + '"update".') + elif not oc._update and not oc._preserve: + raise ValueError('If you are not performing any updates (or ' + 'preserving any INSERTed values), then the ' + 'conflict resolution action should be set to ' + '"IGNORE".') + elif not (oc._conflict_target or oc._conflict_constraint): + raise ValueError('Postgres requires that a conflict target be ' + 'specified when doing an upsert.') + + return self._build_on_conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_TRUNC(date_part, date_field) + + def to_timestamp(self, date_field): + return self.extract_date('EPOCH', date_field) + + def from_timestamp(self, date_field): + # Ironically, here, Postgres means "to the Postgresql timestamp type". + return fn.to_timestamp(date_field) + + def get_noop_select(self, ctx): + return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) + + def set_time_zone(self, timezone): + self.execute_sql('set time zone "%s";' % timezone) + + +class MySQLDatabase(Database): + field_types = { + 'AUTO': 'INTEGER AUTO_INCREMENT', + 'BIGAUTO': 'BIGINT AUTO_INCREMENT', + 'BOOL': 'BOOL', + 'DECIMAL': 'NUMERIC', + 'DOUBLE': 'DOUBLE PRECISION', + 'FLOAT': 'FLOAT', + 'UUID': 'VARCHAR(40)', + 'UUIDB': 'VARBINARY(16)'} + operations = { + 'LIKE': 'LIKE BINARY', + 'ILIKE': 'LIKE', + 'REGEXP': 'REGEXP BINARY', + 'IREGEXP': 'REGEXP', + 'XOR': 'XOR'} + param = '%s' + quote = '``' + + commit_select = True + compound_select_parentheses = CSQ_PARENTHESES_UNNESTED + for_update = True + index_using_precedes_table = True + limit_max = 2 ** 64 - 1 + safe_create_index = False + safe_drop_index = False + sql_mode = 'PIPES_AS_CONCAT' + + def init(self, database, **kwargs): + params = { + 'charset': 'utf8', + 'sql_mode': self.sql_mode, + 'use_unicode': True} + params.update(kwargs) + if 'password' in params and mysql_passwd: + params['passwd'] = params.pop('password') + super(MySQLDatabase, self).init(database, **params) + + def _connect(self): + if mysql is None: + raise ImproperlyConfigured('MySQL driver not installed!') + conn = mysql.connect(db=self.database, **self.connect_params) + return conn + + def _set_server_version(self, conn): + try: + version_raw = conn.server_version + except AttributeError: + version_raw = conn.get_server_info() + self.server_version = self._extract_server_version(version_raw) + + def _extract_server_version(self, version): + version = version.lower() + if 'maria' in version: + match_obj = re.search(r'(1\d\.\d+\.\d+)', version) + else: + match_obj = re.search(r'(\d\.\d+\.\d+)', version) + if match_obj is not None: + return tuple(int(num) for num in match_obj.groups()[0].split('.')) + + warnings.warn('Unable to determine MySQL version: "%s"' % version) + return (0, 0, 0) # Unable to determine version! + + def default_values_insert(self, ctx): + return ctx.literal('() VALUES ()') + + def get_tables(self, schema=None): + query = ('SELECT table_name FROM information_schema.tables ' + 'WHERE table_schema = DATABASE() AND table_type != %s ' + 'ORDER BY table_name') + return [table for table, in self.execute_sql(query, ('VIEW',))] + + def get_views(self, schema=None): + query = ('SELECT table_name, view_definition ' + 'FROM information_schema.views ' + 'WHERE table_schema = DATABASE() ORDER BY table_name') + cursor = self.execute_sql(query) + return [ViewMetadata(*row) for row in cursor.fetchall()] + + def get_indexes(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + unique = set() + indexes = {} + for row in cursor.fetchall(): + if not row[1]: + unique.add(row[2]) + indexes.setdefault(row[2], []) + indexes[row[2]].append(row[4]) + return [IndexMetadata(name, None, indexes[name], name in unique, table) + for name in indexes] + + def get_columns(self, table, schema=None): + sql = """ + SELECT column_name, is_nullable, data_type, column_default + FROM information_schema.columns + WHERE table_name = %s AND table_schema = DATABASE()""" + cursor = self.execute_sql(sql, (table,)) + pks = set(self.get_primary_keys(table)) + return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) + for name, null, dt, df in cursor.fetchall()] + + def get_primary_keys(self, table, schema=None): + cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) + return [row[4] for row in + filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] + + def get_foreign_keys(self, table, schema=None): + query = """ + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""" + cursor = self.execute_sql(query, (table,)) + return [ + ForeignKeyMetadata(column, dest_table, dest_column, table) + for column, dest_table, dest_column in cursor.fetchall()] + + def get_binary_type(self): + return mysql.Binary + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action == 'replace': + return SQL('REPLACE') + elif action == 'ignore': + return SQL('INSERT IGNORE') + elif action != 'update': + raise ValueError('Un-supported action for conflict resolution. ' + 'MySQL supports REPLACE, IGNORE and UPDATE.') + + def conflict_update(self, on_conflict, query): + if on_conflict._where or on_conflict._conflict_target or \ + on_conflict._conflict_constraint: + raise ValueError('MySQL does not support the specification of ' + 'where clauses or conflict targets for conflict ' + 'resolution.') + + updates = [] + if on_conflict._preserve: + # Here we need to determine which function to use, which varies + # depending on the MySQL server version. MySQL and MariaDB prior to + # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". + version = self.server_version or (0,) + if version[0] == 10 and version >= (10, 3, 3): + VALUE_FN = fn.VALUE + else: + VALUE_FN = fn.VALUES + + for column in on_conflict._preserve: + entity = ensure_entity(column) + expression = NodeList(( + ensure_entity(column), + SQL('='), + VALUE_FN(entity))) + updates.append(expression) + + if on_conflict._update: + for k, v in on_conflict._update.items(): + if not isinstance(v, Node): + # Attempt to resolve string field-names to their respective + # field object, to apply data-type conversions. + if isinstance(k, basestring): + k = getattr(query.table, k) + if isinstance(k, Field): + v = k.to_value(v) + else: + v = Value(v, unpack=False) + updates.append(NodeList((ensure_entity(k), SQL('='), v))) + + if updates: + return NodeList((SQL('ON DUPLICATE KEY UPDATE'), + CommaNodeList(updates))) + + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) + + def truncate_date(self, date_part, date_field): + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.UNIX_TIMESTAMP(date_field) + + def from_timestamp(self, date_field): + return fn.FROM_UNIXTIME(date_field) + + def random(self): + return fn.rand() + + def get_noop_select(self, ctx): + return ctx.literal('DO 0') + + +# TRANSACTION CONTROL. + + +class _manual(_callable_context_manager): + def __init__(self, db): + self.db = db + + def __enter__(self): + top = self.db.top_transaction() + if top is not None and not isinstance(top, _manual): + raise ValueError('Cannot enter manual commit block while a ' + 'transaction is active.') + self.db.push_transaction(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.db.pop_transaction() is not self: + raise ValueError('Transaction stack corrupted while exiting ' + 'manual commit block.') + + +class _atomic(_callable_context_manager): + def __init__(self, db, *args, **kwargs): + self.db = db + self._transaction_args = (args, kwargs) + + def __enter__(self): + if self.db.transaction_depth() == 0: + args, kwargs = self._transaction_args + self._helper = self.db.transaction(*args, **kwargs) + elif isinstance(self.db.top_transaction(), _manual): + raise ValueError('Cannot enter atomic commit block while in ' + 'manual commit mode.') + else: + self._helper = self.db.savepoint() + return self._helper.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._helper.__exit__(exc_type, exc_val, exc_tb) + + +class _transaction(_callable_context_manager): + def __init__(self, db, *args, **kwargs): + self.db = db + self._begin_args = (args, kwargs) + + def _begin(self): + args, kwargs = self._begin_args + self.db.begin(*args, **kwargs) + + def commit(self, begin=True): + self.db.commit() + if begin: + self._begin() + + def rollback(self, begin=True): + self.db.rollback() + if begin: + self._begin() + + def __enter__(self): + if self.db.transaction_depth() == 0: + self._begin() + self.db.push_transaction(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.rollback(False) + elif self.db.transaction_depth() == 1: + try: + self.commit(False) + except: + self.rollback(False) + raise + finally: + self.db.pop_transaction() + + +class _savepoint(_callable_context_manager): + def __init__(self, db, sid=None): + self.db = db + self.sid = sid or 's' + uuid.uuid4().hex + self.quoted_sid = self.sid.join(self.db.quote) + + def _begin(self): + self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) + + def commit(self, begin=True): + self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) + if begin: self._begin() + + def rollback(self): + self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + self.rollback() + else: + try: + self.commit(begin=False) + except: + self.rollback() + raise + + +# CURSOR REPRESENTATIONS. + + +class CursorWrapper(object): + def __init__(self, cursor): + self.cursor = cursor + self.count = 0 + self.index = 0 + self.initialized = False + self.populated = False + self.row_cache = [] + + def __iter__(self): + if self.populated: + return iter(self.row_cache) + return ResultIterator(self) + + def __getitem__(self, item): + if isinstance(item, slice): + stop = item.stop + if stop is None or stop < 0: + self.fill_cache() + else: + self.fill_cache(stop) + return self.row_cache[item] + elif isinstance(item, int): + self.fill_cache(item if item > 0 else 0) + return self.row_cache[item] + else: + raise ValueError('CursorWrapper only supports integer and slice ' + 'indexes.') + + def __len__(self): + self.fill_cache() + return self.count + + def initialize(self): + pass + + def iterate(self, cache=True): + row = self.cursor.fetchone() + if row is None: + self.populated = True + self.cursor.close() + raise StopIteration + elif not self.initialized: + self.initialize() # Lazy initialization. + self.initialized = True + self.count += 1 + result = self.process_row(row) + if cache: + self.row_cache.append(result) + return result + + def process_row(self, row): + return row + + def iterator(self): + """Efficient one-pass iteration over the result set.""" + while True: + try: + yield self.iterate(False) + except StopIteration: + return + + def fill_cache(self, n=0): + n = n or float('Inf') + if n < 0: + raise ValueError('Negative values are not supported.') + + iterator = ResultIterator(self) + iterator.index = self.count + while not self.populated and (n > self.count): + try: + iterator.next() + except StopIteration: + break + + +class DictCursorWrapper(CursorWrapper): + def _initialize_columns(self): + description = self.cursor.description + self.columns = [t[0][t[0].find('.') + 1:].strip('")') + for t in description] + self.ncols = len(description) + + initialize = _initialize_columns + + def _row_to_dict(self, row): + result = {} + for i in range(self.ncols): + result.setdefault(self.columns[i], row[i]) # Do not overwrite. + return result + + process_row = _row_to_dict + + +class NamedTupleCursorWrapper(CursorWrapper): + def initialize(self): + description = self.cursor.description + self.tuple_class = collections.namedtuple( + 'Row', + [col[0][col[0].find('.') + 1:].strip('"') for col in description]) + + def process_row(self, row): + return self.tuple_class(*row) + + +class ObjectCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, constructor): + super(ObjectCursorWrapper, self).__init__(cursor) + self.constructor = constructor + + def process_row(self, row): + row_dict = self._row_to_dict(row) + return self.constructor(**row_dict) + + +class ResultIterator(object): + def __init__(self, cursor_wrapper): + self.cursor_wrapper = cursor_wrapper + self.index = 0 + + def __iter__(self): + return self + + def next(self): + if self.index < self.cursor_wrapper.count: + obj = self.cursor_wrapper.row_cache[self.index] + elif not self.cursor_wrapper.populated: + self.cursor_wrapper.iterate() + obj = self.cursor_wrapper.row_cache[self.index] + else: + raise StopIteration + self.index += 1 + return obj + + __next__ = next + +# FIELDS + +class FieldAccessor(object): + def __init__(self, model, field, name): + self.model = model + self.field = field + self.name = name + + def __get__(self, instance, instance_type=None): + if instance is not None: + return instance.__data__.get(self.name) + return self.field + + def __set__(self, instance, value): + instance.__data__[self.name] = value + instance._dirty.add(self.name) + + +class ForeignKeyAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ForeignKeyAccessor, self).__init__(model, field, name) + self.rel_model = field.rel_model + + def get_rel_instance(self, instance): + value = instance.__data__.get(self.name) + if value is not None or self.name in instance.__rel__: + if self.name not in instance.__rel__ and self.field.lazy_load: + obj = self.rel_model.get(self.field.rel_field == value) + instance.__rel__[self.name] = obj + return instance.__rel__.get(self.name, value) + elif not self.field.null: + raise self.rel_model.DoesNotExist + return value + + def __get__(self, instance, instance_type=None): + if instance is not None: + return self.get_rel_instance(instance) + return self.field + + def __set__(self, instance, obj): + if isinstance(obj, self.rel_model): + instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) + instance.__rel__[self.name] = obj + else: + fk_value = instance.__data__.get(self.name) + instance.__data__[self.name] = obj + if obj != fk_value and self.name in instance.__rel__: + del instance.__rel__[self.name] + instance._dirty.add(self.name) + + +class BackrefAccessor(object): + def __init__(self, field): + self.field = field + self.model = field.rel_model + self.rel_model = field.model + + def __get__(self, instance, instance_type=None): + if instance is not None: + dest = self.field.rel_field.name + return (self.rel_model + .select() + .where(self.field == getattr(instance, dest))) + return self + + +class ObjectIdAccessor(object): + """Gives direct access to the underlying id""" + def __init__(self, field): + self.field = field + + def __get__(self, instance, instance_type=None): + if instance is not None: + value = instance.__data__.get(self.field.name) + # Pull the object-id from the related object if it is not set. + if value is None and self.field.name in instance.__rel__: + rel_obj = instance.__rel__[self.field.name] + value = getattr(rel_obj, self.field.rel_field.name) + return value + return self.field + + def __set__(self, instance, value): + setattr(instance, self.field.name, value) + + +class Field(ColumnBase): + _field_counter = 0 + _order = 0 + accessor_class = FieldAccessor + auto_increment = False + default_index_type = None + field_type = 'DEFAULT' + unpack = True + + def __init__(self, null=False, index=False, unique=False, column_name=None, + default=None, primary_key=False, constraints=None, + sequence=None, collation=None, unindexed=False, choices=None, + help_text=None, verbose_name=None, index_type=None, + db_column=None, _hidden=False): + if db_column is not None: + __deprecated__('"db_column" has been deprecated in favor of ' + '"column_name" for Field objects.') + column_name = db_column + + self.null = null + self.index = index + self.unique = unique + self.column_name = column_name + self.default = default + self.primary_key = primary_key + self.constraints = constraints # List of column constraints. + self.sequence = sequence # Name of sequence, e.g. foo_id_seq. + self.collation = collation + self.unindexed = unindexed + self.choices = choices + self.help_text = help_text + self.verbose_name = verbose_name + self.index_type = index_type or self.default_index_type + self._hidden = _hidden + + # Used internally for recovering the order in which Fields were defined + # on the Model class. + Field._field_counter += 1 + self._order = Field._field_counter + self._sort_key = (self.primary_key and 1 or 2), self._order + + def __hash__(self): + return hash(self.name + '.' + self.model.__name__) + + def __repr__(self): + if hasattr(self, 'model') and getattr(self, 'name', None): + return '<%s: %s.%s>' % (type(self).__name__, + self.model.__name__, + self.name) + return '<%s: (unbound)>' % type(self).__name__ + + def bind(self, model, name, set_attribute=True): + self.model = model + self.name = self.safe_name = name + self.column_name = self.column_name or name + if set_attribute: + setattr(model, name, self.accessor_class(model, self, name)) + + @property + def column(self): + return Column(self.model._meta.table, self.column_name) + + def adapt(self, value): + return value + + def db_value(self, value): + return value if value is None else self.adapt(value) + + def python_value(self, value): + return value if value is None else self.adapt(value) + + def to_value(self, value): + return Value(value, self.db_value, unpack=False) + + def get_sort_key(self, ctx): + return self._sort_key + + def __sql__(self, ctx): + return ctx.sql(self.column) + + def get_modifiers(self): + pass + + def ddl_datatype(self, ctx): + if ctx and ctx.state.field_types: + column_type = ctx.state.field_types.get(self.field_type, + self.field_type) + else: + column_type = self.field_type + + modifiers = self.get_modifiers() + if column_type and modifiers: + modifier_literal = ', '.join([str(m) for m in modifiers]) + return SQL('%s(%s)' % (column_type, modifier_literal)) + else: + return SQL(column_type) + + def ddl(self, ctx): + accum = [Entity(self.column_name)] + data_type = self.ddl_datatype(ctx) + if data_type: + accum.append(data_type) + if self.unindexed: + accum.append(SQL('UNINDEXED')) + if not self.null: + accum.append(SQL('NOT NULL')) + if self.primary_key: + accum.append(SQL('PRIMARY KEY')) + if self.sequence: + accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) + if self.constraints: + accum.extend(self.constraints) + if self.collation: + accum.append(SQL('COLLATE %s' % self.collation)) + return NodeList(accum) + + +class IntegerField(Field): + field_type = 'INT' + + def adapt(self, value): + try: + return int(value) + except ValueError: + return value + + +class BigIntegerField(IntegerField): + field_type = 'BIGINT' + + +class SmallIntegerField(IntegerField): + field_type = 'SMALLINT' + + +class AutoField(IntegerField): + auto_increment = True + field_type = 'AUTO' + + def __init__(self, *args, **kwargs): + if kwargs.get('primary_key') is False: + raise ValueError('%s must always be a primary key.' % type(self)) + kwargs['primary_key'] = True + super(AutoField, self).__init__(*args, **kwargs) + + +class BigAutoField(AutoField): + field_type = 'BIGAUTO' + + +class IdentityField(AutoField): + field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' + + def __init__(self, generate_always=False, **kwargs): + if generate_always: + self.field_type = 'INT GENERATED ALWAYS AS IDENTITY' + super(IdentityField, self).__init__(**kwargs) + + +class PrimaryKeyField(AutoField): + def __init__(self, *args, **kwargs): + __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' + 'Please update your code accordingly as this will be ' + 'completely removed in a subsequent release.') + super(PrimaryKeyField, self).__init__(*args, **kwargs) + + +class FloatField(Field): + field_type = 'FLOAT' + + def adapt(self, value): + try: + return float(value) + except ValueError: + return value + + +class DoubleField(FloatField): + field_type = 'DOUBLE' + + +class DecimalField(Field): + field_type = 'DECIMAL' + + def __init__(self, max_digits=10, decimal_places=5, auto_round=False, + rounding=None, *args, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.auto_round = auto_round + self.rounding = rounding or decimal.DefaultContext.rounding + self._exp = decimal.Decimal(10) ** (-self.decimal_places) + super(DecimalField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return [self.max_digits, self.decimal_places] + + def db_value(self, value): + D = decimal.Decimal + if not value: + return value if value is None else D(0) + if self.auto_round: + decimal_value = D(text_type(value)) + return decimal_value.quantize(self._exp, rounding=self.rounding) + return value + + def python_value(self, value): + if value is not None: + if isinstance(value, decimal.Decimal): + return value + return decimal.Decimal(text_type(value)) + + +class _StringField(Field): + def adapt(self, value): + if isinstance(value, text_type): + return value + elif isinstance(value, bytes_type): + return value.decode('utf-8') + return text_type(value) + + def __add__(self, other): return StringExpression(self, OP.CONCAT, other) + def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) + + +class CharField(_StringField): + field_type = 'VARCHAR' + + def __init__(self, max_length=255, *args, **kwargs): + self.max_length = max_length + super(CharField, self).__init__(*args, **kwargs) + + def get_modifiers(self): + return self.max_length and [self.max_length] or None + + +class FixedCharField(CharField): + field_type = 'CHAR' + + def python_value(self, value): + value = super(FixedCharField, self).python_value(value) + if value: + value = value.strip() + return value + + +class TextField(_StringField): + field_type = 'TEXT' + + +class BlobField(Field): + field_type = 'BLOB' + + def _db_hook(self, database): + if database is None: + self._constructor = bytearray + else: + self._constructor = database.get_binary_type() + + def bind(self, model, name, set_attribute=True): + self._constructor = bytearray + if model._meta.database: + if isinstance(model._meta.database, Proxy): + model._meta.database.attach_callback(self._db_hook) + else: + self._db_hook(model._meta.database) + + # Attach a hook to the model metadata; in the event the database is + # changed or set at run-time, we will be sure to apply our callback and + # use the proper data-type for our database driver. + model._meta._db_hooks.append(self._db_hook) + return super(BlobField, self).bind(model, name, set_attribute) + + def db_value(self, value): + if isinstance(value, text_type): + value = value.encode('raw_unicode_escape') + if isinstance(value, bytes_type): + return self._constructor(value) + return value + + +class BitField(BitwiseMixin, BigIntegerField): + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', 0) + super(BitField, self).__init__(*args, **kwargs) + self.__current_flag = 1 + + def flag(self, value=None): + if value is None: + value = self.__current_flag + self.__current_flag <<= 1 + else: + self.__current_flag = value << 1 + + class FlagDescriptor(ColumnBase): + def __init__(self, field, value): + self._field = field + self._value = value + super(FlagDescriptor, self).__init__() + def clear(self): + return self._field.bin_and(~self._value) + def set(self): + return self._field.bin_or(self._value) + def __get__(self, instance, instance_type=None): + if instance is None: + return self + value = getattr(instance, self._field.name) or 0 + return (value & self._value) != 0 + def __set__(self, instance, is_set): + if is_set not in (True, False): + raise ValueError('Value must be either True or False') + value = getattr(instance, self._field.name) or 0 + if is_set: + value |= self._value + else: + value &= ~self._value + setattr(instance, self._field.name, value) + def __sql__(self, ctx): + return ctx.sql(self._field.bin_and(self._value) != 0) + return FlagDescriptor(self, value) + + +class BigBitFieldData(object): + def __init__(self, instance, name): + self.instance = instance + self.name = name + value = self.instance.__data__.get(self.name) + if not value: + value = bytearray() + elif not isinstance(value, bytearray): + value = bytearray(value) + self._buffer = self.instance.__data__[self.name] = value + + def _ensure_length(self, idx): + byte_num, byte_offset = divmod(idx, 8) + cur_size = len(self._buffer) + if cur_size <= byte_num: + self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) + return byte_num, byte_offset + + def set_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] |= (1 << byte_offset) + + def clear_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] &= ~(1 << byte_offset) + + def toggle_bit(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + self._buffer[byte_num] ^= (1 << byte_offset) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def is_set(self, idx): + byte_num, byte_offset = self._ensure_length(idx) + return bool(self._buffer[byte_num] & (1 << byte_offset)) + + def __repr__(self): + return repr(self._buffer) + + +class BigBitFieldAccessor(FieldAccessor): + def __get__(self, instance, instance_type=None): + if instance is None: + return self.field + return BigBitFieldData(instance, self.name) + def __set__(self, instance, value): + if isinstance(value, memoryview): + value = value.tobytes() + elif isinstance(value, buffer_type): + value = bytes(value) + elif isinstance(value, bytearray): + value = bytes_type(value) + elif isinstance(value, BigBitFieldData): + value = bytes_type(value._buffer) + elif isinstance(value, text_type): + value = value.encode('utf-8') + elif not isinstance(value, bytes_type): + raise ValueError('Value must be either a bytes, memoryview or ' + 'BigBitFieldData instance.') + super(BigBitFieldAccessor, self).__set__(instance, value) + + +class BigBitField(BlobField): + accessor_class = BigBitFieldAccessor + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', bytes_type) + super(BigBitField, self).__init__(*args, **kwargs) + + def db_value(self, value): + return bytes_type(value) if value is not None else value + + +class UUIDField(Field): + field_type = 'UUID' + + def db_value(self, value): + if isinstance(value, basestring) and len(value) == 32: + # Hex string. No transformation is necessary. + return value + elif isinstance(value, bytes) and len(value) == 16: + # Allow raw binary representation. + value = uuid.UUID(bytes=value) + if isinstance(value, uuid.UUID): + return value.hex + try: + return uuid.UUID(value).hex + except: + return value + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) if value is not None else None + + +class BinaryUUIDField(BlobField): + field_type = 'UUIDB' + + def db_value(self, value): + if isinstance(value, bytes) and len(value) == 16: + # Raw binary value. No transformation is necessary. + return self._constructor(value) + elif isinstance(value, basestring) and len(value) == 32: + # Allow hex string representation. + value = uuid.UUID(hex=value) + if isinstance(value, uuid.UUID): + return self._constructor(value.bytes) + elif value is not None: + raise ValueError('value for binary UUID field must be UUID(), ' + 'a hexadecimal string, or a bytes object.') + + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + elif isinstance(value, memoryview): + value = value.tobytes() + elif value and not isinstance(value, bytes): + value = bytes(value) + return uuid.UUID(bytes=value) if value is not None else None + + +def _date_part(date_part): + def dec(self): + return self.model._meta.database.extract_date(date_part, self) + return dec + +def format_date_time(value, formats, post_process=None): + post_process = post_process or (lambda x: x) + for fmt in formats: + try: + return post_process(datetime.datetime.strptime(value, fmt)) + except ValueError: + pass + return value + +def simple_date_time(value): + try: + return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') + except (TypeError, ValueError): + return value + + +class _BaseFormattedField(Field): + formats = None + + def __init__(self, formats=None, *args, **kwargs): + if formats is not None: + self.formats = formats + super(_BaseFormattedField, self).__init__(*args, **kwargs) + + +class DateTimeField(_BaseFormattedField): + field_type = 'DATETIME' + formats = [ + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + return format_date_time(value, self.formats) + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +class DateField(_BaseFormattedField): + field_type = 'DATE' + formats = [ + '%Y-%m-%d', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + ] + + def adapt(self, value): + if value and isinstance(value, basestring): + pp = lambda x: x.date() + return format_date_time(value, self.formats, pp) + elif value and isinstance(value, datetime.datetime): + return value.date() + return value + + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + + year = property(_date_part('year')) + month = property(_date_part('month')) + day = property(_date_part('day')) + + +class TimeField(_BaseFormattedField): + field_type = 'TIME' + formats = [ + '%H:%M:%S.%f', + '%H:%M:%S', + '%H:%M', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S', + ] + + def adapt(self, value): + if value: + if isinstance(value, basestring): + pp = lambda x: x.time() + return format_date_time(value, self.formats, pp) + elif isinstance(value, datetime.datetime): + return value.time() + if value is not None and isinstance(value, datetime.timedelta): + return (datetime.datetime.min + value).time() + return value + + hour = property(_date_part('hour')) + minute = property(_date_part('minute')) + second = property(_date_part('second')) + + +def _timestamp_date_part(date_part): + def dec(self): + db = self.model._meta.database + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return db.extract_date(date_part, db.from_timestamp(expr)) + return dec + + +class TimestampField(BigIntegerField): + # Support second -> microsecond resolution. + valid_resolutions = [10**i for i in range(7)] + + def __init__(self, *args, **kwargs): + self.resolution = kwargs.pop('resolution', None) + + if not self.resolution: + self.resolution = 1 + elif self.resolution in range(2, 7): + self.resolution = 10 ** self.resolution + elif self.resolution not in self.valid_resolutions: + raise ValueError('TimestampField resolution must be one of: %s' % + ', '.join(str(i) for i in self.valid_resolutions)) + self.ticks_to_microsecond = 1000000 // self.resolution + + self.utc = kwargs.pop('utc', False) or False + dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now + kwargs.setdefault('default', dflt) + super(TimestampField, self).__init__(*args, **kwargs) + + def local_to_utc(self, dt): + # Convert naive local datetime into naive UTC, e.g.: + # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. + # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) + + def utc_to_local(self, dt): + # Convert a naive UTC datetime into local time, e.g.: + # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. + # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. + # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. + ts = calendar.timegm(dt.utctimetuple()) + return datetime.datetime.fromtimestamp(ts) + + def get_timestamp(self, value): + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + return calendar.timegm(value.utctimetuple()) + else: + return time.mktime(value.timetuple()) + + def db_value(self, value): + if value is None: + return + + if isinstance(value, datetime.datetime): + pass + elif isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + else: + return int(round(value * self.resolution)) + + timestamp = self.get_timestamp(value) + if self.resolution > 1: + timestamp += (value.microsecond * .000001) + timestamp *= self.resolution + return int(round(timestamp)) + + def python_value(self, value): + if value is not None and isinstance(value, (int, float, long)): + if self.resolution > 1: + value, ticks = divmod(value, self.resolution) + microseconds = int(ticks * self.ticks_to_microsecond) + else: + microseconds = 0 + + if self.utc: + value = datetime.datetime.utcfromtimestamp(value) + else: + value = datetime.datetime.fromtimestamp(value) + + if microseconds: + value = value.replace(microsecond=microseconds) + + return value + + def from_timestamp(self): + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return self.model._meta.database.from_timestamp(expr) + + year = property(_timestamp_date_part('year')) + month = property(_timestamp_date_part('month')) + day = property(_timestamp_date_part('day')) + hour = property(_timestamp_date_part('hour')) + minute = property(_timestamp_date_part('minute')) + second = property(_timestamp_date_part('second')) + + +class IPField(BigIntegerField): + def db_value(self, val): + if val is not None: + return struct.unpack('!I', socket.inet_aton(val))[0] + + def python_value(self, val): + if val is not None: + return socket.inet_ntoa(struct.pack('!I', val)) + + +class BooleanField(Field): + field_type = 'BOOL' + adapt = bool + + +class BareField(Field): + def __init__(self, adapt=None, *args, **kwargs): + super(BareField, self).__init__(*args, **kwargs) + if adapt is not None: + self.adapt = adapt + + def ddl_datatype(self, ctx): + return + + +class ForeignKeyField(Field): + accessor_class = ForeignKeyAccessor + + def __init__(self, model, field=None, backref=None, on_delete=None, + on_update=None, deferrable=None, _deferred=None, + rel_model=None, to_field=None, object_id_name=None, + lazy_load=True, constraint_name=None, related_name=None, + *args, **kwargs): + kwargs.setdefault('index', True) + + super(ForeignKeyField, self).__init__(*args, **kwargs) + + if rel_model is not None: + __deprecated__('"rel_model" has been deprecated in favor of ' + '"model" for ForeignKeyField objects.') + model = rel_model + if to_field is not None: + __deprecated__('"to_field" has been deprecated in favor of ' + '"field" for ForeignKeyField objects.') + field = to_field + if related_name is not None: + __deprecated__('"related_name" has been deprecated in favor of ' + '"backref" for Field objects.') + backref = related_name + + self._is_self_reference = model == 'self' + self.rel_model = model + self.rel_field = field + self.declared_backref = backref + self.backref = None + self.on_delete = on_delete + self.on_update = on_update + self.deferrable = deferrable + self.deferred = _deferred + self.object_id_name = object_id_name + self.lazy_load = lazy_load + self.constraint_name = constraint_name + + @property + def field_type(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.field_type + elif isinstance(self.rel_field, BigAutoField): + return BigIntegerField.field_type + return IntegerField.field_type + + def get_modifiers(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.get_modifiers() + return super(ForeignKeyField, self).get_modifiers() + + def adapt(self, value): + return self.rel_field.adapt(value) + + def db_value(self, value): + if isinstance(value, self.rel_model): + value = getattr(value, self.rel_field.name) + return self.rel_field.db_value(value) + + def python_value(self, value): + if isinstance(value, self.rel_model): + return value + return self.rel_field.python_value(value) + + def bind(self, model, name, set_attribute=True): + if not self.column_name: + self.column_name = name if name.endswith('_id') else name + '_id' + if not self.object_id_name: + self.object_id_name = self.column_name + if self.object_id_name == name: + self.object_id_name += '_id' + elif self.object_id_name == name: + raise ValueError('ForeignKeyField "%s"."%s" specifies an ' + 'object_id_name that conflicts with its field ' + 'name.' % (model._meta.name, name)) + if self._is_self_reference: + self.rel_model = model + if isinstance(self.rel_field, basestring): + self.rel_field = getattr(self.rel_model, self.rel_field) + elif self.rel_field is None: + self.rel_field = self.rel_model._meta.primary_key + + # Bind field before assigning backref, so field is bound when + # calling declared_backref() (if callable). + super(ForeignKeyField, self).bind(model, name, set_attribute) + self.safe_name = self.object_id_name + + if callable_(self.declared_backref): + self.backref = self.declared_backref(self) + else: + self.backref, self.declared_backref = self.declared_backref, None + if not self.backref: + self.backref = '%s_set' % model._meta.name + + if set_attribute: + setattr(model, self.object_id_name, ObjectIdAccessor(self)) + if self.backref not in '!+': + setattr(self.rel_model, self.backref, BackrefAccessor(self)) + + def foreign_key_constraint(self): + parts = [] + if self.constraint_name: + parts.extend((SQL('CONSTRAINT'), Entity(self.constraint_name))) + parts.extend([ + SQL('FOREIGN KEY'), + EnclosedNodeList((self,)), + SQL('REFERENCES'), + self.rel_model, + EnclosedNodeList((self.rel_field,))]) + if self.on_delete: + parts.append(SQL('ON DELETE %s' % self.on_delete)) + if self.on_update: + parts.append(SQL('ON UPDATE %s' % self.on_update)) + if self.deferrable: + parts.append(SQL('DEFERRABLE %s' % self.deferrable)) + return NodeList(parts) + + def __getattr__(self, attr): + if attr.startswith('__'): + # Prevent recursion error when deep-copying. + raise AttributeError('Cannot look-up non-existant "__" methods.') + if attr in self.rel_model._meta.fields: + return self.rel_model._meta.fields[attr] + raise AttributeError('Foreign-key has no attribute %s, nor is it a ' + 'valid field on the related model.' % attr) + + +class DeferredForeignKey(Field): + _unresolved = set() + + def __init__(self, rel_model_name, **kwargs): + self.field_kwargs = kwargs + self.rel_model_name = rel_model_name.lower() + DeferredForeignKey._unresolved.add(self) + super(DeferredForeignKey, self).__init__( + column_name=kwargs.get('column_name'), + null=kwargs.get('null')) + + __hash__ = object.__hash__ + + def __deepcopy__(self, memo=None): + return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) + + def set_model(self, rel_model): + field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) + self.model._meta.add_field(self.name, field) + + @staticmethod + def resolve(model_cls): + unresolved = sorted(DeferredForeignKey._unresolved, + key=operator.attrgetter('_order')) + for dr in unresolved: + if dr.rel_model_name == model_cls.__name__.lower(): + dr.set_model(model_cls) + DeferredForeignKey._unresolved.discard(dr) + + +class DeferredThroughModel(object): + def __init__(self): + self._refs = [] + + def set_field(self, model, field, name): + self._refs.append((model, field, name)) + + def set_model(self, through_model): + for src_model, m2mfield, name in self._refs: + m2mfield.through_model = through_model + src_model._meta.add_field(name, m2mfield) + + +class MetaField(Field): + column_name = default = model = name = None + primary_key = False + + +class ManyToManyFieldAccessor(FieldAccessor): + def __init__(self, model, field, name): + super(ManyToManyFieldAccessor, self).__init__(model, field, name) + self.model = field.model + self.rel_model = field.rel_model + self.through_model = field.through_model + src_fks = self.through_model._meta.model_refs[self.model] + dest_fks = self.through_model._meta.model_refs[self.rel_model] + if not src_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.model, self.through_model)) + elif not dest_fks: + raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % + (self.rel_model, self.through_model)) + self.src_fk = src_fks[0] + self.dest_fk = dest_fks[0] + + def __get__(self, instance, instance_type=None, force_query=False): + if instance is not None: + if not force_query and self.src_fk.backref != '+': + backref = getattr(instance, self.src_fk.backref) + if isinstance(backref, list): + return [getattr(obj, self.dest_fk.name) for obj in backref] + + src_id = getattr(instance, self.src_fk.rel_field.name) + return (ManyToManyQuery(instance, self, self.rel_model) + .join(self.through_model) + .join(self.model) + .where(self.src_fk == src_id)) + + return self.field + + def __set__(self, instance, value): + query = self.__get__(instance, force_query=True) + query.add(value, clear_existing=True) + + +class ManyToManyField(MetaField): + accessor_class = ManyToManyFieldAccessor + + def __init__(self, model, backref=None, through_model=None, on_delete=None, + on_update=None, _is_backref=False): + if through_model is not None: + if not (isinstance(through_model, DeferredThroughModel) or + is_model(through_model)): + raise TypeError('Unexpected value for through_model. Expected ' + 'Model or DeferredThroughModel.') + if not _is_backref and (on_delete is not None or on_update is not None): + raise ValueError('Cannot specify on_delete or on_update when ' + 'through_model is specified.') + self.rel_model = model + self.backref = backref + self._through_model = through_model + self._on_delete = on_delete + self._on_update = on_update + self._is_backref = _is_backref + + def _get_descriptor(self): + return ManyToManyFieldAccessor(self) + + def bind(self, model, name, set_attribute=True): + if isinstance(self._through_model, DeferredThroughModel): + self._through_model.set_field(model, self, name) + return + + super(ManyToManyField, self).bind(model, name, set_attribute) + + if not self._is_backref: + many_to_many_field = ManyToManyField( + self.model, + backref=name, + through_model=self.through_model, + on_delete=self._on_delete, + on_update=self._on_update, + _is_backref=True) + self.backref = self.backref or model._meta.name + 's' + self.rel_model._meta.add_field(self.backref, many_to_many_field) + + def get_models(self): + return [model for _, model in sorted(( + (self._is_backref, self.model), + (not self._is_backref, self.rel_model)))] + + @property + def through_model(self): + if self._through_model is None: + self._through_model = self._create_through_model() + return self._through_model + + @through_model.setter + def through_model(self, value): + self._through_model = value + + def _create_through_model(self): + lhs, rhs = self.get_models() + tables = [model._meta.table_name for model in (lhs, rhs)] + + class Meta: + database = self.model._meta.database + schema = self.model._meta.schema + table_name = '%s_%s_through' % tuple(tables) + indexes = ( + ((lhs._meta.name, rhs._meta.name), + True),) + + params = {'on_delete': self._on_delete, 'on_update': self._on_update} + attrs = { + lhs._meta.name: ForeignKeyField(lhs, **params), + rhs._meta.name: ForeignKeyField(rhs, **params), + 'Meta': Meta} + + klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) + return type(klass_name, (Model,), attrs) + + def get_through_model(self): + # XXX: Deprecated. Just use the "through_model" property. + return self.through_model + + +class VirtualField(MetaField): + field_class = None + + def __init__(self, field_class=None, *args, **kwargs): + Field = field_class if field_class is not None else self.field_class + self.field_instance = Field() if Field is not None else None + super(VirtualField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if self.field_instance is not None: + return self.field_instance.db_value(value) + return value + + def python_value(self, value): + if self.field_instance is not None: + return self.field_instance.python_value(value) + return value + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, name, self.accessor_class(model, self, name)) + + +class CompositeKey(MetaField): + sequence = None + + def __init__(self, *field_names): + self.field_names = field_names + self._safe_field_names = None + + @property + def safe_field_names(self): + if self._safe_field_names is None: + if self.model is None: + return self.field_names + + self._safe_field_names = [self.model._meta.fields[f].safe_name + for f in self.field_names] + return self._safe_field_names + + def __get__(self, instance, instance_type=None): + if instance is not None: + return tuple([getattr(instance, f) for f in self.safe_field_names]) + return self + + def __set__(self, instance, value): + if not isinstance(value, (list, tuple)): + raise TypeError('A list or tuple must be used to set the value of ' + 'a composite primary key.') + if len(value) != len(self.field_names): + raise ValueError('The length of the value must equal the number ' + 'of columns of the composite primary key.') + for idx, field_value in enumerate(value): + setattr(instance, self.field_names[idx], field_value) + + def __eq__(self, other): + expressions = [(self.model._meta.fields[field] == value) + for field, value in zip(self.field_names, other)] + return reduce(operator.and_, expressions) + + def __ne__(self, other): + return ~(self == other) + + def __hash__(self): + return hash((self.model.__name__, self.field_names)) + + def __sql__(self, ctx): + # If the composite PK is being selected, do not use parens. Elsewhere, + # such as in an expression, we want to use parentheses and treat it as + # a row value. + parens = ctx.scope != SCOPE_SOURCE + return ctx.sql(NodeList([self.model._meta.fields[field] + for field in self.field_names], ', ', parens)) + + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, self.name, self) + + +class _SortedFieldList(object): + __slots__ = ('_keys', '_items') + + def __init__(self): + self._keys = [] + self._items = [] + + def __getitem__(self, i): + return self._items[i] + + def __iter__(self): + return iter(self._items) + + def __contains__(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + j = bisect_right(self._keys, k) + return item in self._items[i:j] + + def index(self, field): + return self._keys.index(field._sort_key) + + def insert(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + self._keys.insert(i, k) + self._items.insert(i, item) + + def remove(self, item): + idx = self.index(item) + del self._items[idx] + del self._keys[idx] + + +# MODELS + + +class SchemaManager(object): + def __init__(self, model, database=None, **context_options): + self.model = model + self._database = database + context_options.setdefault('scope', SCOPE_VALUES) + self.context_options = context_options + + @property + def database(self): + db = self._database or self.model._meta.database + if db is None: + raise ImproperlyConfigured('database attribute does not appear to ' + 'be set on the model: %s' % self.model) + return db + + @database.setter + def database(self, value): + self._database = value + + def _create_context(self): + return self.database.get_sql_context(**self.context_options) + + def _create_table(self, safe=True, **options): + is_temp = options.pop('temporary', False) + ctx = self._create_context() + ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + ctx.sql(self.model).literal(' ') + + columns = [] + constraints = [] + meta = self.model._meta + if meta.composite_key: + pk_columns = [meta.fields[field_name].column + for field_name in meta.primary_key.field_names] + constraints.append(NodeList((SQL('PRIMARY KEY'), + EnclosedNodeList(pk_columns)))) + + for field in meta.sorted_fields: + columns.append(field.ddl(ctx)) + if isinstance(field, ForeignKeyField) and not field.deferred: + constraints.append(field.foreign_key_constraint()) + + if meta.constraints: + constraints.extend(meta.constraints) + + constraints.extend(self._create_table_option_sql(options)) + ctx.sql(EnclosedNodeList(columns + constraints)) + + if meta.table_settings is not None: + table_settings = ensure_tuple(meta.table_settings) + for setting in table_settings: + if not isinstance(setting, basestring): + raise ValueError('table_settings must be strings') + ctx.literal(' ').literal(setting) + + if meta.without_rowid: + ctx.literal(' WITHOUT ROWID') + return ctx + + def _create_table_option_sql(self, options): + accum = [] + options = merge_dict(self.model._meta.options or {}, options) + if not options: + return accum + + for key, value in sorted(options.items()): + if not isinstance(value, Node): + if is_model(value): + value = value._meta.table + else: + value = SQL(str(value)) + accum.append(NodeList((SQL(key), value), glue='=')) + return accum + + def create_table(self, safe=True, **options): + self.database.execute(self._create_table(safe=safe, **options)) + + def _create_table_as(self, table_name, query, safe=True, **meta): + ctx = (self._create_context() + .literal('CREATE TEMPORARY TABLE ' + if meta.get('temporary') else 'CREATE TABLE ')) + if safe: + ctx.literal('IF NOT EXISTS ') + return (ctx + .sql(Entity(table_name)) + .literal(' AS ') + .sql(query)) + + def create_table_as(self, table_name, query, safe=True, **meta): + ctx = self._create_table_as(table_name, query, safe=safe, **meta) + self.database.execute(ctx) + + def _drop_table(self, safe=True, **options): + ctx = (self._create_context() + .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') + .sql(self.model)) + if options.get('cascade'): + ctx = ctx.literal(' CASCADE') + elif options.get('restrict'): + ctx = ctx.literal(' RESTRICT') + return ctx + + def drop_table(self, safe=True, **options): + self.database.execute(self._drop_table(safe=safe, **options)) + + def _truncate_table(self, restart_identity=False, cascade=False): + db = self.database + if not db.truncate_table: + return (self._create_context() + .literal('DELETE FROM ').sql(self.model)) + + ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) + if restart_identity: + ctx = ctx.literal(' RESTART IDENTITY') + if cascade: + ctx = ctx.literal(' CASCADE') + return ctx + + def truncate_table(self, restart_identity=False, cascade=False): + self.database.execute(self._truncate_table(restart_identity, cascade)) + + def _create_indexes(self, safe=True): + return [self._create_index(index, safe) + for index in self.model._meta.fields_to_index()] + + def _create_index(self, index, safe=True): + if isinstance(index, Index): + if not self.database.safe_create_index: + index = index.safe(False) + elif index._safe != safe: + index = index.safe(safe) + return self._create_context().sql(index) + + def create_indexes(self, safe=True): + for query in self._create_indexes(safe=safe): + self.database.execute(query) + + def _drop_indexes(self, safe=True): + return [self._drop_index(index, safe) + for index in self.model._meta.fields_to_index() + if isinstance(index, Index)] + + def _drop_index(self, index, safe): + statement = 'DROP INDEX ' + if safe and self.database.safe_drop_index: + statement += 'IF EXISTS ' + if isinstance(index._table, Table) and index._table._schema: + index_name = Entity(index._table._schema, index._name) + else: + index_name = Entity(index._name) + return (self + ._create_context() + .literal(statement) + .sql(index_name)) + + def drop_indexes(self, safe=True): + for query in self._drop_indexes(safe=safe): + self.database.execute(query) + + def _check_sequences(self, field): + if not field.sequence or not self.database.sequences: + raise ValueError('Sequences are either not supported, or are not ' + 'defined for "%s".' % field.name) + + def _sequence_for_field(self, field): + if field.model._meta.schema: + return Entity(field.model._meta.schema, field.sequence) + else: + return Entity(field.sequence) + + def _create_sequence(self, field): + self._check_sequences(field) + if not self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('CREATE SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def create_sequence(self, field): + seq_ctx = self._create_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _drop_sequence(self, field): + self._check_sequences(field) + if self.database.sequence_exists(field.sequence): + return (self + ._create_context() + .literal('DROP SEQUENCE ') + .sql(self._sequence_for_field(field))) + + def drop_sequence(self, field): + seq_ctx = self._drop_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) + + def _create_foreign_key(self, field): + name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, + field.column_name, + field.rel_model._meta.table_name) + return (self + ._create_context() + .literal('ALTER TABLE ') + .sql(field.model) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(name))) + .literal(' ') + .sql(field.foreign_key_constraint())) + + def create_foreign_key(self, field): + self.database.execute(self._create_foreign_key(field)) + + def create_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.create_sequence(field) + + def create_all(self, safe=True, **table_options): + self.create_sequences() + self.create_table(safe, **table_options) + self.create_indexes(safe=safe) + + def drop_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.drop_sequence(field) + + def drop_all(self, safe=True, drop_sequences=True, **options): + self.drop_table(safe, **options) + if drop_sequences: + self.drop_sequences() + + +class Metadata(object): + def __init__(self, model, database=None, table_name=None, indexes=None, + primary_key=None, constraints=None, schema=None, + only_save_dirty=False, depends_on=None, options=None, + db_table=None, table_function=None, table_settings=None, + without_rowid=False, temporary=False, legacy_table_names=True, + **kwargs): + if db_table is not None: + __deprecated__('"db_table" has been deprecated in favor of ' + '"table_name" for Models.') + table_name = db_table + self.model = model + self.database = database + + self.fields = {} + self.columns = {} + self.combined = {} + + self._sorted_field_list = _SortedFieldList() + self.sorted_fields = [] + self.sorted_field_names = [] + + self.defaults = {} + self._default_by_name = {} + self._default_dict = {} + self._default_callables = {} + self._default_callable_list = [] + + self.name = model.__name__.lower() + self.table_function = table_function + self.legacy_table_names = legacy_table_names + if not table_name: + table_name = (self.table_function(model) + if self.table_function + else self.make_table_name()) + self.table_name = table_name + self._table = None + + self.indexes = list(indexes) if indexes else [] + self.constraints = constraints + self._schema = schema + self.primary_key = primary_key + self.composite_key = self.auto_increment = None + self.only_save_dirty = only_save_dirty + self.depends_on = depends_on + self.table_settings = table_settings + self.without_rowid = without_rowid + self.temporary = temporary + + self.refs = {} + self.backrefs = {} + self.model_refs = collections.defaultdict(list) + self.model_backrefs = collections.defaultdict(list) + self.manytomany = {} + + self.options = options or {} + for key, value in kwargs.items(): + setattr(self, key, value) + self._additional_keys = set(kwargs.keys()) + + # Allow objects to register hooks that are called if the model is bound + # to a different database. For example, BlobField uses a different + # Python data-type depending on the db driver / python version. When + # the database changes, we need to update any BlobField so they can use + # the appropriate data-type. + self._db_hooks = [] + + def make_table_name(self): + if self.legacy_table_names: + return re.sub(r'[^\w]+', '_', self.name) + return make_snake_case(self.model.__name__) + + def model_graph(self, refs=True, backrefs=True, depth_first=True): + if not refs and not backrefs: + raise ValueError('One of `refs` or `backrefs` must be True.') + + accum = [(None, self.model, None)] + seen = set() + queue = collections.deque((self,)) + method = queue.pop if depth_first else queue.popleft + + while queue: + curr = method() + if curr in seen: continue + seen.add(curr) + + if refs: + for fk, model in curr.refs.items(): + accum.append((fk, model, False)) + queue.append(model._meta) + if backrefs: + for fk, model in curr.backrefs.items(): + accum.append((fk, model, True)) + queue.append(model._meta) + + return accum + + def add_ref(self, field): + rel = field.rel_model + self.refs[field] = rel + self.model_refs[rel].append(field) + rel._meta.backrefs[field] = self.model + rel._meta.model_backrefs[self.model].append(field) + + def remove_ref(self, field): + rel = field.rel_model + del self.refs[field] + self.model_refs[rel].remove(field) + del rel._meta.backrefs[field] + rel._meta.model_backrefs[self.model].remove(field) + + def add_manytomany(self, field): + self.manytomany[field.name] = field + + def remove_manytomany(self, field): + del self.manytomany[field.name] + + @property + def table(self): + if self._table is None: + self._table = Table( + self.table_name, + [field.column_name for field in self.sorted_fields], + schema=self.schema, + _model=self.model, + _database=self.database) + return self._table + + @table.setter + def table(self, value): + raise AttributeError('Cannot set the "table".') + + @table.deleter + def table(self): + self._table = None + + @property + def schema(self): + return self._schema + + @schema.setter + def schema(self, value): + self._schema = value + del self.table + + @property + def entity(self): + if self._schema: + return Entity(self._schema, self.table_name) + else: + return Entity(self.table_name) + + def _update_sorted_fields(self): + self.sorted_fields = list(self._sorted_field_list) + self.sorted_field_names = [f.name for f in self.sorted_fields] + + def get_rel_for_model(self, model): + if isinstance(model, ModelAlias): + model = model.model + forwardrefs = self.model_refs.get(model, []) + backrefs = self.model_backrefs.get(model, []) + return (forwardrefs, backrefs) + + def add_field(self, field_name, field, set_attribute=True): + if field_name in self.fields: + self.remove_field(field_name) + elif field_name in self.manytomany: + self.remove_manytomany(self.manytomany[field_name]) + + if not isinstance(field, MetaField): + del self.table + field.bind(self.model, field_name, set_attribute) + self.fields[field.name] = field + self.columns[field.column_name] = field + self.combined[field.name] = field + self.combined[field.column_name] = field + + self._sorted_field_list.insert(field) + self._update_sorted_fields() + + if field.default is not None: + # This optimization helps speed up model instance construction. + self.defaults[field] = field.default + if callable_(field.default): + self._default_callables[field] = field.default + self._default_callable_list.append((field.name, + field.default)) + else: + self._default_dict[field] = field.default + self._default_by_name[field.name] = field.default + else: + field.bind(self.model, field_name, set_attribute) + + if isinstance(field, ForeignKeyField): + self.add_ref(field) + elif isinstance(field, ManyToManyField) and field.name: + self.add_manytomany(field) + + def remove_field(self, field_name): + if field_name not in self.fields: + return + + del self.table + original = self.fields.pop(field_name) + del self.columns[original.column_name] + del self.combined[field_name] + try: + del self.combined[original.column_name] + except KeyError: + pass + self._sorted_field_list.remove(original) + self._update_sorted_fields() + + if original.default is not None: + del self.defaults[original] + if self._default_callables.pop(original, None): + for i, (name, _) in enumerate(self._default_callable_list): + if name == field_name: + self._default_callable_list.pop(i) + break + else: + self._default_dict.pop(original, None) + self._default_by_name.pop(original.name, None) + + if isinstance(original, ForeignKeyField): + self.remove_ref(original) + + def set_primary_key(self, name, field): + self.composite_key = isinstance(field, CompositeKey) + self.add_field(name, field) + self.primary_key = field + self.auto_increment = ( + field.auto_increment or + bool(field.sequence)) + + def get_primary_keys(self): + if self.composite_key: + return tuple([self.fields[field_name] + for field_name in self.primary_key.field_names]) + else: + return (self.primary_key,) if self.primary_key is not False else () + + def get_default_dict(self): + dd = self._default_by_name.copy() + for field_name, default in self._default_callable_list: + dd[field_name] = default() + return dd + + def fields_to_index(self): + indexes = [] + for f in self.sorted_fields: + if f.primary_key: + continue + if f.index or f.unique: + indexes.append(ModelIndex(self.model, (f,), unique=f.unique, + using=f.index_type)) + + for index_obj in self.indexes: + if isinstance(index_obj, Node): + indexes.append(index_obj) + elif isinstance(index_obj, (list, tuple)): + index_parts, unique = index_obj + fields = [] + for part in index_parts: + if isinstance(part, basestring): + fields.append(self.combined[part]) + elif isinstance(part, Node): + fields.append(part) + else: + raise ValueError('Expected either a field name or a ' + 'subclass of Node. Got: %s' % part) + indexes.append(ModelIndex(self.model, fields, unique=unique)) + + return indexes + + def set_database(self, database): + self.database = database + self.model._schema._database = database + del self.table + + # Apply any hooks that have been registered. + for hook in self._db_hooks: + hook(database) + + def set_table_name(self, table_name): + self.table_name = table_name + del self.table + + +class SubclassAwareMetadata(Metadata): + models = [] + + def __init__(self, model, *args, **kwargs): + super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) + self.models.append(model) + + def map_models(self, fn): + for model in self.models: + fn(model) + + +class DoesNotExist(Exception): pass + + +class ModelBase(type): + inheritable = set(['constraints', 'database', 'indexes', 'primary_key', + 'options', 'schema', 'table_function', 'temporary', + 'only_save_dirty', 'legacy_table_names', + 'table_settings']) + + def __new__(cls, name, bases, attrs): + if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: + return super(ModelBase, cls).__new__(cls, name, bases, attrs) + + meta_options = {} + meta = attrs.pop('Meta', None) + if meta: + for k, v in meta.__dict__.items(): + if not k.startswith('_'): + meta_options[k] = v + + pk = getattr(meta, 'primary_key', None) + pk_name = parent_pk = None + + # Inherit any field descriptors by deep copying the underlying field + # into the attrs of the new model, additionally see if the bases define + # inheritable model options and swipe them. + for b in bases: + if not hasattr(b, '_meta'): + continue + + base_meta = b._meta + if parent_pk is None: + parent_pk = deepcopy(base_meta.primary_key) + all_inheritable = cls.inheritable | base_meta._additional_keys + for k in base_meta.__dict__: + if k in all_inheritable and k not in meta_options: + meta_options[k] = base_meta.__dict__[k] + meta_options.setdefault('schema', base_meta.schema) + + for (k, v) in b.__dict__.items(): + if k in attrs: continue + + if isinstance(v, FieldAccessor) and not v.field.primary_key: + attrs[k] = deepcopy(v.field) + + sopts = meta_options.pop('schema_options', None) or {} + Meta = meta_options.get('model_metadata_class', Metadata) + Schema = meta_options.get('schema_manager_class', SchemaManager) + + # Construct the new class. + cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) + cls.__data__ = cls.__rel__ = None + + cls._meta = Meta(cls, **meta_options) + cls._schema = Schema(cls, **sopts) + + fields = [] + for key, value in cls.__dict__.items(): + if isinstance(value, Field): + if value.primary_key and pk: + raise ValueError('over-determined primary key %s.' % name) + elif value.primary_key: + pk, pk_name = value, key + else: + fields.append((key, value)) + + if pk is None: + if parent_pk is not False: + pk, pk_name = ((parent_pk, parent_pk.name) + if parent_pk is not None else + (AutoField(), 'id')) + else: + pk = False + elif isinstance(pk, CompositeKey): + pk_name = '__composite_key__' + cls._meta.composite_key = True + + if pk is not False: + cls._meta.set_primary_key(pk_name, pk) + + for name, field in fields: + cls._meta.add_field(name, field) + + # Create a repr and error class before finalizing. + if hasattr(cls, '__str__') and '__repr__' not in attrs: + setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( + cls.__name__, self.__str__())) + + exc_name = '%sDoesNotExist' % cls.__name__ + exc_attrs = {'__module__': cls.__module__} + exception_class = type(exc_name, (DoesNotExist,), exc_attrs) + cls.DoesNotExist = exception_class + + # Call validation hook, allowing additional model validation. + cls.validate_model() + DeferredForeignKey.resolve(cls) + return cls + + def __repr__(self): + return '' % self.__name__ + + def __iter__(self): + return iter(self.select()) + + def __getitem__(self, key): + return self.get_by_id(key) + + def __setitem__(self, key, value): + self.set_by_id(key, value) + + def __delitem__(self, key): + self.delete_by_id(key) + + def __contains__(self, key): + try: + self.get_by_id(key) + except self.DoesNotExist: + return False + else: + return True + + def __len__(self): + return self.select().count() + def __bool__(self): return True + __nonzero__ = __bool__ # Python 2. + + def __sql__(self, ctx): + return ctx.sql(self._meta.table) + + +class _BoundModelsContext(_callable_context_manager): + def __init__(self, models, database, bind_refs, bind_backrefs): + self.models = models + self.database = database + self.bind_refs = bind_refs + self.bind_backrefs = bind_backrefs + + def __enter__(self): + self._orig_database = [] + for model in self.models: + self._orig_database.append(model._meta.database) + model.bind(self.database, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) + return self.models + + def __exit__(self, exc_type, exc_val, exc_tb): + for model, db in zip(self.models, self._orig_database): + model.bind(db, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) + + +class Model(with_metaclass(ModelBase, Node)): + def __init__(self, *args, **kwargs): + if kwargs.pop('__no_default__', None): + self.__data__ = {} + else: + self.__data__ = self._meta.get_default_dict() + self._dirty = set(self.__data__) + self.__rel__ = {} + + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __str__(self): + return str(self._pk) if self._meta.primary_key is not False else 'n/a' + + @classmethod + def validate_model(cls): + pass + + @classmethod + def alias(cls, alias=None): + return ModelAlias(cls, alias) + + @classmethod + def select(cls, *fields): + is_default = not fields + if not fields: + fields = cls._meta.sorted_fields + return ModelSelect(cls, fields, is_default=is_default) + + @classmethod + def _normalize_data(cls, data, kwargs): + normalized = {} + if data: + if not isinstance(data, dict): + if kwargs: + raise ValueError('Data cannot be mixed with keyword ' + 'arguments: %s' % data) + return data + for key in data: + try: + field = (key if isinstance(key, Field) + else cls._meta.combined[key]) + except KeyError: + if not isinstance(key, Node): + raise ValueError('Unrecognized field name: "%s" in %s.' + % (key, data)) + field = key + normalized[field] = data[key] + if kwargs: + for key in kwargs: + try: + normalized[cls._meta.combined[key]] = kwargs[key] + except KeyError: + normalized[getattr(cls, key)] = kwargs[key] + return normalized + + @classmethod + def update(cls, __data=None, **update): + return ModelUpdate(cls, cls._normalize_data(__data, update)) + + @classmethod + def insert(cls, __data=None, **insert): + return ModelInsert(cls, cls._normalize_data(__data, insert)) + + @classmethod + def insert_many(cls, rows, fields=None): + return ModelInsert(cls, insert=rows, columns=fields) + + @classmethod + def insert_from(cls, query, fields): + columns = [getattr(cls, field) if isinstance(field, basestring) + else field for field in fields] + return ModelInsert(cls, insert=query, columns=columns) + + @classmethod + def replace(cls, __data=None, **insert): + return cls.insert(__data, **insert).on_conflict('REPLACE') + + @classmethod + def replace_many(cls, rows, fields=None): + return (cls + .insert_many(rows=rows, fields=fields) + .on_conflict('REPLACE')) + + @classmethod + def raw(cls, sql, *params): + return ModelRaw(cls, sql, params) + + @classmethod + def delete(cls): + return ModelDelete(cls) + + @classmethod + def create(cls, **query): + inst = cls(**query) + inst.save(force_insert=True) + return inst + + @classmethod + def bulk_create(cls, model_list, batch_size=None): + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + field_names = list(cls._meta.sorted_field_names) + if cls._meta.auto_increment: + pk_name = cls._meta.primary_key.name + field_names.remove(pk_name) + + if cls._meta.database.returning_clause and \ + cls._meta.primary_key is not False: + pk_fields = cls._meta.get_primary_keys() + else: + pk_fields = None + + fields = [cls._meta.fields[field_name] for field_name in field_names] + attrs = [] + for field in fields: + if isinstance(field, ForeignKeyField): + attrs.append(field.object_id_name) + else: + attrs.append(field.name) + + for batch in batches: + accum = ([getattr(model, f) for f in attrs] + for model in batch) + res = cls.insert_many(accum, fields=fields).execute() + if pk_fields and res is not None: + for row, model in zip(res, batch): + for (pk_field, obj_id) in zip(pk_fields, row): + setattr(model, pk_field.name, obj_id) + + @classmethod + def bulk_update(cls, model_list, fields, batch_size=None): + if isinstance(cls._meta.primary_key, CompositeKey): + raise ValueError('bulk_update() is not supported for models with ' + 'a composite primary key.') + + # First normalize list of fields so all are field instances. + fields = [cls._meta.fields[f] if isinstance(f, basestring) else f + for f in fields] + # Now collect list of attribute names to use for values. + attrs = [field.object_id_name if isinstance(field, ForeignKeyField) + else field.name for field in fields] + + if batch_size is not None: + batches = chunked(model_list, batch_size) + else: + batches = [model_list] + + n = 0 + pk = cls._meta.primary_key + + for batch in batches: + id_list = [model._pk for model in batch] + update = {} + for field, attr in zip(fields, attrs): + accum = [] + for model in batch: + value = getattr(model, attr) + if not isinstance(value, Node): + value = field.to_value(value) + accum.append((pk.to_value(model._pk), value)) + case = Case(pk, accum) + update[field] = case + + n += (cls.update(update) + .where(cls._meta.primary_key.in_(id_list)) + .execute()) + return n + + @classmethod + def noop(cls): + return NoopModelSelect(cls, ()) + + @classmethod + def get(cls, *query, **filters): + sq = cls.select() + if query: + # Handle simple lookup using just the primary key. + if len(query) == 1 and isinstance(query[0], int): + sq = sq.where(cls._meta.primary_key == query[0]) + else: + sq = sq.where(*query) + if filters: + sq = sq.filter(**filters) + return sq.get() + + @classmethod + def get_or_none(cls, *query, **filters): + try: + return cls.get(*query, **filters) + except DoesNotExist: + pass + + @classmethod + def get_by_id(cls, pk): + return cls.get(cls._meta.primary_key == pk) + + @classmethod + def set_by_id(cls, key, value): + if key is None: + return cls.insert(value).execute() + else: + return (cls.update(value) + .where(cls._meta.primary_key == key).execute()) + + @classmethod + def delete_by_id(cls, pk): + return cls.delete().where(cls._meta.primary_key == pk).execute() + + @classmethod + def get_or_create(cls, **kwargs): + defaults = kwargs.pop('defaults', {}) + query = cls.select() + for field, value in kwargs.items(): + query = query.where(getattr(cls, field) == value) + + try: + return query.get(), False + except cls.DoesNotExist: + try: + if defaults: + kwargs.update(defaults) + with cls._meta.database.atomic(): + return cls.create(**kwargs), True + except IntegrityError as exc: + try: + return query.get(), False + except cls.DoesNotExist: + raise exc + + @classmethod + def filter(cls, *dq_nodes, **filters): + return cls.select().filter(*dq_nodes, **filters) + + def get_id(self): + # Using getattr(self, pk-name) could accidentally trigger a query if + # the primary-key is a foreign-key. So we use the safe_name attribute, + # which defaults to the field-name, but will be the object_id_name for + # foreign-key fields. + if self._meta.primary_key is not False: + return getattr(self, self._meta.primary_key.safe_name) + + _pk = property(get_id) + + @_pk.setter + def _pk(self, value): + setattr(self, self._meta.primary_key.name, value) + + def _pk_expr(self): + return self._meta.primary_key == self._pk + + def _prune_fields(self, field_dict, only): + new_data = {} + for field in only: + if isinstance(field, basestring): + field = self._meta.combined[field] + if field.name in field_dict: + new_data[field.name] = field_dict[field.name] + return new_data + + def _populate_unsaved_relations(self, field_dict): + for foreign_key_field in self._meta.refs: + foreign_key = foreign_key_field.name + conditions = ( + foreign_key in field_dict and + field_dict[foreign_key] is None and + self.__rel__.get(foreign_key) is not None) + if conditions: + setattr(self, foreign_key, getattr(self, foreign_key)) + field_dict[foreign_key] = self.__data__[foreign_key] + + def save(self, force_insert=False, only=None): + field_dict = self.__data__.copy() + if self._meta.primary_key is not False: + pk_field = self._meta.primary_key + pk_value = self._pk + else: + pk_field = pk_value = None + if only is not None: + field_dict = self._prune_fields(field_dict, only) + elif self._meta.only_save_dirty and not force_insert: + field_dict = self._prune_fields(field_dict, self.dirty_fields) + if not field_dict: + self._dirty.clear() + return False + + self._populate_unsaved_relations(field_dict) + rows = 1 + + if self._meta.auto_increment and pk_value is None: + field_dict.pop(pk_field.name, None) + + if pk_value is not None and not force_insert: + if self._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + if not field_dict: + raise ValueError('no data to save!') + rows = self.update(**field_dict).where(self._pk_expr()).execute() + elif pk_field is not None: + pk = self.insert(**field_dict).execute() + if pk is not None and (self._meta.auto_increment or + pk_value is None): + self._pk = pk + else: + self.insert(**field_dict).execute() + + self._dirty.clear() + return rows + + def is_dirty(self): + return bool(self._dirty) + + @property + def dirty_fields(self): + return [f for f in self._meta.sorted_fields if f.name in self._dirty] + + def dependencies(self, search_nullable=False): + model_class = type(self) + stack = [(type(self), None)] + seen = set() + + while stack: + klass, query = stack.pop() + if klass in seen: + continue + seen.add(klass) + for fk, rel_model in klass._meta.backrefs.items(): + if rel_model is model_class or query is None: + node = (fk == self.__data__[fk.rel_field.name]) + else: + node = fk << query + subquery = (rel_model.select(rel_model._meta.primary_key) + .where(node)) + if not fk.null or search_nullable: + stack.append((rel_model, subquery)) + yield (node, fk) + + def delete_instance(self, recursive=False, delete_nullable=False): + if recursive: + dependencies = self.dependencies(delete_nullable) + for query, fk in reversed(list(dependencies)): + model = fk.model + if fk.null and not delete_nullable: + model.update(**{fk.name: None}).where(query).execute() + else: + model.delete().where(query).execute() + return type(self).delete().where(self._pk_expr()).execute() + + def __hash__(self): + return hash((self.__class__, self._pk)) + + def __eq__(self, other): + return ( + other.__class__ == self.__class__ and + self._pk is not None and + self._pk == other._pk) + + def __ne__(self, other): + return not self == other + + def __sql__(self, ctx): + # NOTE: when comparing a foreign-key field whose related-field is not a + # primary-key, then doing an equality test for the foreign-key with a + # model instance will return the wrong value; since we would return + # the primary key for a given model instance. + # + # This checks to see if we have a converter in the scope, and that we + # are converting a foreign-key expression. If so, we hand the model + # instance to the converter rather than blindly grabbing the primary- + # key. In the event the provided converter fails to handle the model + # instance, then we will return the primary-key. + if ctx.state.converter is not None and ctx.state.is_fk_expr: + try: + return ctx.sql(Value(self, converter=ctx.state.converter)) + except (TypeError, ValueError): + pass + + return ctx.sql(Value(getattr(self, self._meta.primary_key.name), + converter=self._meta.primary_key.db_value)) + + @classmethod + def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None): + is_different = cls._meta.database is not database + cls._meta.set_database(database) + if bind_refs or bind_backrefs: + if _exclude is None: + _exclude = set() + G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) + for _, model, is_backref in G: + if model not in _exclude: + model._meta.set_database(database) + _exclude.add(model) + return is_different + + @classmethod + def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): + return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) + + @classmethod + def table_exists(cls): + M = cls._meta + return cls._schema.database.table_exists(M.table.__name__, M.schema) + + @classmethod + def create_table(cls, safe=True, **options): + if 'fail_silently' in options: + __deprecated__('"fail_silently" has been deprecated in favor of ' + '"safe" for the create_table() method.') + safe = options.pop('fail_silently') + + if safe and not cls._schema.database.safe_create_index \ + and cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.create_all(safe, **options) + + @classmethod + def drop_table(cls, safe=True, drop_sequences=True, **options): + if safe and not cls._schema.database.safe_drop_index \ + and not cls.table_exists(): + return + if cls._meta.temporary: + options.setdefault('temporary', cls._meta.temporary) + cls._schema.drop_all(safe, drop_sequences, **options) + + @classmethod + def truncate_table(cls, **options): + cls._schema.truncate_table(**options) + + @classmethod + def index(cls, *fields, **kwargs): + return ModelIndex(cls, fields, **kwargs) + + @classmethod + def add_index(cls, *fields, **kwargs): + if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): + cls._meta.indexes.append(fields[0]) + else: + cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) + + +class ModelAlias(Node): + """Provide a separate reference to a model in a query.""" + def __init__(self, model, alias=None): + self.__dict__['model'] = model + self.__dict__['alias'] = alias + + def __getattr__(self, attr): + # Hack to work-around the fact that properties or other objects + # implementing the descriptor protocol (on the model being aliased), + # will not work correctly when we use getattr(). So we explicitly pass + # the model alias to the descriptor's getter. + try: + obj = self.model.__dict__[attr] + except KeyError: + pass + else: + if isinstance(obj, ModelDescriptor): + return obj.__get__(None, self) + + model_attr = getattr(self.model, attr) + if isinstance(model_attr, Field): + self.__dict__[attr] = FieldAlias.create(self, model_attr) + return self.__dict__[attr] + return model_attr + + def __setattr__(self, attr, value): + raise AttributeError('Cannot set attributes on model aliases.') + + def get_field_aliases(self): + return [getattr(self, n) for n in self.model._meta.sorted_field_names] + + def select(self, *selection): + if not selection: + selection = self.get_field_aliases() + return ModelSelect(self, selection) + + def __call__(self, **kwargs): + return self.model(**kwargs) + + def __sql__(self, ctx): + if ctx.scope == SCOPE_VALUES: + # Return the quoted table name. + return ctx.sql(self.model) + + if self.alias: + ctx.alias_manager[self] = self.alias + + if ctx.scope == SCOPE_SOURCE: + # Define the table and its alias. + return (ctx + .sql(self.model._meta.entity) + .literal(' AS ') + .sql(Entity(ctx.alias_manager[self]))) + else: + # Refer to the table using the alias. + return ctx.sql(Entity(ctx.alias_manager[self])) + + +class FieldAlias(Field): + def __init__(self, source, field): + self.source = source + self.model = source.model + self.field = field + + @classmethod + def create(cls, source, field): + class _FieldAlias(cls, type(field)): + pass + return _FieldAlias(source, field) + + def clone(self): + return FieldAlias(self.source, self.field) + + def adapt(self, value): return self.field.adapt(value) + def python_value(self, value): return self.field.python_value(value) + def db_value(self, value): return self.field.db_value(value) + def __getattr__(self, attr): + return self.source if attr == 'model' else getattr(self.field, attr) + + def __sql__(self, ctx): + return ctx.sql(Column(self.source, self.field.column_name)) + + +def sort_models(models): + models = set(models) + seen = set() + ordering = [] + def dfs(model): + if model in models and model not in seen: + seen.add(model) + for foreign_key, rel_model in model._meta.refs.items(): + # Do not depth-first search deferred foreign-keys as this can + # cause tables to be created in the incorrect order. + if not foreign_key.deferred: + dfs(rel_model) + if model._meta.depends_on: + for dependency in model._meta.depends_on: + dfs(dependency) + ordering.append(model) + + names = lambda m: (m._meta.name, m._meta.table_name) + for m in sorted(models, key=names): + dfs(m) + return ordering + + +class _ModelQueryHelper(object): + default_row_type = ROW.MODEL + + def __init__(self, *args, **kwargs): + super(_ModelQueryHelper, self).__init__(*args, **kwargs) + if not self._database: + self._database = self.model._meta.database + + @Node.copy + def objects(self, constructor=None): + self._row_type = ROW.CONSTRUCTOR + self._constructor = self.model if constructor is None else constructor + + def _get_cursor_wrapper(self, cursor): + row_type = self._row_type or self.default_row_type + if row_type == ROW.MODEL: + return self._get_model_cursor_wrapper(cursor) + elif row_type == ROW.DICT: + return ModelDictCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.TUPLE: + return ModelTupleCursorWrapper(cursor, self.model, self._returning) + elif row_type == ROW.NAMED_TUPLE: + return ModelNamedTupleCursorWrapper(cursor, self.model, + self._returning) + elif row_type == ROW.CONSTRUCTOR: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self._constructor) + else: + raise ValueError('Unrecognized row type: "%s".' % row_type) + + def _get_model_cursor_wrapper(self, cursor): + return ModelObjectCursorWrapper(cursor, self.model, [], self.model) + + +class ModelRaw(_ModelQueryHelper, RawQuery): + def __init__(self, model, sql, params, **kwargs): + self.model = model + self._returning = () + super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) + + def get(self): + try: + return self.execute()[0] + except IndexError: + sql, params = self.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (self.model, sql, params)) + + +class BaseModelSelect(_ModelQueryHelper): + def union_all(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) + __add__ = union_all + + def union(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) + __or__ = union + + def intersect(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) + __and__ = intersect + + def except_(self, rhs): + return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) + __sub__ = except_ + + def __iter__(self): + if not self._cursor_wrapper: + self.execute() + return iter(self._cursor_wrapper) + + def prefetch(self, *subqueries): + return prefetch(self, *subqueries) + + def get(self, database=None): + clone = self.paginate(1, 1) + clone._cursor_wrapper = None + try: + return clone.execute(database)[0] + except IndexError: + sql, params = clone.sql() + raise self.model.DoesNotExist('%s instance matching query does ' + 'not exist:\nSQL: %s\nParams: %s' % + (clone.model, sql, params)) + + @Node.copy + def group_by(self, *columns): + grouping = [] + for column in columns: + if is_model(column): + grouping.extend(column._meta.sorted_fields) + elif isinstance(column, Table): + if not column._columns: + raise ValueError('Cannot pass a table to group_by() that ' + 'does not have columns explicitly ' + 'declared.') + grouping.extend([getattr(column, col_name) + for col_name in column._columns]) + else: + grouping.append(column) + self._group_by = grouping + + +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): + def __init__(self, model, *args, **kwargs): + self.model = model + super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) + + def _get_model_cursor_wrapper(self, cursor): + return self.lhs._get_model_cursor_wrapper(cursor) + + +def _normalize_model_select(fields_or_models): + fields = [] + for fm in fields_or_models: + if is_model(fm): + fields.extend(fm._meta.sorted_fields) + elif isinstance(fm, ModelAlias): + fields.extend(fm.get_field_aliases()) + elif isinstance(fm, Table) and fm._columns: + fields.extend([getattr(fm, col) for col in fm._columns]) + else: + fields.append(fm) + return fields + + +class ModelSelect(BaseModelSelect, Select): + def __init__(self, model, fields_or_models, is_default=False): + self.model = self._join_ctx = model + self._joins = {} + self._is_default = is_default + fields = _normalize_model_select(fields_or_models) + super(ModelSelect, self).__init__([model], fields) + + def clone(self): + clone = super(ModelSelect, self).clone() + if clone._joins: + clone._joins = dict(clone._joins) + return clone + + def select(self, *fields_or_models): + if fields_or_models or not self._is_default: + self._is_default = False + fields = _normalize_model_select(fields_or_models) + return super(ModelSelect, self).select(*fields) + return self + + def switch(self, ctx=None): + self._join_ctx = self.model if ctx is None else ctx + return self + + def _get_model(self, src): + if is_model(src): + return src, True + elif isinstance(src, Table) and src._model: + return src._model, False + elif isinstance(src, ModelAlias): + return src.model, False + elif isinstance(src, ModelSelect): + return src.model, False + return None, False + + def _normalize_join(self, src, dest, on, attr): + # Allow "on" expression to have an alias that determines the + # destination attribute for the joined data. + on_alias = isinstance(on, Alias) + if on_alias: + attr = attr or on._alias + on = on.alias() + + # Obtain references to the source and destination models being joined. + src_model, src_is_model = self._get_model(src) + dest_model, dest_is_model = self._get_model(dest) + + if src_model and dest_model: + self._join_ctx = dest + constructor = dest_model + + # In the case where the "on" clause is a Column or Field, we will + # convert that field into the appropriate predicate expression. + if not (src_is_model and dest_is_model) and isinstance(on, Column): + if on.source is src: + to_field = src_model._meta.columns[on.name] + elif on.source is dest: + to_field = dest_model._meta.columns[on.name] + else: + raise AttributeError('"on" clause Column %s does not ' + 'belong to %s or %s.' % + (on, src_model, dest_model)) + on = None + elif isinstance(on, Field): + to_field = on + on = None + else: + to_field = None + + fk_field, is_backref = self._generate_on_clause( + src_model, dest_model, to_field, on) + + if on is None: + src_attr = 'name' if src_is_model else 'column_name' + dest_attr = 'name' if dest_is_model else 'column_name' + if is_backref: + lhs = getattr(dest, getattr(fk_field, dest_attr)) + rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) + else: + lhs = getattr(src, getattr(fk_field, src_attr)) + rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) + on = (lhs == rhs) + + if not attr: + if fk_field is not None and not is_backref: + attr = fk_field.name + else: + attr = dest_model._meta.name + elif on_alias and fk_field is not None and \ + attr == fk_field.object_id_name and not is_backref: + raise ValueError('Cannot assign join alias to "%s", as this ' + 'attribute is the object_id_name for the ' + 'foreign-key field "%s"' % (attr, fk_field)) + + elif isinstance(dest, Source): + constructor = dict + attr = attr or dest._alias + if not attr and isinstance(dest, Table): + attr = attr or dest.__name__ + + return (on, attr, constructor) + + def _generate_on_clause(self, src, dest, to_field=None, on=None): + meta = src._meta + is_backref = fk_fields = False + + # Get all the foreign keys between source and dest, and determine if + # the join is via a back-reference. + if dest in meta.model_refs: + fk_fields = meta.model_refs[dest] + elif dest in meta.model_backrefs: + fk_fields = meta.model_backrefs[dest] + is_backref = True + + if not fk_fields: + if on is not None: + return None, False + raise ValueError('Unable to find foreign key between %s and %s. ' + 'Please specify an explicit join condition.' % + (src, dest)) + elif to_field is not None: + # If the foreign-key field was specified explicitly, remove all + # other foreign-key fields from the list. + target = (to_field.field if isinstance(to_field, FieldAlias) + else to_field) + fk_fields = [f for f in fk_fields if ( + (f is target) or + (is_backref and f.rel_field is to_field))] + + if len(fk_fields) == 1: + return fk_fields[0], is_backref + + if on is None: + # If multiple foreign-keys exist, try using the FK whose name + # matches that of the related model. If not, raise an error as this + # is ambiguous. + for fk in fk_fields: + if fk.name == dest._meta.name: + return fk, is_backref + + raise ValueError('More than one foreign key between %s and %s.' + ' Please specify which you are joining on.' % + (src, dest)) + + # If there are multiple foreign-keys to choose from and the join + # predicate is an expression, we'll try to figure out which + # foreign-key field we're joining on so that we can assign to the + # correct attribute when resolving the model graph. + to_field = None + if isinstance(on, Expression): + lhs, rhs = on.lhs, on.rhs + # Coerce to set() so that we force Python to compare using the + # object's hash rather than equality test, which returns a + # false-positive due to overriding __eq__. + fk_set = set(fk_fields) + + if isinstance(lhs, Field): + lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs + if lhs_f in fk_set: + to_field = lhs_f + elif isinstance(rhs, Field): + rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs + if rhs_f in fk_set: + to_field = rhs_f + + return to_field, False + + @Node.copy + def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None): + src = self._join_ctx if src is None else src + + if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL: + on = True + elif join_type != JOIN.CROSS: + on, attr, constructor = self._normalize_join(src, dest, on, attr) + if attr: + self._joins.setdefault(src, []) + self._joins[src].append((dest, attr, constructor, join_type)) + elif on is not None: + raise ValueError('Cannot specify on clause with cross join.') + + if not self._from_list: + raise ValueError('No sources to join on.') + + item = self._from_list.pop() + self._from_list.append(Join(item, dest, join_type, on)) + + def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): + return self.join(dest, join_type, on, src, attr) + + def _get_model_cursor_wrapper(self, cursor): + if len(self._from_list) == 1 and not self._joins: + return ModelObjectCursorWrapper(cursor, self.model, + self._returning, self.model) + return ModelCursorWrapper(cursor, self.model, self._returning, + self._from_list, self._joins) + + def ensure_join(self, lm, rm, on=None, **join_kwargs): + join_ctx = self._join_ctx + for dest, _, constructor, _ in self._joins.get(lm, []): + if dest == rm: + return self + return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) + + def convert_dict_to_node(self, qdict): + accum = [] + joins = [] + fks = (ForeignKeyField, BackrefAccessor) + for key, value in sorted(qdict.items()): + curr = self.model + if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: + key, op = key.rsplit('__', 1) + op = DJANGO_MAP[op] + elif value is None: + op = DJANGO_MAP['is'] + else: + op = DJANGO_MAP['eq'] + + if '__' not in key: + # Handle simplest case. This avoids joining over-eagerly when a + # direct FK lookup is all that is required. + model_attr = getattr(curr, key) + else: + for piece in key.split('__'): + for dest, attr, _, _ in self._joins.get(curr, ()): + if attr == piece or (isinstance(dest, ModelAlias) and + dest.alias == piece): + curr = dest + break + else: + model_attr = getattr(curr, piece) + if value is not None and isinstance(model_attr, fks): + curr = model_attr.rel_model + joins.append(model_attr) + accum.append(op(model_attr, value)) + return accum, joins + + def filter(self, *args, **kwargs): + # normalize args and kwargs into a new expression + if args and kwargs: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + DQ(**kwargs)) + elif args: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + ColumnBase()) + elif kwargs: + dq_node = DQ(**kwargs) & ColumnBase() + else: + return self.clone() + + # dq_node should now be an Expression, lhs = Node(), rhs = ... + q = collections.deque([dq_node]) + dq_joins = [] + seen_joins = set() + while q: + curr = q.popleft() + if not isinstance(curr, Expression): + continue + for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): + if isinstance(piece, DQ): + query, joins = self.convert_dict_to_node(piece.query) + for join in joins: + if join not in seen_joins: + dq_joins.append(join) + seen_joins.add(join) + expression = reduce(operator.and_, query) + # Apply values from the DQ object. + if piece._negated: + expression = Negated(expression) + #expression._alias = piece._alias + setattr(curr, side, expression) + else: + q.append(piece) + + if not args or not kwargs: + dq_node = dq_node.lhs + + query = self.clone() + for field in dq_joins: + if isinstance(field, ForeignKeyField): + lm, rm = field.model, field.rel_model + field_obj = field + elif isinstance(field, BackrefAccessor): + lm, rm = field.model, field.rel_model + field_obj = field.field + query = query.ensure_join(lm, rm, field_obj) + return query.where(dq_node) + + def create_table(self, name, safe=True, **meta): + return self.model._schema.create_table_as(name, self, safe, **meta) + + def __sql_selection__(self, ctx, is_subquery=False): + if self._is_default and is_subquery and len(self._returning) > 1 and \ + self.model._meta.primary_key is not False: + return ctx.sql(self.model._meta.primary_key) + + return ctx.sql(CommaNodeList(self._returning)) + + +class NoopModelSelect(ModelSelect): + def __sql__(self, ctx): + return self.model._meta.database.get_noop_select(ctx) + + def _get_cursor_wrapper(self, cursor): + return CursorWrapper(cursor) + + +class _ModelWriteQueryHelper(_ModelQueryHelper): + def __init__(self, model, *args, **kwargs): + self.model = model + super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) + + def returning(self, *returning): + accum = [] + for item in returning: + if is_model(item): + accum.extend(item._meta.sorted_fields) + else: + accum.append(item) + return super(_ModelWriteQueryHelper, self).returning(*accum) + + def _set_table_alias(self, ctx): + table = self.model._meta.table + ctx.alias_manager[table] = table.__name__ + + +class ModelUpdate(_ModelWriteQueryHelper, Update): + pass + + +class ModelInsert(_ModelWriteQueryHelper, Insert): + default_row_type = ROW.TUPLE + + def __init__(self, *args, **kwargs): + super(ModelInsert, self).__init__(*args, **kwargs) + if self._returning is None and self.model._meta.database is not None: + if self.model._meta.database.returning_clause: + self._returning = self.model._meta.get_primary_keys() + + def returning(self, *returning): + # By default ModelInsert will yield a `tuple` containing the + # primary-key of the newly inserted row. But if we are explicitly + # specifying a returning clause and have not set a row type, we will + # default to returning model instances instead. + if returning and self._row_type is None: + self._row_type = ROW.MODEL + return super(ModelInsert, self).returning(*returning) + + def get_default_data(self): + return self.model._meta.defaults + + def get_default_columns(self): + fields = self.model._meta.sorted_fields + return fields[1:] if self.model._meta.auto_increment else fields + + +class ModelDelete(_ModelWriteQueryHelper, Delete): + pass + + +class ManyToManyQuery(ModelSelect): + def __init__(self, instance, accessor, rel, *args, **kwargs): + self._instance = instance + self._accessor = accessor + self._src_attr = accessor.src_fk.rel_field.name + self._dest_attr = accessor.dest_fk.rel_field.name + super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) + + def _id_list(self, model_or_id_list): + if isinstance(model_or_id_list[0], Model): + return [getattr(obj, self._dest_attr) for obj in model_or_id_list] + return model_or_id_list + + def add(self, value, clear_existing=False): + if clear_existing: + self.clear() + + accessor = self._accessor + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + query = value.columns( + Value(src_id), + accessor.dest_fk.rel_field) + accessor.through_model.insert_from( + fields=[accessor.src_fk, accessor.dest_fk], + query=query).execute() + else: + value = ensure_tuple(value) + if not value: return + + inserts = [{ + accessor.src_fk.name: src_id, + accessor.dest_fk.name: rel_id} + for rel_id in self._id_list(value)] + accessor.through_model.insert_many(inserts).execute() + + def remove(self, value): + src_id = getattr(self._instance, self._src_attr) + if isinstance(value, SelectQuery): + column = getattr(value.model, self._dest_attr) + subquery = value.columns(column) + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << subquery) & + (self._accessor.src_fk == src_id)) + .execute()) + else: + value = ensure_tuple(value) + if not value: + return + return (self._accessor.through_model + .delete() + .where( + (self._accessor.dest_fk << self._id_list(value)) & + (self._accessor.src_fk == src_id)) + .execute()) + + def clear(self): + src_id = getattr(self._instance, self._src_attr) + return (self._accessor.through_model + .delete() + .where(self._accessor.src_fk == src_id) + .execute()) + + +def safe_python_value(conv_func): + def validate(value): + try: + return conv_func(value) + except (TypeError, ValueError): + return value + return validate + + +class BaseModelCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, model, columns): + super(BaseModelCursorWrapper, self).__init__(cursor) + self.model = model + self.select = columns or [] + + def _initialize_columns(self): + combined = self.model._meta.combined + table = self.model._meta.table + description = self.cursor.description + + self.ncols = len(self.cursor.description) + self.columns = [] + self.converters = converters = [None] * self.ncols + self.fields = fields = [None] * self.ncols + + for idx, description_item in enumerate(description): + column = description_item[0] + dot_index = column.find('.') + if dot_index != -1: + column = column[dot_index + 1:] + + column = column.strip('")') + self.columns.append(column) + try: + raw_node = self.select[idx] + except IndexError: + if column in combined: + raw_node = node = combined[column] + else: + continue + else: + node = raw_node.unwrap() + + # Heuristics used to attempt to get the field associated with a + # given SELECT column, so that we can accurately convert the value + # returned by the database-cursor into a Python object. + if isinstance(node, Field): + if raw_node._coerce: + converters[idx] = node.python_value + fields[idx] = node + if not raw_node.is_alias(): + self.columns[idx] = node.name + elif isinstance(node, ColumnBase) and raw_node._converter: + converters[idx] = raw_node._converter + elif isinstance(node, Function) and node._coerce: + if node._python_value is not None: + converters[idx] = node._python_value + elif node.arguments and isinstance(node.arguments[0], Node): + # If the first argument is a field or references a column + # on a Model, try using that field's conversion function. + # This usually works, but we use "safe_python_value()" so + # that if a TypeError or ValueError occurs during + # conversion we can just fall-back to the raw cursor value. + first = node.arguments[0].unwrap() + if isinstance(first, Entity): + path = first._path[-1] # Try to look-up by name. + first = combined.get(path) + if isinstance(first, Field): + converters[idx] = safe_python_value(first.python_value) + elif column in combined: + if node._coerce: + converters[idx] = combined[column].python_value + if isinstance(node, Column) and node.source == table: + fields[idx] = combined[column] + + initialize = _initialize_columns + + def process_row(self, row): + raise NotImplementedError + + +class ModelDictCursorWrapper(BaseModelCursorWrapper): + def process_row(self, row): + result = {} + columns, converters = self.columns, self.converters + fields = self.fields + + for i in range(self.ncols): + attr = columns[i] + if attr in result: continue # Don't overwrite if we have dupes. + if converters[i] is not None: + result[attr] = converters[i](row[i]) + else: + result[attr] = row[i] + + return result + + +class ModelTupleCursorWrapper(ModelDictCursorWrapper): + constructor = tuple + + def process_row(self, row): + columns, converters = self.columns, self.converters + return self.constructor([ + (converters[i](row[i]) if converters[i] is not None else row[i]) + for i in range(self.ncols)]) + + +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): + def initialize(self): + self._initialize_columns() + attributes = [] + for i in range(self.ncols): + attributes.append(self.columns[i]) + self.tuple_class = collections.namedtuple('Row', attributes) + self.constructor = lambda row: self.tuple_class(*row) + + +class ModelObjectCursorWrapper(ModelDictCursorWrapper): + def __init__(self, cursor, model, select, constructor): + self.constructor = constructor + self.is_model = is_model(constructor) + super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) + + def process_row(self, row): + data = super(ModelObjectCursorWrapper, self).process_row(row) + if self.is_model: + # Clear out any dirty fields before returning to the user. + obj = self.constructor(__no_default__=1, **data) + obj._dirty.clear() + return obj + else: + return self.constructor(**data) + + +class ModelCursorWrapper(BaseModelCursorWrapper): + def __init__(self, cursor, model, select, from_list, joins): + super(ModelCursorWrapper, self).__init__(cursor, model, select) + self.from_list = from_list + self.joins = joins + + def initialize(self): + self._initialize_columns() + selected_src = set([field.model for field in self.fields + if field is not None]) + select, columns = self.select, self.columns + + self.key_to_constructor = {self.model: self.model} + self.src_is_dest = {} + self.src_to_dest = [] + accum = collections.deque(self.from_list) + dests = set() + + while accum: + curr = accum.popleft() + if isinstance(curr, Join): + accum.append(curr.lhs) + accum.append(curr.rhs) + continue + + if curr not in self.joins: + continue + + is_dict = isinstance(curr, dict) + for key, attr, constructor, join_type in self.joins[curr]: + if key not in self.key_to_constructor: + self.key_to_constructor[key] = constructor + + # (src, attr, dest, is_dict, join_type). + self.src_to_dest.append((curr, attr, key, is_dict, + join_type)) + dests.add(key) + accum.append(key) + + # Ensure that we accommodate everything selected. + for src in selected_src: + if src not in self.key_to_constructor: + if is_model(src): + self.key_to_constructor[src] = src + elif isinstance(src, ModelAlias): + self.key_to_constructor[src] = src.model + + # Indicate which sources are also dests. + for src, _, dest, _, _ in self.src_to_dest: + self.src_is_dest[src] = src in dests and (dest in selected_src + or src in selected_src) + + self.column_keys = [] + for idx, node in enumerate(select): + key = self.model + field = self.fields[idx] + if field is not None: + if isinstance(field, FieldAlias): + key = field.source + else: + key = field.model + else: + if isinstance(node, Node): + node = node.unwrap() + if isinstance(node, Column): + key = node.source + + self.column_keys.append(key) + + def process_row(self, row): + objects = {} + object_list = [] + for key, constructor in self.key_to_constructor.items(): + objects[key] = constructor(__no_default__=True) + object_list.append(objects[key]) + + default_instance = objects[self.model] + + set_keys = set() + for idx, key in enumerate(self.column_keys): + # Get the instance corresponding to the selected column/value, + # falling back to the "root" model instance. + instance = objects.get(key, default_instance) + column = self.columns[idx] + value = row[idx] + if value is not None: + set_keys.add(key) + if self.converters[idx]: + value = self.converters[idx](value) + + if isinstance(instance, dict): + instance[column] = value + else: + setattr(instance, column, value) + + # Need to do some analysis on the joins before this. + for (src, attr, dest, is_dict, join_type) in self.src_to_dest: + instance = objects[src] + try: + joined_instance = objects[dest] + except KeyError: + continue + + # If no fields were set on the destination instance then do not + # assign an "empty" instance. + if instance is None or dest is None or \ + (dest not in set_keys and not self.src_is_dest.get(dest)): + continue + + # If no fields were set on either the source or the destination, + # then we have nothing to do here. + if instance not in set_keys and dest not in set_keys \ + and join_type.endswith('OUTER JOIN'): + continue + + if is_dict: + instance[attr] = joined_instance + else: + setattr(instance, attr, joined_instance) + + # When instantiating models from a cursor, we clear the dirty fields. + for instance in object_list: + if isinstance(instance, Model): + instance._dirty.clear() + + return objects[self.model] + + +class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( + 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): + def __new__(cls, query, fields=None, is_backref=None, rel_models=None, + field_to_name=None, model=None): + if fields: + if is_backref: + if rel_models is None: + rel_models = [field.model for field in fields] + foreign_key_attrs = [field.rel_field.name for field in fields] + else: + if rel_models is None: + rel_models = [field.rel_model for field in fields] + foreign_key_attrs = [field.name for field in fields] + field_to_name = list(zip(fields, foreign_key_attrs)) + model = query.model + return super(PrefetchQuery, cls).__new__( + cls, query, fields, is_backref, rel_models, field_to_name, model) + + def populate_instance(self, instance, id_map): + if self.is_backref: + for field in self.fields: + identifier = instance.__data__[field.name] + key = (field, identifier) + if key in id_map: + setattr(instance, field.name, id_map[key]) + else: + for field, attname in self.field_to_name: + identifier = instance.__data__[field.rel_field.name] + key = (field, identifier) + rel_instances = id_map.get(key, []) + for inst in rel_instances: + setattr(inst, attname, instance) + inst._dirty.clear() + setattr(instance, field.backref, rel_instances) + + def store_instance(self, instance, id_map): + for field, attname in self.field_to_name: + identity = field.rel_field.python_value(instance.__data__[attname]) + key = (field, identity) + if self.is_backref: + id_map[key] = instance + else: + id_map.setdefault(key, []) + id_map[key].append(instance) + + +def prefetch_add_subquery(sq, subqueries): + fixed_queries = [PrefetchQuery(sq)] + for i, subquery in enumerate(subqueries): + if isinstance(subquery, tuple): + subquery, target_model = subquery + else: + target_model = None + if not isinstance(subquery, Query) and is_model(subquery) or \ + isinstance(subquery, ModelAlias): + subquery = subquery.select() + subquery_model = subquery.model + fks = backrefs = None + for j in reversed(range(i + 1)): + fixed = fixed_queries[j] + last_query = fixed.query + last_model = last_obj = fixed.model + if isinstance(last_model, ModelAlias): + last_model = last_model.model + rels = subquery_model._meta.model_refs.get(last_model, []) + if rels: + fks = [getattr(subquery_model, fk.name) for fk in rels] + pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] + else: + backrefs = subquery_model._meta.model_backrefs.get(last_model) + if (fks or backrefs) and ((target_model is last_obj) or + (target_model is None)): + break + + if not fks and not backrefs: + tgt_err = ' using %s' % target_model if target_model else '' + raise AttributeError('Error: unable to find foreign key for ' + 'query: %s%s' % (subquery, tgt_err)) + + dest = (target_model,) if target_model else None + + if fks: + expr = reduce(operator.or_, [ + (fk << last_query.select(pk)) + for (fk, pk) in zip(fks, pks)]) + subquery = subquery.where(expr) + fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) + elif backrefs: + expressions = [] + for backref in backrefs: + rel_field = getattr(subquery_model, backref.rel_field.name) + fk_field = getattr(last_obj, backref.name) + expressions.append(rel_field << last_query.select(fk_field)) + subquery = subquery.where(reduce(operator.or_, expressions)) + fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) + + return fixed_queries + + +def prefetch(sq, *subqueries): + if not subqueries: + return sq + + fixed_queries = prefetch_add_subquery(sq, subqueries) + deps = {} + rel_map = {} + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps.setdefault(query_model, {}) + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + + for instance in pq.query: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return list(pq.query) diff --git a/libs/playhouse/README.md b/libs/playhouse/README.md new file mode 100644 index 000000000..faebd6902 --- /dev/null +++ b/libs/playhouse/README.md @@ -0,0 +1,48 @@ +## Playhouse + +The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection. + +### Vendor extensions + +* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html) + * Full-text search (FTS3/4/5) + * BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4 + * Virtual tables and C extensions + * Closure tables + * JSON extension support + * LSM1 (key/value database) support + * BLOB API + * Online backup API +* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver. +* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases. +* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/). +* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext) + * JSON and JSONB + * HStore + * Arrays + * Server-side cursors + * Full-text search +* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext) + +### High-level libraries + +* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields) + * Compressed field + * PickleField +* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts) + * Model-to-dict serializer + * Dict-to-model deserializer +* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid) +* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init. +* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/). +* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries. + +### Database management and framework support + +* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database. +* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite. +* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling. +* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection +* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database +* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications. +* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more. diff --git a/libs/playhouse/__init__.py b/libs/playhouse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/playhouse/_pysqlite/cache.h b/libs/playhouse/_pysqlite/cache.h new file mode 100644 index 000000000..06f957a77 --- /dev/null +++ b/libs/playhouse/_pysqlite/cache.h @@ -0,0 +1,73 @@ +/* cache.h - definitions for the LRU cache + * + * Copyright (C) 2004-2015 Gerhard Häring + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */ + +#ifndef PYSQLITE_CACHE_H +#define PYSQLITE_CACHE_H +#include "Python.h" + +/* The LRU cache is implemented as a combination of a doubly-linked with a + * dictionary. The list items are of type 'Node' and the dictionary has the + * nodes as values. */ + +typedef struct _pysqlite_Node +{ + PyObject_HEAD + PyObject* key; + PyObject* data; + long count; + struct _pysqlite_Node* prev; + struct _pysqlite_Node* next; +} pysqlite_Node; + +typedef struct +{ + PyObject_HEAD + int size; + + /* a dictionary mapping keys to Node entries */ + PyObject* mapping; + + /* the factory callable */ + PyObject* factory; + + pysqlite_Node* first; + pysqlite_Node* last; + + /* if set, decrement the factory function when the Cache is deallocated. + * this is almost always desirable, but not in the pysqlite context */ + int decref_factory; +} pysqlite_Cache; + +extern PyTypeObject pysqlite_NodeType; +extern PyTypeObject pysqlite_CacheType; + +int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs); +void pysqlite_node_dealloc(pysqlite_Node* self); + +int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs); +void pysqlite_cache_dealloc(pysqlite_Cache* self); +PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args); + +int pysqlite_cache_setup_types(void); + +#endif diff --git a/libs/playhouse/_pysqlite/connection.h b/libs/playhouse/_pysqlite/connection.h new file mode 100644 index 000000000..d35c13f9a --- /dev/null +++ b/libs/playhouse/_pysqlite/connection.h @@ -0,0 +1,129 @@ +/* connection.h - definitions for the connection type + * + * Copyright (C) 2004-2015 Gerhard Häring + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */ + +#ifndef PYSQLITE_CONNECTION_H +#define PYSQLITE_CONNECTION_H +#include "Python.h" +#include "pythread.h" +#include "structmember.h" + +#include "cache.h" +#include "module.h" + +#include "sqlite3.h" + +typedef struct +{ + PyObject_HEAD + sqlite3* db; + + /* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a + * bitwise combination thereof makes sense */ + int detect_types; + + /* the timeout value in seconds for database locks */ + double timeout; + + /* for internal use in the timeout handler: when did the timeout handler + * first get called with count=0? */ + double timeout_started; + + /* None for autocommit, otherwise a PyString with the isolation level */ + PyObject* isolation_level; + + /* NULL for autocommit, otherwise a string with the BEGIN statement; will be + * freed in connection destructor */ + char* begin_statement; + + /* 1 if a check should be performed for each API call if the connection is + * used from the same thread it was created in */ + int check_same_thread; + + int initialized; + + /* thread identification of the thread the connection was created in */ + long thread_ident; + + pysqlite_Cache* statement_cache; + + /* Lists of weak references to statements and cursors used within this connection */ + PyObject* statements; + PyObject* cursors; + + /* Counters for how many statements/cursors were created in the connection. May be + * reset to 0 at certain intervals */ + int created_statements; + int created_cursors; + + PyObject* row_factory; + + /* Determines how bytestrings from SQLite are converted to Python objects: + * - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings + * - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created. + * - PyString_Type: PyStrings are created as-is. + * - Any custom callable: Any object returned from the callable called with the bytestring + * as single parameter. + */ + PyObject* text_factory; + + /* remember references to functions/classes used in + * create_function/create/aggregate, use these as dictionary keys, so we + * can keep the total system refcount constant by clearing that dictionary + * in connection_dealloc */ + PyObject* function_pinboard; + + /* a dictionary of registered collation name => collation callable mappings */ + PyObject* collations; + + /* Exception objects */ + PyObject* Warning; + PyObject* Error; + PyObject* InterfaceError; + PyObject* DatabaseError; + PyObject* DataError; + PyObject* OperationalError; + PyObject* IntegrityError; + PyObject* InternalError; + PyObject* ProgrammingError; + PyObject* NotSupportedError; +} pysqlite_Connection; + +extern PyTypeObject pysqlite_ConnectionType; + +PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware); +void pysqlite_connection_dealloc(pysqlite_Connection* self); +PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); +PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args); +PyObject* _pysqlite_connection_begin(pysqlite_Connection* self); +PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args); +PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args); +PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw); +int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs); + +int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor); +int pysqlite_check_thread(pysqlite_Connection* self); +int pysqlite_check_connection(pysqlite_Connection* con); + +int pysqlite_connection_setup_types(void); + +#endif diff --git a/libs/playhouse/_pysqlite/module.h b/libs/playhouse/_pysqlite/module.h new file mode 100644 index 000000000..08c566257 --- /dev/null +++ b/libs/playhouse/_pysqlite/module.h @@ -0,0 +1,58 @@ +/* module.h - definitions for the module + * + * Copyright (C) 2004-2015 Gerhard Häring + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */ + +#ifndef PYSQLITE_MODULE_H +#define PYSQLITE_MODULE_H +#include "Python.h" + +#define PYSQLITE_VERSION "2.8.2" + +extern PyObject* pysqlite_Error; +extern PyObject* pysqlite_Warning; +extern PyObject* pysqlite_InterfaceError; +extern PyObject* pysqlite_DatabaseError; +extern PyObject* pysqlite_InternalError; +extern PyObject* pysqlite_OperationalError; +extern PyObject* pysqlite_ProgrammingError; +extern PyObject* pysqlite_IntegrityError; +extern PyObject* pysqlite_DataError; +extern PyObject* pysqlite_NotSupportedError; + +extern PyObject* pysqlite_OptimizedUnicode; + +/* the functions time.time() and time.sleep() */ +extern PyObject* time_time; +extern PyObject* time_sleep; + +/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter + * functions, that convert the SQL value to the appropriate Python value. + * The key is uppercase. + */ +extern PyObject* converters; + +extern int _enable_callback_tracebacks; +extern int pysqlite_BaseTypeAdapted; + +#define PARSE_DECLTYPES 1 +#define PARSE_COLNAMES 2 +#endif diff --git a/libs/playhouse/_sqlite_ext.pyx b/libs/playhouse/_sqlite_ext.pyx new file mode 100644 index 000000000..3e6408016 --- /dev/null +++ b/libs/playhouse/_sqlite_ext.pyx @@ -0,0 +1,1595 @@ +import hashlib +import zlib + +cimport cython +from cpython cimport datetime +from cpython.bytes cimport PyBytes_AsStringAndSize +from cpython.bytes cimport PyBytes_Check +from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.bytes cimport PyBytes_AS_STRING +from cpython.object cimport PyObject +from cpython.ref cimport Py_INCREF, Py_DECREF +from cpython.unicode cimport PyUnicode_AsUTF8String +from cpython.unicode cimport PyUnicode_Check +from cpython.unicode cimport PyUnicode_DecodeUTF8 +from cpython.version cimport PY_MAJOR_VERSION +from libc.float cimport DBL_MAX +from libc.math cimport ceil, log, sqrt +from libc.math cimport pow as cpow +#from libc.stdint cimport ssize_t +from libc.stdint cimport uint8_t +from libc.stdint cimport uint32_t +from libc.stdlib cimport calloc, free, malloc, rand +from libc.string cimport memcpy, memset, strlen + +from peewee import InterfaceError +from peewee import Node +from peewee import OperationalError +from peewee import sqlite3 as pysqlite + +import traceback + + +cdef struct sqlite3_index_constraint: + int iColumn # Column constrained, -1 for rowid. + unsigned char op # Constraint operator. + unsigned char usable # True if this constraint is usable. + int iTermOffset # Used internally - xBestIndex should ignore. + + +cdef struct sqlite3_index_orderby: + int iColumn + unsigned char desc + + +cdef struct sqlite3_index_constraint_usage: + int argvIndex # if > 0, constraint is part of argv to xFilter. + unsigned char omit + + +cdef extern from "sqlite3.h" nogil: + ctypedef struct sqlite3: + int busyTimeout + ctypedef struct sqlite3_backup + ctypedef struct sqlite3_blob + ctypedef struct sqlite3_context + ctypedef struct sqlite3_value + ctypedef long long sqlite3_int64 + ctypedef unsigned long long sqlite_uint64 + + # Virtual tables. + ctypedef struct sqlite3_module # Forward reference. + ctypedef struct sqlite3_vtab: + const sqlite3_module *pModule + int nRef + char *zErrMsg + ctypedef struct sqlite3_vtab_cursor: + sqlite3_vtab *pVtab + + ctypedef struct sqlite3_index_info: + int nConstraint + sqlite3_index_constraint *aConstraint + int nOrderBy + sqlite3_index_orderby *aOrderBy + sqlite3_index_constraint_usage *aConstraintUsage + int idxNum + char *idxStr + int needToFreeIdxStr + int orderByConsumed + double estimatedCost + sqlite3_int64 estimatedRows + int idxFlags + + ctypedef struct sqlite3_module: + int iVersion + int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv, + sqlite3_vtab **ppVTab, char**) + int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv, + sqlite3_vtab **ppVTab, char**) + int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*) + int (*xDisconnect)(sqlite3_vtab *pVTab) + int (*xDestroy)(sqlite3_vtab *pVTab) + int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) + int (*xClose)(sqlite3_vtab_cursor*) + int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, + int argc, sqlite3_value **argv) + int (*xNext)(sqlite3_vtab_cursor*) + int (*xEof)(sqlite3_vtab_cursor*) + int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context *, int) + int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid) + int (*xUpdate)(sqlite3_vtab *pVTab, int, sqlite3_value **, + sqlite3_int64 **) + int (*xBegin)(sqlite3_vtab *pVTab) + int (*xSync)(sqlite3_vtab *pVTab) + int (*xCommit)(sqlite3_vtab *pVTab) + int (*xRollback)(sqlite3_vtab *pVTab) + int (*xFindFunction)(sqlite3_vtab *pVTab, int nArg, const char *zName, + void (**pxFunc)(sqlite3_context *, int, + sqlite3_value **), + void **ppArg) + int (*xRename)(sqlite3_vtab *pVTab, const char *zNew) + int (*xSavepoint)(sqlite3_vtab *pVTab, int) + int (*xRelease)(sqlite3_vtab *pVTab, int) + int (*xRollbackTo)(sqlite3_vtab *pVTab, int) + + cdef int sqlite3_declare_vtab(sqlite3 *db, const char *zSQL) + cdef int sqlite3_create_module(sqlite3 *db, const char *zName, + const sqlite3_module *p, void *pClientData) + + cdef const char sqlite3_version[] + + # Encoding. + cdef int SQLITE_UTF8 = 1 + + # Return values. + cdef int SQLITE_OK = 0 + cdef int SQLITE_ERROR = 1 + cdef int SQLITE_INTERNAL = 2 + cdef int SQLITE_PERM = 3 + cdef int SQLITE_ABORT = 4 + cdef int SQLITE_BUSY = 5 + cdef int SQLITE_LOCKED = 6 + cdef int SQLITE_NOMEM = 7 + cdef int SQLITE_READONLY = 8 + cdef int SQLITE_INTERRUPT = 9 + cdef int SQLITE_DONE = 101 + + # Function type. + cdef int SQLITE_DETERMINISTIC = 0x800 + + # Types of filtering operations. + cdef int SQLITE_INDEX_CONSTRAINT_EQ = 2 + cdef int SQLITE_INDEX_CONSTRAINT_GT = 4 + cdef int SQLITE_INDEX_CONSTRAINT_LE = 8 + cdef int SQLITE_INDEX_CONSTRAINT_LT = 16 + cdef int SQLITE_INDEX_CONSTRAINT_GE = 32 + cdef int SQLITE_INDEX_CONSTRAINT_MATCH = 64 + + # sqlite_value_type. + cdef int SQLITE_INTEGER = 1 + cdef int SQLITE_FLOAT = 2 + cdef int SQLITE3_TEXT = 3 + cdef int SQLITE_TEXT = 3 + cdef int SQLITE_BLOB = 4 + cdef int SQLITE_NULL = 5 + + ctypedef void (*sqlite3_destructor_type)(void*) + + # Converting from Sqlite -> Python. + cdef const void *sqlite3_value_blob(sqlite3_value*) + cdef int sqlite3_value_bytes(sqlite3_value*) + cdef double sqlite3_value_double(sqlite3_value*) + cdef int sqlite3_value_int(sqlite3_value*) + cdef sqlite3_int64 sqlite3_value_int64(sqlite3_value*) + cdef const unsigned char *sqlite3_value_text(sqlite3_value*) + cdef int sqlite3_value_type(sqlite3_value*) + cdef int sqlite3_value_numeric_type(sqlite3_value*) + + # Converting from Python -> Sqlite. + cdef void sqlite3_result_blob(sqlite3_context*, const void *, int, + void(*)(void*)) + cdef void sqlite3_result_double(sqlite3_context*, double) + cdef void sqlite3_result_error(sqlite3_context*, const char*, int) + cdef void sqlite3_result_error_toobig(sqlite3_context*) + cdef void sqlite3_result_error_nomem(sqlite3_context*) + cdef void sqlite3_result_error_code(sqlite3_context*, int) + cdef void sqlite3_result_int(sqlite3_context*, int) + cdef void sqlite3_result_int64(sqlite3_context*, sqlite3_int64) + cdef void sqlite3_result_null(sqlite3_context*) + cdef void sqlite3_result_text(sqlite3_context*, const char*, int, + void(*)(void*)) + cdef void sqlite3_result_value(sqlite3_context*, sqlite3_value*) + + # Memory management. + cdef void* sqlite3_malloc(int) + cdef void sqlite3_free(void *) + + cdef int sqlite3_changes(sqlite3 *db) + cdef int sqlite3_get_autocommit(sqlite3 *db) + cdef sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *db) + + cdef void *sqlite3_commit_hook(sqlite3 *, int(*)(void *), void *) + cdef void *sqlite3_rollback_hook(sqlite3 *, void(*)(void *), void *) + cdef void *sqlite3_update_hook( + sqlite3 *, + void(*)(void *, int, char *, char *, sqlite3_int64), + void *) + + cdef int SQLITE_STATUS_MEMORY_USED = 0 + cdef int SQLITE_STATUS_PAGECACHE_USED = 1 + cdef int SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 + cdef int SQLITE_STATUS_SCRATCH_USED = 3 + cdef int SQLITE_STATUS_SCRATCH_OVERFLOW = 4 + cdef int SQLITE_STATUS_MALLOC_SIZE = 5 + cdef int SQLITE_STATUS_PARSER_STACK = 6 + cdef int SQLITE_STATUS_PAGECACHE_SIZE = 7 + cdef int SQLITE_STATUS_SCRATCH_SIZE = 8 + cdef int SQLITE_STATUS_MALLOC_COUNT = 9 + cdef int sqlite3_status(int op, int *pCurrent, int *pHighwater, int resetFlag) + + cdef int SQLITE_DBSTATUS_LOOKASIDE_USED = 0 + cdef int SQLITE_DBSTATUS_CACHE_USED = 1 + cdef int SQLITE_DBSTATUS_SCHEMA_USED = 2 + cdef int SQLITE_DBSTATUS_STMT_USED = 3 + cdef int SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 + cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 + cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 + cdef int SQLITE_DBSTATUS_CACHE_HIT = 7 + cdef int SQLITE_DBSTATUS_CACHE_MISS = 8 + cdef int SQLITE_DBSTATUS_CACHE_WRITE = 9 + cdef int SQLITE_DBSTATUS_DEFERRED_FKS = 10 + #cdef int SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 + cdef int sqlite3_db_status(sqlite3 *, int op, int *pCur, int *pHigh, int reset) + + cdef int SQLITE_DELETE = 9 + cdef int SQLITE_INSERT = 18 + cdef int SQLITE_UPDATE = 23 + + cdef int SQLITE_CONFIG_SINGLETHREAD = 1 # None + cdef int SQLITE_CONFIG_MULTITHREAD = 2 # None + cdef int SQLITE_CONFIG_SERIALIZED = 3 # None + cdef int SQLITE_CONFIG_SCRATCH = 6 # void *, int sz, int N + cdef int SQLITE_CONFIG_PAGECACHE = 7 # void *, int sz, int N + cdef int SQLITE_CONFIG_HEAP = 8 # void *, int nByte, int min + cdef int SQLITE_CONFIG_MEMSTATUS = 9 # boolean + cdef int SQLITE_CONFIG_LOOKASIDE = 13 # int, int + cdef int SQLITE_CONFIG_URI = 17 # int + cdef int SQLITE_CONFIG_MMAP_SIZE = 22 # sqlite3_int64, sqlite3_int64 + cdef int SQLITE_CONFIG_STMTJRNL_SPILL = 26 # int nByte + cdef int SQLITE_DBCONFIG_MAINDBNAME = 1000 # const char* + cdef int SQLITE_DBCONFIG_LOOKASIDE = 1001 # void* int int + cdef int SQLITE_DBCONFIG_ENABLE_FKEY = 1002 # int int* + cdef int SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003 # int int* + cdef int SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004 # int int* + cdef int SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005 # int int* + cdef int SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006 # int int* + cdef int SQLITE_DBCONFIG_ENABLE_QPSG = 1007 # int int* + + cdef int sqlite3_config(int, ...) + cdef int sqlite3_db_config(sqlite3*, int op, ...) + + # Misc. + cdef int sqlite3_busy_handler(sqlite3 *db, int(*)(void *, int), void *) + cdef int sqlite3_sleep(int ms) + cdef sqlite3_backup *sqlite3_backup_init( + sqlite3 *pDest, + const char *zDestName, + sqlite3 *pSource, + const char *zSourceName) + + # Backup. + cdef int sqlite3_backup_step(sqlite3_backup *p, int nPage) + cdef int sqlite3_backup_finish(sqlite3_backup *p) + cdef int sqlite3_backup_remaining(sqlite3_backup *p) + cdef int sqlite3_backup_pagecount(sqlite3_backup *p) + + # Error handling. + cdef int sqlite3_errcode(sqlite3 *db) + cdef int sqlite3_errstr(int) + cdef const char *sqlite3_errmsg(sqlite3 *db) + + cdef int sqlite3_blob_open( + sqlite3*, + const char *zDb, + const char *zTable, + const char *zColumn, + sqlite3_int64 iRow, + int flags, + sqlite3_blob **ppBlob) + cdef int sqlite3_blob_reopen(sqlite3_blob *, sqlite3_int64) + cdef int sqlite3_blob_close(sqlite3_blob *) + cdef int sqlite3_blob_bytes(sqlite3_blob *) + cdef int sqlite3_blob_read(sqlite3_blob *, void *Z, int N, int iOffset) + cdef int sqlite3_blob_write(sqlite3_blob *, const void *z, int n, + int iOffset) + + +cdef extern from "_pysqlite/connection.h": + ctypedef struct pysqlite_Connection: + sqlite3* db + double timeout + int initialized + + +cdef sqlite_to_python(int argc, sqlite3_value **params): + cdef: + int i + int vtype + list pyargs = [] + + for i in range(argc): + vtype = sqlite3_value_type(params[i]) + if vtype == SQLITE_INTEGER: + pyval = sqlite3_value_int(params[i]) + elif vtype == SQLITE_FLOAT: + pyval = sqlite3_value_double(params[i]) + elif vtype == SQLITE_TEXT: + pyval = PyUnicode_DecodeUTF8( + sqlite3_value_text(params[i]), + sqlite3_value_bytes(params[i]), NULL) + elif vtype == SQLITE_BLOB: + pyval = PyBytes_FromStringAndSize( + sqlite3_value_blob(params[i]), + sqlite3_value_bytes(params[i])) + elif vtype == SQLITE_NULL: + pyval = None + else: + pyval = None + + pyargs.append(pyval) + + return pyargs + + +cdef python_to_sqlite(sqlite3_context *context, value): + if value is None: + sqlite3_result_null(context) + elif isinstance(value, (int, long)): + sqlite3_result_int64(context, value) + elif isinstance(value, float): + sqlite3_result_double(context, value) + elif isinstance(value, unicode): + bval = PyUnicode_AsUTF8String(value) + sqlite3_result_text( + context, + bval, + len(bval), + -1) + elif isinstance(value, bytes): + if PY_MAJOR_VERSION > 2: + sqlite3_result_blob( + context, + (value), + len(value), + -1) + else: + sqlite3_result_text( + context, + value, + len(value), + -1) + else: + sqlite3_result_error( + context, + encode('Unsupported type %s' % type(value)), + -1) + return SQLITE_ERROR + + return SQLITE_OK + + +cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. + +USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' + +# The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field +# to store a reference to the Python implementation. +ctypedef struct peewee_vtab: + sqlite3_vtab base + void *table_func_cls + + +# Like peewee_vtab, the peewee_cursor embeds the base sqlite3_vtab_cursor and +# adds fields to store references to the current index, the Python +# implementation, the current rows' data, and a flag for whether the cursor has +# been exhausted. +ctypedef struct peewee_cursor: + sqlite3_vtab_cursor base + long long idx + void *table_func + void *row_data + bint stopped + + +# We define an xConnect function, but leave xCreate NULL so that the +# table-function can be called eponymously. +cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, + sqlite3_vtab **ppVtab, char **pzErr) with gil: + cdef: + int rc + object table_func_cls = pAux + peewee_vtab *pNew = 0 + + rc = sqlite3_declare_vtab( + db, + encode('CREATE TABLE x(%s);' % + table_func_cls.get_table_columns_declaration())) + if rc == SQLITE_OK: + pNew = sqlite3_malloc(sizeof(pNew[0])) + memset(pNew, 0, sizeof(pNew[0])) + ppVtab[0] = &(pNew.base) + + pNew.table_func_cls = table_func_cls + Py_INCREF(table_func_cls) + + return rc + + +cdef int pwDisconnect(sqlite3_vtab *pBase) with gil: + cdef: + peewee_vtab *pVtab = pBase + object table_func_cls = (pVtab.table_func_cls) + + Py_DECREF(table_func_cls) + sqlite3_free(pVtab) + return SQLITE_OK + + +# The xOpen method is used to initialize a cursor. In this method we +# instantiate the TableFunction class and zero out a new cursor for iteration. +cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil: + cdef: + peewee_vtab *pVtab = pBase + peewee_cursor *pCur = 0 + object table_func_cls = pVtab.table_func_cls + + pCur = sqlite3_malloc(sizeof(pCur[0])) + memset(pCur, 0, sizeof(pCur[0])) + ppCursor[0] = &(pCur.base) + pCur.idx = 0 + try: + table_func = table_func_cls() + except: + if table_func_cls.print_tracebacks: + traceback.print_exc() + sqlite3_free(pCur) + return SQLITE_ERROR + + Py_INCREF(table_func) + pCur.table_func = table_func + pCur.stopped = False + return SQLITE_OK + + +cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil: + cdef: + peewee_cursor *pCur = pBase + object table_func = pCur.table_func + Py_DECREF(table_func) + sqlite3_free(pCur) + return SQLITE_OK + + +# Iterate once, advancing the cursor's index and assigning the row data to the +# `row_data` field on the peewee_cursor struct. +cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil: + cdef: + peewee_cursor *pCur = pBase + object table_func = pCur.table_func + tuple result + + if pCur.row_data: + Py_DECREF(pCur.row_data) + + pCur.row_data = NULL + try: + result = tuple(table_func.iterate(pCur.idx)) + except StopIteration: + pCur.stopped = True + except: + if table_func.print_tracebacks: + traceback.print_exc() + return SQLITE_ERROR + else: + Py_INCREF(result) + pCur.row_data = result + pCur.idx += 1 + pCur.stopped = False + + return SQLITE_OK + + +# Return the requested column from the current row. +cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, + int iCol) with gil: + cdef: + bytes bval + peewee_cursor *pCur = pBase + sqlite3_int64 x = 0 + tuple row_data + + if iCol == -1: + sqlite3_result_int64(ctx, pCur.idx) + return SQLITE_OK + + if not pCur.row_data: + sqlite3_result_error(ctx, encode('no row data'), -1) + return SQLITE_ERROR + + row_data = pCur.row_data + return python_to_sqlite(ctx, row_data[iCol]) + + +cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid): + cdef: + peewee_cursor *pCur = pBase + pRowid[0] = pCur.idx + return SQLITE_OK + + +# Return a boolean indicating whether the cursor has been consumed. +cdef int pwEof(sqlite3_vtab_cursor *pBase): + cdef: + peewee_cursor *pCur = pBase + return 1 if pCur.stopped else 0 + + +# The filter method is called on the first iteration. This method is where we +# get access to the parameters that the function was called with, and call the +# TableFunction's `initialize()` function. +cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) with gil: + cdef: + peewee_cursor *pCur = pBase + object table_func = pCur.table_func + dict query = {} + int idx + int value_type + tuple row_data + void *row_data_raw + + if not idxStr or argc == 0 and len(table_func.params): + return SQLITE_ERROR + elif len(idxStr): + params = decode(idxStr).split(',') + else: + params = [] + + py_values = sqlite_to_python(argc, argv) + + for idx, param in enumerate(params): + value = argv[idx] + if not value: + query[param] = None + else: + query[param] = py_values[idx] + + try: + table_func.initialize(**query) + except: + if table_func.print_tracebacks: + traceback.print_exc() + return SQLITE_ERROR + + pCur.stopped = False + try: + row_data = tuple(table_func.iterate(0)) + except StopIteration: + pCur.stopped = True + except: + if table_func.print_tracebacks: + traceback.print_exc() + return SQLITE_ERROR + else: + Py_INCREF(row_data) + pCur.row_data = row_data + pCur.idx += 1 + return SQLITE_OK + + +# SQLite will (in some cases, repeatedly) call the xBestIndex method to try and +# find the best query plan. +cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ + with gil: + cdef: + int i + int idxNum = 0, nArg = 0 + peewee_vtab *pVtab = pBase + object table_func_cls = pVtab.table_func_cls + sqlite3_index_constraint *pConstraint = 0 + list columns = [] + char *idxStr + int nParams = len(table_func_cls.params) + + for i in range(pIdxInfo.nConstraint): + pConstraint = pIdxInfo.aConstraint + i + if not pConstraint.usable: + continue + if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: + continue + + columns.append(table_func_cls.params[pConstraint.iColumn - + table_func_cls._ncols]) + nArg += 1 + pIdxInfo.aConstraintUsage[i].argvIndex = nArg + pIdxInfo.aConstraintUsage[i].omit = 1 + + if nArg > 0 or nParams == 0: + if nArg == nParams: + # All parameters are present, this is ideal. + pIdxInfo.estimatedCost = 1 + pIdxInfo.estimatedRows = 10 + else: + # Penalize score based on number of missing params. + pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) + pIdxInfo.estimatedRows = 10 ** (nParams - nArg) + + # Store a reference to the columns in the index info structure. + joinedCols = encode(','.join(columns)) + idxStr = sqlite3_malloc((len(joinedCols) + 1) * sizeof(char)) + memcpy(idxStr, joinedCols, len(joinedCols)) + idxStr[len(joinedCols)] = '\x00' + pIdxInfo.idxStr = idxStr + pIdxInfo.needToFreeIdxStr = 0 + elif USE_SQLITE_CONSTRAINT: + return SQLITE_CONSTRAINT + else: + pIdxInfo.estimatedCost = DBL_MAX + pIdxInfo.estimatedRows = 100000 + return SQLITE_OK + + +cdef class _TableFunctionImpl(object): + cdef: + sqlite3_module module + object table_function + + def __cinit__(self, table_function): + self.table_function = table_function + + cdef create_module(self, pysqlite_Connection* sqlite_conn): + cdef: + bytes name = encode(self.table_function.name) + sqlite3 *db = sqlite_conn.db + int rc + + # Populate the SQLite module struct members. + self.module.iVersion = 0 + self.module.xCreate = NULL + self.module.xConnect = pwConnect + self.module.xBestIndex = pwBestIndex + self.module.xDisconnect = pwDisconnect + self.module.xDestroy = NULL + self.module.xOpen = pwOpen + self.module.xClose = pwClose + self.module.xFilter = pwFilter + self.module.xNext = pwNext + self.module.xEof = pwEof + self.module.xColumn = pwColumn + self.module.xRowid = pwRowid + self.module.xUpdate = NULL + self.module.xBegin = NULL + self.module.xSync = NULL + self.module.xCommit = NULL + self.module.xRollback = NULL + self.module.xFindFunction = NULL + self.module.xRename = NULL + + # Create the SQLite virtual table. + rc = sqlite3_create_module( + db, + name, + &self.module, + (self.table_function)) + + Py_INCREF(self) + + return rc == SQLITE_OK + + +class TableFunction(object): + columns = None + params = None + name = None + print_tracebacks = True + _ncols = None + + @classmethod + def register(cls, conn): + cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) + impl.create_module(conn) + cls._ncols = len(cls.columns) + + def initialize(self, **filters): + raise NotImplementedError + + def iterate(self, idx): + raise NotImplementedError + + @classmethod + def get_table_columns_declaration(cls): + cdef list accum = [] + + for column in cls.columns: + if isinstance(column, tuple): + if len(column) != 2: + raise ValueError('Column must be either a string or a ' + '2-tuple of name, type') + accum.append('%s %s' % column) + else: + accum.append(column) + + for param in cls.params: + accum.append('%s HIDDEN' % param) + + return ', '.join(accum) + + +cdef tuple SQLITE_DATETIME_FORMATS = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +cdef dict SQLITE_DATE_TRUNC_MAPPING = { + 'year': '%Y', + 'month': '%Y-%m', + 'day': '%Y-%m-%d', + 'hour': '%Y-%m-%d %H', + 'minute': '%Y-%m-%d %H:%M', + 'second': '%Y-%m-%d %H:%M:%S'} + + +cdef tuple validate_and_format_datetime(lookup, date_str): + if not date_str or not lookup: + return + + lookup = lookup.lower() + if lookup not in SQLITE_DATE_TRUNC_MAPPING: + return + + cdef datetime.datetime date_obj + cdef bint success = False + + for date_format in SQLITE_DATETIME_FORMATS: + try: + date_obj = datetime.datetime.strptime(date_str, date_format) + except ValueError: + pass + else: + return (date_obj, lookup) + + +cdef inline bytes encode(key): + cdef bytes bkey + if PyUnicode_Check(key): + bkey = PyUnicode_AsUTF8String(key) + elif PyBytes_Check(key): + bkey = key + elif key is None: + return None + else: + bkey = PyUnicode_AsUTF8String(str(key)) + return bkey + + +cdef inline unicode decode(key): + cdef unicode ukey + if PyBytes_Check(key): + ukey = key.decode('utf-8') + elif PyUnicode_Check(key): + ukey = key + elif key is None: + return None + else: + ukey = unicode(key) + return ukey + + +cdef double *get_weights(int ncol, tuple raw_weights): + cdef: + int argc = len(raw_weights) + int icol + double *weights = malloc(sizeof(double) * ncol) + + for icol in range(ncol): + if argc == 0: + weights[icol] = 1.0 + elif icol < argc: + weights[icol] = raw_weights[icol] + else: + weights[icol] = 0.0 + return weights + + +def peewee_rank(py_match_info, *raw_weights): + cdef: + unsigned int *match_info + unsigned int *phrase_info + bytes _match_info_buf = bytes(py_match_info) + char *match_info_buf = _match_info_buf + int nphrase, ncol, icol, iphrase, hits, global_hits + int P_O = 0, C_O = 1, X_O = 2 + double score = 0.0, weight + double *weights + + match_info = match_info_buf + nphrase = match_info[P_O] + ncol = match_info[C_O] + weights = get_weights(ncol, raw_weights) + + # matchinfo X value corresponds to, for each phrase in the search query, a + # list of 3 values for each column in the search table. + # So if we have a two-phrase search query and three columns of data, the + # following would be the layout: + # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] + # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] + for iphrase in range(nphrase): + phrase_info = &match_info[X_O + iphrase * ncol * 3] + for icol in range(ncol): + weight = weights[icol] + if weight == 0: + continue + + # The idea is that we count the number of times the phrase appears + # in this column of the current row, compared to how many times it + # appears in this column across all rows. The ratio of these values + # provides a rough way to score based on "high value" terms. + hits = phrase_info[3 * icol] + global_hits = phrase_info[3 * icol + 1] + if hits > 0: + score += weight * (hits / global_hits) + + free(weights) + return -1 * score + + +def peewee_lucene(py_match_info, *raw_weights): + # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1) + cdef: + unsigned int *match_info + bytes _match_info_buf = bytes(py_match_info) + char *match_info_buf = _match_info_buf + int nphrase, ncol + double total_docs, term_frequency + double doc_length, docs_with_term, avg_length + double idf, weight, rhs, denom + double *weights + int P_O = 0, C_O = 1, N_O = 2, L_O, X_O + int iphrase, icol, x + double score = 0.0 + + match_info = match_info_buf + nphrase = match_info[P_O] + ncol = match_info[C_O] + total_docs = match_info[N_O] + + L_O = 3 + ncol + X_O = L_O + ncol + weights = get_weights(ncol, raw_weights) + + for iphrase in range(nphrase): + for icol in range(ncol): + weight = weights[icol] + if weight == 0: + continue + doc_length = match_info[L_O + icol] + x = X_O + (3 * (icol + iphrase * ncol)) + term_frequency = match_info[x] # f(qi) + docs_with_term = match_info[x + 2] or 1. # n(qi) + idf = log(total_docs / (docs_with_term + 1.)) + tf = sqrt(term_frequency) + fieldNorms = 1.0 / sqrt(doc_length) + score += (idf * tf * fieldNorms) + + free(weights) + return -1 * score + + +def peewee_bm25(py_match_info, *raw_weights): + # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1) + # where the second parameter is the index of the column and + # the 3rd and 4th specify k and b. + cdef: + unsigned int *match_info + bytes _match_info_buf = bytes(py_match_info) + char *match_info_buf = _match_info_buf + int nphrase, ncol + double B = 0.75, K = 1.2 + double total_docs, term_frequency + double doc_length, docs_with_term, avg_length + double idf, weight, ratio, num, b_part, denom, pc_score + double *weights + int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O + int iphrase, icol, x + double score = 0.0 + + match_info = match_info_buf + # PCNALX = matchinfo format. + # P = 1 = phrase count within query. + # C = 1 = searchable columns in table. + # N = 1 = total rows in table. + # A = c = for each column, avg number of tokens + # L = c = for each column, length of current row (in tokens) + # X = 3 * c * p = for each phrase and table column, + # * phrase count within column for current row. + # * phrase count within column for all rows. + # * total rows for which column contains phrase. + nphrase = match_info[P_O] # n + ncol = match_info[C_O] + total_docs = match_info[N_O] # N + + L_O = A_O + ncol + X_O = L_O + ncol + weights = get_weights(ncol, raw_weights) + + for iphrase in range(nphrase): + for icol in range(ncol): + weight = weights[icol] + if weight == 0: + continue + + x = X_O + (3 * (icol + iphrase * ncol)) + term_frequency = match_info[x] # f(qi, D) + docs_with_term = match_info[x + 2] # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + if idf <= 0.0: + idf = 1e-6 + + doc_length = match_info[L_O + icol] # |D| + avg_length = match_info[A_O + icol] # avgdl + if avg_length == 0: + avg_length = 1 + ratio = doc_length / avg_length + + num = term_frequency * (K + 1) + b_part = 1 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * (num / denom) + score += (pc_score * weight) + + free(weights) + return -1 * score + + +def peewee_bm25f(py_match_info, *raw_weights): + # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1) + # where the second parameter is the index of the column and + # the 3rd and 4th specify k and b. + cdef: + unsigned int *match_info + bytes _match_info_buf = bytes(py_match_info) + char *match_info_buf = _match_info_buf + int nphrase, ncol + double B = 0.75, K = 1.2, epsilon + double total_docs, term_frequency, docs_with_term + double doc_length = 0.0, avg_length = 0.0 + double idf, weight, ratio, num, b_part, denom, pc_score + double *weights + int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O + int iphrase, icol, x + double score = 0.0 + + match_info = match_info_buf + nphrase = match_info[P_O] # n + ncol = match_info[C_O] + total_docs = match_info[N_O] # N + + L_O = A_O + ncol + X_O = L_O + ncol + + for icol in range(ncol): + avg_length += match_info[A_O + icol] + doc_length += match_info[L_O + icol] + + epsilon = 1.0 / (total_docs * avg_length) + if avg_length == 0: + avg_length = 1 + ratio = doc_length / avg_length + weights = get_weights(ncol, raw_weights) + + for iphrase in range(nphrase): + for icol in range(ncol): + weight = weights[icol] + if weight == 0: + continue + + x = X_O + (3 * (icol + iphrase * ncol)) + term_frequency = match_info[x] # f(qi, D) + docs_with_term = match_info[x + 2] # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + idf = epsilon if idf <= 0 else idf + + num = term_frequency * (K + 1) + b_part = 1 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * ((num / denom) + 1.) + score += (pc_score * weight) + + free(weights) + return -1 * score + + +cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen, + uint32_t seed): + cdef: + uint32_t m = 0x5bd1e995 + int r = 24 + const unsigned char *data = key + uint32_t h = seed ^ nlen + uint32_t k + + while nlen >= 4: + k = ((data)[0]) + + k *= m + k = k ^ (k >> r) + k *= m + + h *= m + h = h ^ k + + data += 4 + nlen -= 4 + + if nlen == 3: + h = h ^ (data[2] << 16) + if nlen >= 2: + h = h ^ (data[1] << 8) + if nlen >= 1: + h = h ^ (data[0]) + h *= m + + h = h ^ (h >> 13) + h *= m + h = h ^ (h >> 15) + return h + + +def peewee_murmurhash(key, seed=None): + if key is None: + return + + cdef: + bytes bkey = encode(key) + int nseed = seed or 0 + + if key: + return murmurhash2(bkey, len(bkey), nseed) + return 0 + + +def make_hash(hash_impl): + def inner(*items): + state = hash_impl() + for item in items: + state.update(encode(item)) + return state.hexdigest() + return inner + + +peewee_md5 = make_hash(hashlib.md5) +peewee_sha1 = make_hash(hashlib.sha1) +peewee_sha256 = make_hash(hashlib.sha256) + + +def _register_functions(database, pairs): + for func, name in pairs: + database.register_function(func, name) + + +def register_hash_functions(database): + _register_functions(database, ( + (peewee_murmurhash, 'murmurhash'), + (peewee_md5, 'md5'), + (peewee_sha1, 'sha1'), + (peewee_sha256, 'sha256'), + (zlib.adler32, 'adler32'), + (zlib.crc32, 'crc32'))) + + +def register_rank_functions(database): + _register_functions(database, ( + (peewee_bm25, 'fts_bm25'), + (peewee_bm25f, 'fts_bm25f'), + (peewee_lucene, 'fts_lucene'), + (peewee_rank, 'fts_rank'))) + + +ctypedef struct bf_t: + void *bits + size_t size + +cdef int seeds[10] +seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b] + + +cdef bf_t *bf_create(size_t size): + cdef bf_t *bf = calloc(1, sizeof(bf_t)) + bf.size = size + bf.bits = calloc(1, size) + return bf + +@cython.cdivision(True) +cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed): + cdef: + uint32_t h = murmurhash2(key, klen, seed) + return h % (bf.size * 8) + +@cython.cdivision(True) +cdef bf_add(bf_t *bf, unsigned char *key): + cdef: + uint8_t *bits = (bf.bits) + uint32_t h + int pos, seed + size_t keylen = strlen(key) + + for seed in seeds: + h = bf_bitindex(bf, key, keylen, seed) + pos = h / 8 + bits[pos] = bits[pos] | (1 << (h % 8)) + +@cython.cdivision(True) +cdef int bf_contains(bf_t *bf, unsigned char *key): + cdef: + uint8_t *bits = (bf.bits) + uint32_t h + int pos, seed + size_t keylen = strlen(key) + + for seed in seeds: + h = bf_bitindex(bf, key, keylen, seed) + pos = h / 8 + if not (bits[pos] & (1 << (h % 8))): + return 0 + return 1 + +cdef bf_free(bf_t *bf): + free(bf.bits) + free(bf) + + +cdef class BloomFilter(object): + cdef: + bf_t *bf + + def __init__(self, size=1024 * 32): + self.bf = bf_create(size) + + def __dealloc__(self): + if self.bf: + bf_free(self.bf) + + def __len__(self): + return self.bf.size + + def add(self, *keys): + cdef bytes bkey + + for key in keys: + bkey = encode(key) + bf_add(self.bf, bkey) + + def __contains__(self, key): + cdef bytes bkey = encode(key) + return bf_contains(self.bf, bkey) + + def to_buffer(self): + # We have to do this so that embedded NULL bytes are preserved. + cdef bytes buf = PyBytes_FromStringAndSize((self.bf.bits), + self.bf.size) + # Similarly we wrap in a buffer object so pysqlite preserves the + # embedded NULL bytes. + return buf + + @classmethod + def from_buffer(cls, data): + cdef: + char *buf + Py_ssize_t buflen + BloomFilter bloom + + PyBytes_AsStringAndSize(data, &buf, &buflen) + + bloom = BloomFilter(buflen) + memcpy(bloom.bf.bits, buf, buflen) + return bloom + + @classmethod + def calculate_size(cls, double n, double p): + cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) + return m + + +cdef class BloomFilterAggregate(object): + cdef: + BloomFilter bf + + def __init__(self): + self.bf = None + + def step(self, value, size=None): + if not self.bf: + size = size or 1024 + self.bf = BloomFilter(size) + + self.bf.add(value) + + def finalize(self): + if not self.bf: + return None + + return pysqlite.Binary(self.bf.to_buffer()) + + +def peewee_bloomfilter_contains(key, data): + cdef: + bf_t bf + bytes bkey + bytes bdata = bytes(data) + unsigned char *cdata = bdata + + bf.size = len(data) + bf.bits = cdata + bkey = encode(key) + + return bf_contains(&bf, bkey) + + +def peewee_bloomfilter_calculate_size(n_items, error_p): + return BloomFilter.calculate_size(n_items, error_p) + + +def register_bloomfilter(database): + database.register_aggregate(BloomFilterAggregate, 'bloomfilter') + database.register_function(peewee_bloomfilter_contains, + 'bloomfilter_contains') + database.register_function(peewee_bloomfilter_calculate_size, + 'bloomfilter_calculate_size') + + +cdef inline int _check_connection(pysqlite_Connection *conn) except -1: + """ + Check that the underlying SQLite database connection is usable. Raises an + InterfaceError if the connection is either uninitialized or closed. + """ + if not conn.db: + raise InterfaceError('Cannot operate on closed database.') + return 1 + + +class ZeroBlob(Node): + def __init__(self, length): + if not isinstance(length, int) or length < 0: + raise ValueError('Length must be a positive integer.') + self.length = length + + def __sql__(self, ctx): + return ctx.literal('zeroblob(%s)' % self.length) + + +cdef class Blob(object) # Forward declaration. + + +cdef inline int _check_blob_closed(Blob blob) except -1: + if not blob.pBlob: + raise InterfaceError('Cannot operate on closed blob.') + return 1 + + +cdef class Blob(object): + cdef: + int offset + pysqlite_Connection *conn + sqlite3_blob *pBlob + + def __init__(self, database, table, column, rowid, + read_only=False): + cdef: + bytes btable = encode(table) + bytes bcolumn = encode(column) + int flags = 0 if read_only else 1 + int rc + sqlite3_blob *blob + + self.conn = (database._state.conn) + _check_connection(self.conn) + + rc = sqlite3_blob_open( + self.conn.db, + 'main', + btable, + bcolumn, + rowid, + flags, + &blob) + if rc != SQLITE_OK: + raise OperationalError('Unable to open blob.') + if not blob: + raise MemoryError('Unable to allocate blob.') + + self.pBlob = blob + self.offset = 0 + + cdef _close(self): + if self.pBlob: + sqlite3_blob_close(self.pBlob) + self.pBlob = 0 + + def __dealloc__(self): + self._close() + + def __len__(self): + _check_blob_closed(self) + return sqlite3_blob_bytes(self.pBlob) + + def read(self, n=None): + cdef: + bytes pybuf + int length = -1 + int size + char *buf + + if n is not None: + length = n + + _check_blob_closed(self) + size = sqlite3_blob_bytes(self.pBlob) + if self.offset == size or length == 0: + return b'' + + if length < 0: + length = size - self.offset + + if self.offset + length > size: + length = size - self.offset + + pybuf = PyBytes_FromStringAndSize(NULL, length) + buf = PyBytes_AS_STRING(pybuf) + if sqlite3_blob_read(self.pBlob, buf, length, self.offset): + self._close() + raise OperationalError('Error reading from blob.') + + self.offset += length + return bytes(pybuf) + + def seek(self, offset, frame_of_reference=0): + cdef int size + _check_blob_closed(self) + size = sqlite3_blob_bytes(self.pBlob) + if frame_of_reference == 0: + if offset < 0 or offset > size: + raise ValueError('seek() offset outside of valid range.') + self.offset = offset + elif frame_of_reference == 1: + if self.offset + offset < 0 or self.offset + offset > size: + raise ValueError('seek() offset outside of valid range.') + self.offset += offset + elif frame_of_reference == 2: + if size + offset < 0 or size + offset > size: + raise ValueError('seek() offset outside of valid range.') + self.offset = size + offset + else: + raise ValueError('seek() frame of reference must be 0, 1 or 2.') + + def tell(self): + _check_blob_closed(self) + return self.offset + + def write(self, bytes data): + cdef: + char *buf + int size + Py_ssize_t buflen + + _check_blob_closed(self) + size = sqlite3_blob_bytes(self.pBlob) + PyBytes_AsStringAndSize(data, &buf, &buflen) + if ((buflen + self.offset)) < self.offset: + raise ValueError('Data is too large (integer wrap)') + if ((buflen + self.offset)) > size: + raise ValueError('Data would go beyond end of blob') + if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset): + raise OperationalError('Error writing to blob.') + self.offset += buflen + + def close(self): + self._close() + + def reopen(self, rowid): + _check_blob_closed(self) + self.offset = 0 + if sqlite3_blob_reopen(self.pBlob, rowid): + self._close() + raise OperationalError('Unable to re-open blob.') + + +def sqlite_get_status(flag): + cdef: + int current, highwater, rc + + rc = sqlite3_status(flag, ¤t, &highwater, 0) + if rc == SQLITE_OK: + return (current, highwater) + raise Exception('Error requesting status: %s' % rc) + + +def sqlite_get_db_status(conn, flag): + cdef: + int current, highwater, rc + pysqlite_Connection *c_conn = conn + + rc = sqlite3_db_status(c_conn.db, flag, ¤t, &highwater, 0) + if rc == SQLITE_OK: + return (current, highwater) + raise Exception('Error requesting db status: %s' % rc) + + +cdef class ConnectionHelper(object): + cdef: + object _commit_hook, _rollback_hook, _update_hook + pysqlite_Connection *conn + + def __init__(self, connection): + self.conn = connection + self._commit_hook = self._rollback_hook = self._update_hook = None + + def __dealloc__(self): + # When deallocating a Database object, we need to ensure that we clear + # any commit, rollback or update hooks that may have been applied. + if not self.conn.initialized or not self.conn.db: + return + + if self._commit_hook is not None: + sqlite3_commit_hook(self.conn.db, NULL, NULL) + if self._rollback_hook is not None: + sqlite3_rollback_hook(self.conn.db, NULL, NULL) + if self._update_hook is not None: + sqlite3_update_hook(self.conn.db, NULL, NULL) + + def set_commit_hook(self, fn): + self._commit_hook = fn + if fn is None: + sqlite3_commit_hook(self.conn.db, NULL, NULL) + else: + sqlite3_commit_hook(self.conn.db, _commit_callback, fn) + + def set_rollback_hook(self, fn): + self._rollback_hook = fn + if fn is None: + sqlite3_rollback_hook(self.conn.db, NULL, NULL) + else: + sqlite3_rollback_hook(self.conn.db, _rollback_callback, fn) + + def set_update_hook(self, fn): + self._update_hook = fn + if fn is None: + sqlite3_update_hook(self.conn.db, NULL, NULL) + else: + sqlite3_update_hook(self.conn.db, _update_callback, fn) + + def set_busy_handler(self, timeout=5): + """ + Replace the default busy handler with one that introduces some "jitter" + into the amount of time delayed between checks. + """ + cdef sqlite3_int64 n = timeout * 1000 + sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, n) + return True + + def changes(self): + return sqlite3_changes(self.conn.db) + + def last_insert_rowid(self): + return sqlite3_last_insert_rowid(self.conn.db) + + def autocommit(self): + return sqlite3_get_autocommit(self.conn.db) != 0 + + +cdef int _commit_callback(void *userData) with gil: + # C-callback that delegates to the Python commit handler. If the Python + # function raises a ValueError, then the commit is aborted and the + # transaction rolled back. Otherwise, regardless of the function return + # value, the transaction will commit. + cdef object fn = userData + try: + fn() + except ValueError: + return 1 + else: + return SQLITE_OK + + +cdef void _rollback_callback(void *userData) with gil: + # C-callback that delegates to the Python rollback handler. + cdef object fn = userData + fn() + + +cdef void _update_callback(void *userData, int queryType, const char *database, + const char *table, sqlite3_int64 rowid) with gil: + # C-callback that delegates to a Python function that is executed whenever + # the database is updated (insert/update/delete queries). The Python + # callback receives a string indicating the query type, the name of the + # database, the name of the table being updated, and the rowid of the row + # being updatd. + cdef object fn = userData + if queryType == SQLITE_INSERT: + query = 'INSERT' + elif queryType == SQLITE_UPDATE: + query = 'UPDATE' + elif queryType == SQLITE_DELETE: + query = 'DELETE' + else: + query = '' + fn(query, decode(database), decode(table), rowid) + + +def backup(src_conn, dest_conn, pages=None, name=None, progress=None): + cdef: + bytes bname = encode(name or 'main') + int page_step = pages or -1 + int rc + pysqlite_Connection *src = src_conn + pysqlite_Connection *dest = dest_conn + sqlite3 *src_db = src.db + sqlite3 *dest_db = dest.db + sqlite3_backup *backup + + # We always backup to the "main" database in the dest db. + backup = sqlite3_backup_init(dest_db, b'main', src_db, bname) + if backup == NULL: + raise OperationalError('Unable to initialize backup.') + + while True: + with nogil: + rc = sqlite3_backup_step(backup, page_step) + if progress is not None: + # Progress-handler is called with (remaining, page count, is done?) + remaining = sqlite3_backup_remaining(backup) + page_count = sqlite3_backup_pagecount(backup) + try: + progress(remaining, page_count, rc == SQLITE_DONE) + except: + sqlite3_backup_finish(backup) + raise + if rc == SQLITE_BUSY or rc == SQLITE_LOCKED: + with nogil: + sqlite3_sleep(250) + elif rc == SQLITE_DONE: + break + + with nogil: + sqlite3_backup_finish(backup) + if sqlite3_errcode(dest_db): + raise OperationalError('Error backuping up database: %s' % + sqlite3_errmsg(dest_db)) + return True + + +def backup_to_file(src_conn, filename, pages=None, name=None, progress=None): + dest_conn = pysqlite.connect(filename) + backup(src_conn, dest_conn, pages=pages, name=name, progress=progress) + dest_conn.close() + return True + + +cdef int _aggressive_busy_handler(void *ptr, int n) nogil: + # In concurrent environments, it often seems that if multiple queries are + # kicked off at around the same time, they proceed in lock-step to check + # for the availability of the lock. By introducing some "jitter" we can + # ensure that this doesn't happen. Furthermore, this function makes more + # attempts in the same time period than the default handler. + cdef: + sqlite3_int64 busyTimeout = ptr + int current, total + + if n < 20: + current = 25 - (rand() % 10) # ~20ms + total = n * 20 + elif n < 40: + current = 50 - (rand() % 20) # ~40ms + total = 400 + ((n - 20) * 40) + else: + current = 120 - (rand() % 40) # ~100ms + total = 1200 + ((n - 40) * 100) # Estimate the amount of time slept. + + if total + current > busyTimeout: + current = busyTimeout - total + if current > 0: + sqlite3_sleep(current) + return 1 + return 0 diff --git a/libs/playhouse/_sqlite_udf.pyx b/libs/playhouse/_sqlite_udf.pyx new file mode 100644 index 000000000..9ff6e7430 --- /dev/null +++ b/libs/playhouse/_sqlite_udf.pyx @@ -0,0 +1,137 @@ +import sys +from difflib import SequenceMatcher +from random import randint + + +IS_PY3K = sys.version_info[0] == 3 + +# String UDF. +def damerau_levenshtein_dist(s1, s2): + cdef: + int i, j, del_cost, add_cost, sub_cost + int s1_len = len(s1), s2_len = len(s2) + list one_ago, two_ago, current_row + list zeroes = [0] * (s2_len + 1) + + if IS_PY3K: + current_row = list(range(1, s2_len + 2)) + else: + current_row = range(1, s2_len + 2) + + current_row[-1] = 0 + one_ago = None + + for i in range(s1_len): + two_ago = one_ago + one_ago = current_row + current_row = list(zeroes) + current_row[-1] = i + 1 + for j in range(s2_len): + del_cost = one_ago[j] + 1 + add_cost = current_row[j - 1] + 1 + sub_cost = one_ago[j - 1] + (s1[i] != s2[j]) + current_row[j] = min(del_cost, add_cost, sub_cost) + + # Handle transpositions. + if (i > 0 and j > 0 and s1[i] == s2[j - 1] + and s1[i-1] == s2[j] and s1[i] != s2[j]): + current_row[j] = min(current_row[j], two_ago[j - 2] + 1) + + return current_row[s2_len - 1] + +# String UDF. +def levenshtein_dist(a, b): + cdef: + int add, delete, change + int i, j + int n = len(a), m = len(b) + list current, previous + list zeroes + + if n > m: + a, b = b, a + n, m = m, n + + zeroes = [0] * (m + 1) + + if IS_PY3K: + current = list(range(n + 1)) + else: + current = range(n + 1) + + for i in range(1, m + 1): + previous = current + current = list(zeroes) + current[0] = i + + for j in range(1, n + 1): + add = previous[j] + 1 + delete = current[j - 1] + 1 + change = previous[j - 1] + if a[j - 1] != b[i - 1]: + change +=1 + current[j] = min(add, delete, change) + + return current[n] + +# String UDF. +def str_dist(a, b): + cdef: + int t = 0 + + for i in SequenceMatcher(None, a, b).get_opcodes(): + if i[0] == 'equal': + continue + t = t + max(i[4] - i[3], i[2] - i[1]) + return t + +# Math Aggregate. +cdef class median(object): + cdef: + int ct + list items + + def __init__(self): + self.ct = 0 + self.items = [] + + cdef selectKth(self, int k, int s=0, int e=-1): + cdef: + int idx + if e < 0: + e = len(self.items) + idx = randint(s, e-1) + idx = self.partition_k(idx, s, e) + if idx > k: + return self.selectKth(k, s, idx) + elif idx < k: + return self.selectKth(k, idx + 1, e) + else: + return self.items[idx] + + cdef int partition_k(self, int pi, int s, int e): + cdef: + int i, x + + val = self.items[pi] + # Swap pivot w/last item. + self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1] + x = s + for i in range(s, e): + if self.items[i] < val: + self.items[i], self.items[x] = self.items[x], self.items[i] + x += 1 + self.items[x], self.items[e-1] = self.items[e-1], self.items[x] + return x + + def step(self, item): + self.items.append(item) + self.ct += 1 + + def finalize(self): + if self.ct == 0: + return None + elif self.ct < 3: + return self.items[0] + else: + return self.selectKth(self.ct / 2) diff --git a/libs/playhouse/apsw_ext.py b/libs/playhouse/apsw_ext.py new file mode 100644 index 000000000..654ee7739 --- /dev/null +++ b/libs/playhouse/apsw_ext.py @@ -0,0 +1,146 @@ +""" +Peewee integration with APSW, "another python sqlite wrapper". + +Project page: https://rogerbinns.github.io/apsw/ + +APSW is a really neat library that provides a thin wrapper on top of SQLite's +C interface. + +Here are just a few reasons to use APSW, taken from the documentation: + +* APSW gives all functionality of SQLite, including virtual tables, virtual + file system, blob i/o, backups and file control. +* Connections can be shared across threads without any additional locking. +* Transactions are managed explicitly by your code. +* APSW can handle nested transactions. +* Unicode is handled correctly. +* APSW is faster. +""" +import apsw +from peewee import * +from peewee import __exception_wrapper__ +from peewee import BooleanField as _BooleanField +from peewee import DateField as _DateField +from peewee import DateTimeField as _DateTimeField +from peewee import DecimalField as _DecimalField +from peewee import TimeField as _TimeField +from peewee import logger + +from playhouse.sqlite_ext import SqliteExtDatabase + + +class APSWDatabase(SqliteExtDatabase): + server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.')) + + def __init__(self, database, **kwargs): + self._modules = {} + super(APSWDatabase, self).__init__(database, **kwargs) + + def register_module(self, mod_name, mod_inst): + self._modules[mod_name] = mod_inst + if not self.is_closed(): + self.connection().createmodule(mod_name, mod_inst) + + def unregister_module(self, mod_name): + del(self._modules[mod_name]) + + def _connect(self): + conn = apsw.Connection(self.database, **self.connect_params) + if self._timeout is not None: + conn.setbusytimeout(self._timeout * 1000) + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _add_conn_hooks(self, conn): + super(APSWDatabase, self)._add_conn_hooks(conn) + self._load_modules(conn) # APSW-only. + + def _load_modules(self, conn): + for mod_name, mod_inst in self._modules.items(): + conn.createmodule(mod_name, mod_inst) + return conn + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + def make_aggregate(): + return (klass(), klass.step, klass.finalize) + conn.createaggregatefunction(name, make_aggregate) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.createcollation(name, fn) + + def _load_functions(self, conn): + for name, (fn, num_params) in self._functions.items(): + conn.createscalarfunction(name, fn, num_params) + + def _load_extensions(self, conn): + conn.enableloadextension(True) + for extension in self._extensions: + conn.loadextension(extension) + + def load_extension(self, extension): + self._extensions.add(extension) + if not self.is_closed(): + conn = self.connection() + conn.enableloadextension(True) + conn.loadextension(extension) + + def last_insert_id(self, cursor, query_type=None): + return cursor.getconnection().last_insert_rowid() + + def rows_affected(self, cursor): + return cursor.getconnection().changes() + + def begin(self, lock_type='deferred'): + self.cursor().execute('begin %s;' % lock_type) + + def commit(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('commit;') + return True + + def rollback(self): + with __exception_wrapper__: + curs = self.cursor() + if curs.getconnection().getautocommit(): + return False + curs.execute('rollback;') + return True + + def execute_sql(self, sql, params=None, commit=True): + logger.debug((sql, params)) + with __exception_wrapper__: + cursor = self.cursor() + cursor.execute(sql, params or ()) + return cursor + + +def nh(s, v): + if v is not None: + return str(v) + +class BooleanField(_BooleanField): + def db_value(self, v): + v = super(BooleanField, self).db_value(v) + if v is not None: + return v and 1 or 0 + +class DateField(_DateField): + db_value = nh + +class TimeField(_TimeField): + db_value = nh + +class DateTimeField(_DateTimeField): + db_value = nh + +class DecimalField(_DecimalField): + db_value = nh diff --git a/libs/playhouse/cockroachdb.py b/libs/playhouse/cockroachdb.py new file mode 100644 index 000000000..1d4ff9f68 --- /dev/null +++ b/libs/playhouse/cockroachdb.py @@ -0,0 +1,207 @@ +import functools +import re + +from peewee import * +from peewee import _atomic +from peewee import _manual +from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) +from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). +from peewee import IndexMetadata +from playhouse.pool import _PooledPostgresqlDatabase +try: + from playhouse.postgres_ext import ArrayField + from playhouse.postgres_ext import BinaryJSONField + from playhouse.postgres_ext import IntervalField + JSONField = BinaryJSONField +except ImportError: # psycopg2 not installed, ignore. + ArrayField = BinaryJSONField = IntervalField = JSONField = None + + +NESTED_TX_MIN_VERSION = 200100 + +TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' + 'alternatively use the @transaction context-manager/decorator, ' + 'which only wraps the outer-most block in transactional logic. ' + 'To run a transaction with automatic retries, use the ' + 'run_transaction() helper.') + +class ExceededMaxAttempts(OperationalError): pass + + +class UUIDKeyField(UUIDField): + auto_increment = True + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')] + kwargs.setdefault('primary_key', True) + super(UUIDKeyField, self).__init__(*args, **kwargs) + + +class RowIDField(AutoField): + field_type = 'INT' + + def __init__(self, *args, **kwargs): + if kwargs.get('constraints'): + raise ValueError('%s cannot specify constraints.' % type(self)) + kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')] + super(RowIDField, self).__init__(*args, **kwargs) + + +class CockroachDatabase(PostgresqlDatabase): + field_types = PostgresqlDatabase.field_types.copy() + field_types.update({ + 'BLOB': 'BYTES', + }) + + for_update = False + nulls_ordering = False + release_after_rollback = True + + def __init__(self, *args, **kwargs): + kwargs.setdefault('user', 'root') + kwargs.setdefault('port', 26257) + super(CockroachDatabase, self).__init__(*args, **kwargs) + + def _set_server_version(self, conn): + curs = conn.cursor() + curs.execute('select version()') + raw, = curs.fetchone() + match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) + if match_obj is not None: + clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) + self.server_version = int(clean) # 19.1.5 -> 190105. + else: + # Fallback to use whatever cockroachdb tells us via protocol. + super(CockroachDatabase, self)._set_server_version(conn) + + def _get_pk_constraint(self, table, schema=None): + query = ('SELECT constraint_name ' + 'FROM information_schema.table_constraints ' + 'WHERE table_name = %s AND table_schema = %s ' + 'AND constraint_type = %s') + cursor = self.execute_sql(query, (table, schema or 'public', + 'PRIMARY KEY')) + row = cursor.fetchone() + return row and row[0] or None + + def get_indexes(self, table, schema=None): + # The primary-key index is returned by default, so we will just strip + # it out here. + indexes = super(CockroachDatabase, self).get_indexes(table, schema) + pkc = self._get_pk_constraint(table, schema) + return [idx for idx in indexes if (not pkc) or (idx.name != pkc)] + + def conflict_statement(self, on_conflict, query): + if not on_conflict._action: return + + action = on_conflict._action.lower() + if action in ('replace', 'upsert'): + return SQL('UPSERT') + elif action not in ('ignore', 'nothing', 'update'): + raise ValueError('Un-supported action for conflict resolution. ' + 'CockroachDB supports REPLACE (UPSERT), IGNORE ' + 'and UPDATE.') + + def conflict_update(self, oc, query): + action = oc._action.lower() if oc._action else '' + if action in ('ignore', 'nothing'): + return SQL('ON CONFLICT DO NOTHING') + elif action in ('replace', 'upsert'): + # No special stuff is necessary, this is just indicated by starting + # the statement with UPSERT instead of INSERT. + return + elif oc._conflict_constraint: + raise ValueError('CockroachDB does not support the usage of a ' + 'constraint name. Use the column(s) instead.') + + return super(CockroachDatabase, self).conflict_update(oc, query) + + def extract_date(self, date_part, date_field): + return fn.extract(date_part, date_field) + + def from_timestamp(self, date_field): + # CRDB does not allow casting a decimal/float to timestamp, so we first + # cast to int, then to timestamptz. + return date_field.cast('int').cast('timestamptz') + + def begin(self, system_time=None, priority=None): + super(CockroachDatabase, self).begin() + if system_time is not None: + self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s', + (system_time,), commit=False) + if priority is not None: + priority = priority.lower() + if priority not in ('low', 'normal', 'high'): + raise ValueError('priority must be low, normal or high') + self.execute_sql('SET TRANSACTION PRIORITY %s' % priority, + commit=False) + + def atomic(self, system_time=None, priority=None): + if self.server_version < NESTED_TX_MIN_VERSION: + return _crdb_atomic(self, system_time, priority) + return super(CockroachDatabase, self).atomic(system_time, priority) + + def savepoint(self): + if self.server_version < NESTED_TX_MIN_VERSION: + raise NotImplementedError(TXN_ERR_MSG) + return super(CockroachDatabase, self).savepoint() + + def retry_transaction(self, max_attempts=None, system_time=None, + priority=None): + def deco(cb): + @functools.wraps(cb) + def new_fn(): + return run_transaction(self, cb, max_attempts, system_time, + priority) + return new_fn + return deco + + def run_transaction(self, cb, max_attempts=None, system_time=None, + priority=None): + return run_transaction(self, cb, max_attempts, system_time, priority) + + +class _crdb_atomic(_atomic): + def __enter__(self): + if self.db.transaction_depth() > 0: + if not isinstance(self.db.top_transaction(), _manual): + raise NotImplementedError(TXN_ERR_MSG) + return super(_crdb_atomic, self).__enter__() + + +def run_transaction(db, callback, max_attempts=None, system_time=None, + priority=None): + """ + Run transactional SQL in a transaction with automatic retries. + + User-provided `callback`: + * Must accept one parameter, the `db` instance representing the connection + the transaction is running under. + * Must not attempt to commit, rollback or otherwise manage transactions. + * May be called more than once. + * Should ideally only contain SQL operations. + + Additionally, the database must not have any open transaction at the time + this function is called, as CRDB does not support nested transactions. + """ + max_attempts = max_attempts or -1 + with db.atomic(system_time=system_time, priority=priority) as txn: + db.execute_sql('SAVEPOINT cockroach_restart') + while max_attempts != 0: + try: + result = callback(db) + db.execute_sql('RELEASE SAVEPOINT cockroach_restart') + return result + except OperationalError as exc: + if exc.orig.pgcode == '40001': + max_attempts -= 1 + db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart') + continue + raise + raise ExceededMaxAttempts(None, 'unable to commit transaction') + + +class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase): + pass diff --git a/libs/playhouse/dataset.py b/libs/playhouse/dataset.py new file mode 100644 index 000000000..4bbc549a0 --- /dev/null +++ b/libs/playhouse/dataset.py @@ -0,0 +1,451 @@ +import csv +import datetime +from decimal import Decimal +import json +import operator +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse +import sys + +from peewee import * +from playhouse.db_url import connect +from playhouse.migrate import migrate +from playhouse.migrate import SchemaMigrator +from playhouse.reflection import Introspector + +if sys.version_info[0] == 3: + basestring = str + from functools import reduce + def open_file(f, mode): + return open(f, mode, encoding='utf8') +else: + open_file = open + + +class DataSet(object): + def __init__(self, url, **kwargs): + if isinstance(url, Database): + self._url = None + self._database = url + self._database_path = self._database.database + else: + self._url = url + parse_result = urlparse(url) + self._database_path = parse_result.path[1:] + + # Connect to the database. + self._database = connect(url) + + self._database.connect() + + # Introspect the database and generate models. + self._introspector = Introspector.from_database(self._database) + self._models = self._introspector.generate_models( + skip_invalid=True, + literal_column_names=True, + **kwargs) + self._migrator = SchemaMigrator.from_database(self._database) + + class BaseModel(Model): + class Meta: + database = self._database + self._base_model = BaseModel + self._export_formats = self.get_export_formats() + self._import_formats = self.get_import_formats() + + def __repr__(self): + return '' % self._database_path + + def get_export_formats(self): + return { + 'csv': CSVExporter, + 'json': JSONExporter, + 'tsv': TSVExporter} + + def get_import_formats(self): + return { + 'csv': CSVImporter, + 'json': JSONImporter, + 'tsv': TSVImporter} + + def __getitem__(self, table): + if table not in self._models and table in self.tables: + self.update_cache(table) + return Table(self, table, self._models.get(table)) + + @property + def tables(self): + return self._database.get_tables() + + def __contains__(self, table): + return table in self.tables + + def connect(self): + self._database.connect() + + def close(self): + self._database.close() + + def update_cache(self, table=None): + if table: + dependencies = [table] + if table in self._models: + model_class = self._models[table] + dependencies.extend([ + related._meta.table_name for _, related, _ in + model_class._meta.model_graph()]) + else: + dependencies.extend(self.get_table_dependencies(table)) + else: + dependencies = None # Update all tables. + self._models = {} + updated = self._introspector.generate_models( + skip_invalid=True, + table_names=dependencies, + literal_column_names=True) + self._models.update(updated) + + def get_table_dependencies(self, table): + stack = [table] + accum = [] + seen = set() + while stack: + table = stack.pop() + for fk_meta in self._database.get_foreign_keys(table): + dest = fk_meta.dest_table + if dest not in seen: + stack.append(dest) + accum.append(dest) + return accum + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._database.is_closed(): + self.close() + + def query(self, sql, params=None, commit=True): + return self._database.execute_sql(sql, params, commit) + + def transaction(self): + if self._database.transaction_depth() == 0: + return self._database.transaction() + else: + return self._database.savepoint() + + def _check_arguments(self, filename, file_obj, format, format_dict): + if filename and file_obj: + raise ValueError('file is over-specified. Please use either ' + 'filename or file_obj, but not both.') + if not filename and not file_obj: + raise ValueError('A filename or file-like object must be ' + 'specified.') + if format not in format_dict: + valid_formats = ', '.join(sorted(format_dict.keys())) + raise ValueError('Unsupported format "%s". Use one of %s.' % ( + format, valid_formats)) + + def freeze(self, query, format='csv', filename=None, file_obj=None, + **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'w') + + exporter = self._export_formats[format](query) + exporter.export(file_obj, **kwargs) + + if filename: + file_obj.close() + + def thaw(self, table, format='csv', filename=None, file_obj=None, + strict=False, **kwargs): + self._check_arguments(filename, file_obj, format, self._export_formats) + if filename: + file_obj = open_file(filename, 'r') + + importer = self._import_formats[format](self[table], strict) + count = importer.load(file_obj, **kwargs) + + if filename: + file_obj.close() + + return count + + +class Table(object): + def __init__(self, dataset, name, model_class): + self.dataset = dataset + self.name = name + if model_class is None: + model_class = self._create_model() + model_class.create_table() + self.dataset._models[name] = model_class + + @property + def model_class(self): + return self.dataset._models[self.name] + + def __repr__(self): + return '' % self.name + + def __len__(self): + return self.find().count() + + def __iter__(self): + return iter(self.find().iterator()) + + def _create_model(self): + class Meta: + table_name = self.name + return type( + str(self.name), + (self.dataset._base_model,), + {'Meta': Meta}) + + def create_index(self, columns, unique=False): + index = ModelIndex(self.model_class, columns, unique=unique) + self.model_class.add_index(index) + self.dataset._database.execute(index) + + def _guess_field_type(self, value): + if isinstance(value, basestring): + return TextField + if isinstance(value, (datetime.date, datetime.datetime)): + return DateTimeField + elif value is True or value is False: + return BooleanField + elif isinstance(value, int): + return IntegerField + elif isinstance(value, float): + return FloatField + elif isinstance(value, Decimal): + return DecimalField + return TextField + + @property + def columns(self): + return [f.name for f in self.model_class._meta.sorted_fields] + + def _migrate_new_columns(self, data): + new_keys = set(data) - set(self.model_class._meta.fields) + if new_keys: + operations = [] + for key in new_keys: + field_class = self._guess_field_type(data[key]) + field = field_class(null=True) + operations.append( + self.dataset._migrator.add_column(self.name, key, field)) + field.bind(self.model_class, key) + + migrate(*operations) + + self.dataset.update_cache(self.name) + + def __getitem__(self, item): + try: + return self.model_class[item] + except self.model_class.DoesNotExist: + pass + + def __setitem__(self, item, value): + if not isinstance(value, dict): + raise ValueError('Table.__setitem__() value must be a dict') + + pk = self.model_class._meta.primary_key + value[pk.name] = item + + try: + with self.dataset.transaction() as txn: + self.insert(**value) + except IntegrityError: + self.dataset.update_cache(self.name) + self.update(columns=[pk.name], **value) + + def __delitem__(self, item): + del self.model_class[item] + + def insert(self, **data): + self._migrate_new_columns(data) + return self.model_class.insert(**data).execute() + + def _apply_where(self, query, filters, conjunction=None): + conjunction = conjunction or operator.and_ + if filters: + expressions = [ + (self.model_class._meta.fields[column] == value) + for column, value in filters.items()] + query = query.where(reduce(conjunction, expressions)) + return query + + def update(self, columns=None, conjunction=None, **data): + self._migrate_new_columns(data) + filters = {} + if columns: + for column in columns: + filters[column] = data.pop(column) + + return self._apply_where( + self.model_class.update(**data), + filters, + conjunction).execute() + + def _query(self, **query): + return self._apply_where(self.model_class.select(), query) + + def find(self, **query): + return self._query(**query).dicts() + + def find_one(self, **query): + try: + return self.find(**query).get() + except self.model_class.DoesNotExist: + return None + + def all(self): + return self.find() + + def delete(self, **query): + return self._apply_where(self.model_class.delete(), query).execute() + + def freeze(self, *args, **kwargs): + return self.dataset.freeze(self.all(), *args, **kwargs) + + def thaw(self, *args, **kwargs): + return self.dataset.thaw(self.name, *args, **kwargs) + + +class Exporter(object): + def __init__(self, query): + self.query = query + + def export(self, file_obj): + raise NotImplementedError + + +class JSONExporter(Exporter): + def __init__(self, query, iso8601_datetimes=False): + super(JSONExporter, self).__init__(query) + self.iso8601_datetimes = iso8601_datetimes + + def _make_default(self): + datetime_types = (datetime.datetime, datetime.date, datetime.time) + + if self.iso8601_datetimes: + def default(o): + if isinstance(o, datetime_types): + return o.isoformat() + elif isinstance(o, Decimal): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + else: + def default(o): + if isinstance(o, datetime_types + (Decimal,)): + return str(o) + raise TypeError('Unable to serialize %r as JSON' % o) + return default + + def export(self, file_obj, **kwargs): + json.dump( + list(self.query), + file_obj, + default=self._make_default(), + **kwargs) + + +class CSVExporter(Exporter): + def export(self, file_obj, header=True, **kwargs): + writer = csv.writer(file_obj, **kwargs) + tuples = self.query.tuples().execute() + tuples.initialize() + if header and getattr(tuples, 'columns', None): + writer.writerow([column for column in tuples.columns]) + for row in tuples: + writer.writerow(row) + + +class TSVExporter(CSVExporter): + def export(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVExporter, self).export(file_obj, header, **kwargs) + + +class Importer(object): + def __init__(self, table, strict=False): + self.table = table + self.strict = strict + + model = self.table.model_class + self.columns = model._meta.columns + self.columns.update(model._meta.fields) + + def load(self, file_obj): + raise NotImplementedError + + +class JSONImporter(Importer): + def load(self, file_obj, **kwargs): + data = json.load(file_obj, **kwargs) + count = 0 + + for row in data: + if self.strict: + obj = {} + for key in row: + field = self.columns.get(key) + if field is not None: + obj[field.name] = field.python_value(row[key]) + else: + obj = row + + if obj: + self.table.insert(**obj) + count += 1 + + return count + + +class CSVImporter(Importer): + def load(self, file_obj, header=True, **kwargs): + count = 0 + reader = csv.reader(file_obj, **kwargs) + if header: + try: + header_keys = next(reader) + except StopIteration: + return count + + if self.strict: + header_fields = [] + for idx, key in enumerate(header_keys): + if key in self.columns: + header_fields.append((idx, self.columns[key])) + else: + header_fields = list(enumerate(header_keys)) + else: + header_fields = list(enumerate(self.model._meta.sorted_fields)) + + if not header_fields: + return count + + for row in reader: + obj = {} + for idx, field in header_fields: + if self.strict: + obj[field.name] = field.python_value(row[idx]) + else: + obj[field] = row[idx] + + self.table.insert(**obj) + count += 1 + + return count + + +class TSVImporter(CSVImporter): + def load(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/libs/playhouse/db_url.py b/libs/playhouse/db_url.py new file mode 100644 index 000000000..7176c806d --- /dev/null +++ b/libs/playhouse/db_url.py @@ -0,0 +1,130 @@ +try: + from urlparse import parse_qsl, unquote, urlparse +except ImportError: + from urllib.parse import parse_qsl, unquote, urlparse + +from peewee import * +from playhouse.cockroachdb import CockroachDatabase +from playhouse.cockroachdb import PooledCockroachDatabase +from playhouse.pool import PooledMySQLDatabase +from playhouse.pool import PooledPostgresqlDatabase +from playhouse.pool import PooledSqliteDatabase +from playhouse.pool import PooledSqliteExtDatabase +from playhouse.sqlite_ext import SqliteExtDatabase + + +schemes = { + 'cockroachdb': CockroachDatabase, + 'cockroachdb+pool': PooledCockroachDatabase, + 'crdb': CockroachDatabase, + 'crdb+pool': PooledCockroachDatabase, + 'mysql': MySQLDatabase, + 'mysql+pool': PooledMySQLDatabase, + 'postgres': PostgresqlDatabase, + 'postgresql': PostgresqlDatabase, + 'postgres+pool': PooledPostgresqlDatabase, + 'postgresql+pool': PooledPostgresqlDatabase, + 'sqlite': SqliteDatabase, + 'sqliteext': SqliteExtDatabase, + 'sqlite+pool': PooledSqliteDatabase, + 'sqliteext+pool': PooledSqliteExtDatabase, +} + +def register_database(db_class, *names): + global schemes + for name in names: + schemes[name] = db_class + +def parseresult_to_dict(parsed, unquote_password=False): + + # urlparse in python 2.6 is broken so query will be empty and instead + # appended to path complete with '?' + path_parts = parsed.path[1:].split('?') + try: + query = path_parts[1] + except IndexError: + query = parsed.query + + connect_kwargs = {'database': path_parts[0]} + if parsed.username: + connect_kwargs['user'] = parsed.username + if parsed.password: + connect_kwargs['password'] = parsed.password + if unquote_password: + connect_kwargs['password'] = unquote(connect_kwargs['password']) + if parsed.hostname: + connect_kwargs['host'] = parsed.hostname + if parsed.port: + connect_kwargs['port'] = parsed.port + + # Adjust parameters for MySQL. + if parsed.scheme == 'mysql' and 'password' in connect_kwargs: + connect_kwargs['passwd'] = connect_kwargs.pop('password') + elif 'sqlite' in parsed.scheme and not connect_kwargs['database']: + connect_kwargs['database'] = ':memory:' + + # Get additional connection args from the query string + qs_args = parse_qsl(query, keep_blank_values=True) + for key, value in qs_args: + if value.lower() == 'false': + value = False + elif value.lower() == 'true': + value = True + elif value.isdigit(): + value = int(value) + elif '.' in value and all(p.isdigit() for p in value.split('.', 1)): + try: + value = float(value) + except ValueError: + pass + elif value.lower() in ('null', 'none'): + value = None + + connect_kwargs[key] = value + + return connect_kwargs + +def parse(url, unquote_password=False): + parsed = urlparse(url) + return parseresult_to_dict(parsed, unquote_password) + +def connect(url, unquote_password=False, **connect_params): + parsed = urlparse(url) + connect_kwargs = parseresult_to_dict(parsed, unquote_password) + connect_kwargs.update(connect_params) + database_class = schemes.get(parsed.scheme) + + if database_class is None: + if database_class in schemes: + raise RuntimeError('Attempted to use "%s" but a required library ' + 'could not be imported.' % parsed.scheme) + else: + raise RuntimeError('Unrecognized or unsupported scheme: "%s".' % + parsed.scheme) + + return database_class(**connect_kwargs) + +# Conditionally register additional databases. +try: + from playhouse.pool import PooledPostgresqlExtDatabase +except ImportError: + pass +else: + register_database( + PooledPostgresqlExtDatabase, + 'postgresext+pool', + 'postgresqlext+pool') + +try: + from playhouse.apsw_ext import APSWDatabase +except ImportError: + pass +else: + register_database(APSWDatabase, 'apsw') + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase +except ImportError: + pass +else: + register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext') diff --git a/libs/playhouse/fields.py b/libs/playhouse/fields.py new file mode 100644 index 000000000..fce1a3d6d --- /dev/null +++ b/libs/playhouse/fields.py @@ -0,0 +1,64 @@ +try: + import bz2 +except ImportError: + bz2 = None +try: + import zlib +except ImportError: + zlib = None +try: + import cPickle as pickle +except ImportError: + import pickle +import sys + +from peewee import BlobField +from peewee import buffer_type + + +PY2 = sys.version_info[0] == 2 + + +class CompressedField(BlobField): + ZLIB = 'zlib' + BZ2 = 'bz2' + algorithm_to_import = { + ZLIB: zlib, + BZ2: bz2, + } + + def __init__(self, compression_level=6, algorithm=ZLIB, *args, + **kwargs): + self.compression_level = compression_level + if algorithm not in self.algorithm_to_import: + raise ValueError('Unrecognized algorithm %s' % algorithm) + compress_module = self.algorithm_to_import[algorithm] + if compress_module is None: + raise ValueError('Missing library required for %s.' % algorithm) + + self.algorithm = algorithm + self.compress = compress_module.compress + self.decompress = compress_module.decompress + super(CompressedField, self).__init__(*args, **kwargs) + + def python_value(self, value): + if value is not None: + return self.decompress(value) + + def db_value(self, value): + if value is not None: + return self._constructor( + self.compress(value, self.compression_level)) + + +class PickleField(BlobField): + def python_value(self, value): + if value is not None: + if isinstance(value, buffer_type): + value = bytes(value) + return pickle.loads(value) + + def db_value(self, value): + if value is not None: + pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) + return self._constructor(pickled) diff --git a/libs/playhouse/flask_utils.py b/libs/playhouse/flask_utils.py new file mode 100644 index 000000000..76a2a62f4 --- /dev/null +++ b/libs/playhouse/flask_utils.py @@ -0,0 +1,185 @@ +import math +import sys + +from flask import abort +from flask import render_template +from flask import request +from peewee import Database +from peewee import DoesNotExist +from peewee import Model +from peewee import Proxy +from peewee import SelectQuery +from playhouse.db_url import connect as db_url_connect + + +class PaginatedQuery(object): + def __init__(self, query_or_model, paginate_by, page_var='page', page=None, + check_bounds=False): + self.paginate_by = paginate_by + self.page_var = page_var + self.page = page or None + self.check_bounds = check_bounds + + if isinstance(query_or_model, SelectQuery): + self.query = query_or_model + self.model = self.query.model + else: + self.model = query_or_model + self.query = self.model.select() + + def get_page(self): + if self.page is not None: + return self.page + + curr_page = request.args.get(self.page_var) + if curr_page and curr_page.isdigit(): + return max(1, int(curr_page)) + return 1 + + def get_page_count(self): + if not hasattr(self, '_page_count'): + self._page_count = int(math.ceil( + float(self.query.count()) / self.paginate_by)) + return self._page_count + + def get_object_list(self): + if self.check_bounds and self.get_page() > self.get_page_count(): + abort(404) + return self.query.paginate(self.get_page(), self.paginate_by) + + +def get_object_or_404(query_or_model, *query): + if not isinstance(query_or_model, SelectQuery): + query_or_model = query_or_model.select() + try: + return query_or_model.where(*query).get() + except DoesNotExist: + abort(404) + +def object_list(template_name, query, context_variable='object_list', + paginate_by=20, page_var='page', page=None, check_bounds=True, + **kwargs): + paginated_query = PaginatedQuery( + query, + paginate_by=paginate_by, + page_var=page_var, + page=page, + check_bounds=check_bounds) + kwargs[context_variable] = paginated_query.get_object_list() + return render_template( + template_name, + pagination=paginated_query, + page=paginated_query.get_page(), + **kwargs) + +def get_current_url(): + if not request.query_string: + return request.path + return '%s?%s' % (request.path, request.query_string) + +def get_next_url(default='/'): + if request.args.get('next'): + return request.args['next'] + elif request.form.get('next'): + return request.form['next'] + return default + +class FlaskDB(object): + def __init__(self, app=None, database=None, model_class=Model): + self.database = None # Reference to actual Peewee database instance. + self.base_model_class = model_class + self._app = app + self._db = database # dict, url, Database, or None (default). + if app is not None: + self.init_app(app) + + def init_app(self, app): + self._app = app + + if self._db is None: + if 'DATABASE' in app.config: + initial_db = app.config['DATABASE'] + elif 'DATABASE_URL' in app.config: + initial_db = app.config['DATABASE_URL'] + else: + raise ValueError('Missing required configuration data for ' + 'database: DATABASE or DATABASE_URL.') + else: + initial_db = self._db + + self._load_database(app, initial_db) + self._register_handlers(app) + + def _load_database(self, app, config_value): + if isinstance(config_value, Database): + database = config_value + elif isinstance(config_value, dict): + database = self._load_from_config_dict(dict(config_value)) + else: + # Assume a database connection URL. + database = db_url_connect(config_value) + + if isinstance(self.database, Proxy): + self.database.initialize(database) + else: + self.database = database + + def _load_from_config_dict(self, config_dict): + try: + name = config_dict.pop('name') + engine = config_dict.pop('engine') + except KeyError: + raise RuntimeError('DATABASE configuration must specify a ' + '`name` and `engine`.') + + if '.' in engine: + path, class_name = engine.rsplit('.', 1) + else: + path, class_name = 'peewee', engine + + try: + __import__(path) + module = sys.modules[path] + database_class = getattr(module, class_name) + assert issubclass(database_class, Database) + except ImportError: + raise RuntimeError('Unable to import %s' % engine) + except AttributeError: + raise RuntimeError('Database engine not found %s' % engine) + except AssertionError: + raise RuntimeError('Database engine not a subclass of ' + 'peewee.Database: %s' % engine) + + return database_class(name, **config_dict) + + def _register_handlers(self, app): + app.before_request(self.connect_db) + app.teardown_request(self.close_db) + + def get_model_class(self): + if self.database is None: + raise RuntimeError('Database must be initialized.') + + class BaseModel(self.base_model_class): + class Meta: + database = self.database + + return BaseModel + + @property + def Model(self): + if self._app is None: + database = getattr(self, 'database', None) + if database is None: + self.database = Proxy() + + if not hasattr(self, '_model_class'): + self._model_class = self.get_model_class() + return self._model_class + + def connect_db(self): + self.database.connect() + + def close_db(self, exc): + if not self.database.is_closed(): + self.database.close() diff --git a/libs/playhouse/hybrid.py b/libs/playhouse/hybrid.py new file mode 100644 index 000000000..50531cc35 --- /dev/null +++ b/libs/playhouse/hybrid.py @@ -0,0 +1,53 @@ +from peewee import ModelDescriptor + + +# Hybrid methods/attributes, based on similar functionality in SQLAlchemy: +# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html +class hybrid_method(ModelDescriptor): + def __init__(self, func, expr=None): + self.func = func + self.expr = expr or func + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr.__get__(instance_type, instance_type.__class__) + return self.func.__get__(instance, instance_type) + + def expression(self, expr): + self.expr = expr + return self + + +class hybrid_property(ModelDescriptor): + def __init__(self, fget, fset=None, fdel=None, expr=None): + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + + def __get__(self, instance, instance_type): + if instance is None: + return self.expr(instance_type) + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError('Cannot set attribute.') + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError('Cannot delete attribute.') + self.fdel(instance) + + def setter(self, fset): + self.fset = fset + return self + + def deleter(self, fdel): + self.fdel = fdel + return self + + def expression(self, expr): + self.expr = expr + return self diff --git a/libs/playhouse/kv.py b/libs/playhouse/kv.py new file mode 100644 index 000000000..742b49cad --- /dev/null +++ b/libs/playhouse/kv.py @@ -0,0 +1,172 @@ +import operator + +from peewee import * +from peewee import Expression +from playhouse.fields import PickleField +try: + from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase +except ImportError: + from playhouse.sqlite_ext import SqliteExtDatabase + + +Sentinel = type('Sentinel', (object,), {}) + + +class KeyValue(object): + """ + Persistent dictionary. + + :param Field key_field: field to use for key. Defaults to CharField. + :param Field value_field: field to use for value. Defaults to PickleField. + :param bool ordered: data should be returned in key-sorted order. + :param Database database: database where key/value data is stored. + :param str table_name: table name for data. + """ + def __init__(self, key_field=None, value_field=None, ordered=False, + database=None, table_name='keyvalue'): + if key_field is None: + key_field = CharField(max_length=255, primary_key=True) + if not key_field.primary_key: + raise ValueError('key_field must have primary_key=True.') + + if value_field is None: + value_field = PickleField() + + self._key_field = key_field + self._value_field = value_field + self._ordered = ordered + self._database = database or SqliteExtDatabase(':memory:') + self._table_name = table_name + if isinstance(self._database, PostgresqlDatabase): + self.upsert = self._postgres_upsert + self.update = self._postgres_update + else: + self.upsert = self._upsert + self.update = self._update + + self.model = self.create_model() + self.key = self.model.key + self.value = self.model.value + + # Ensure table exists. + self.model.create_table() + + def create_model(self): + class KeyValue(Model): + key = self._key_field + value = self._value_field + class Meta: + database = self._database + table_name = self._table_name + return KeyValue + + def query(self, *select): + query = self.model.select(*select).tuples() + if self._ordered: + query = query.order_by(self.key) + return query + + def convert_expression(self, expr): + if not isinstance(expr, Expression): + return (self.key == expr), True + return expr, False + + def __contains__(self, key): + expr, _ = self.convert_expression(key) + return self.model.select().where(expr).exists() + + def __len__(self): + return len(self.model) + + def __getitem__(self, expr): + converted, is_single = self.convert_expression(expr) + query = self.query(self.value).where(converted) + item_getter = operator.itemgetter(0) + result = [item_getter(row) for row in query] + if len(result) == 0 and is_single: + raise KeyError(expr) + elif is_single: + return result[0] + return result + + def _upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict('replace') + .execute()) + + def _postgres_upsert(self, key, value): + (self.model + .insert(key=key, value=value) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def __setitem__(self, expr, value): + if isinstance(expr, Expression): + self.model.update(value=value).where(expr).execute() + else: + self.upsert(expr, value) + + def __delitem__(self, expr): + converted, _ = self.convert_expression(expr) + self.model.delete().where(converted).execute() + + def __iter__(self): + return iter(self.query().execute()) + + def keys(self): + return map(operator.itemgetter(0), self.query(self.key)) + + def values(self): + return map(operator.itemgetter(0), self.query(self.value)) + + def items(self): + return iter(self.query().execute()) + + def _update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict('replace') + .execute()) + + def _postgres_update(self, __data=None, **mapping): + if __data is not None: + mapping.update(__data) + return (self.model + .insert_many(list(mapping.items()), + fields=[self.key, self.value]) + .on_conflict(conflict_target=[self.key], + preserve=[self.value]) + .execute()) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, default=Sentinel): + with self._database.atomic(): + try: + result = self[key] + except KeyError: + if default is Sentinel: + raise + return default + del self[key] + + return result + + def clear(self): + self.model.delete().execute() diff --git a/libs/playhouse/migrate.py b/libs/playhouse/migrate.py new file mode 100644 index 000000000..f536a45c9 --- /dev/null +++ b/libs/playhouse/migrate.py @@ -0,0 +1,886 @@ +""" +Lightweight schema migrations. + +NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some +features. + +Example Usage +------------- + +Instantiate a migrator: + + # Postgres example: + my_db = PostgresqlDatabase(...) + migrator = PostgresqlMigrator(my_db) + + # SQLite example: + my_db = SqliteDatabase('my_database.db') + migrator = SqliteMigrator(my_db) + +Then you will use the `migrate` function to run various `Operation`s which +are generated by the migrator: + + migrate( + migrator.add_column('some_table', 'column_name', CharField(default='')) + ) + +Migrations are not run inside a transaction, so if you wish the migration to +run in a transaction you will need to wrap the call to `migrate` in a +transaction block, e.g.: + + with my_db.transaction(): + migrate(...) + +Supported Operations +-------------------- + +Add new field(s) to an existing model: + + # Create your field instances. For non-null fields you must specify a + # default value. + pubdate_field = DateTimeField(null=True) + comment_field = TextField(default='') + + # Run the migration, specifying the database table, field name and field. + migrate( + migrator.add_column('comment_tbl', 'pub_date', pubdate_field), + migrator.add_column('comment_tbl', 'comment', comment_field), + ) + +Renaming a field: + + # Specify the table, original name of the column, and its new name. + migrate( + migrator.rename_column('story', 'pub_date', 'publish_date'), + migrator.rename_column('story', 'mod_date', 'modified_date'), + ) + +Dropping a field: + + migrate( + migrator.drop_column('story', 'some_old_field'), + ) + +Making a field nullable or not nullable: + + # Note that when making a field not null that field must not have any + # NULL values present. + migrate( + # Make `pub_date` allow NULL values. + migrator.drop_not_null('story', 'pub_date'), + + # Prevent `modified_date` from containing NULL values. + migrator.add_not_null('story', 'modified_date'), + ) + +Renaming a table: + + migrate( + migrator.rename_table('story', 'stories_tbl'), + ) + +Adding an index: + + # Specify the table, column names, and whether the index should be + # UNIQUE or not. + migrate( + # Create an index on the `pub_date` column. + migrator.add_index('story', ('pub_date',), False), + + # Create a multi-column index on the `pub_date` and `status` fields. + migrator.add_index('story', ('pub_date', 'status'), False), + + # Create a unique index on the category and title fields. + migrator.add_index('story', ('category_id', 'title'), True), + ) + +Dropping an index: + + # Specify the index name. + migrate(migrator.drop_index('story', 'story_pub_date_status')) + +Adding or dropping table constraints: + +.. code-block:: python + + # Add a CHECK() constraint to enforce the price cannot be negative. + migrate(migrator.add_constraint( + 'products', + 'price_check', + Check('price >= 0'))) + + # Remove the price check constraint. + migrate(migrator.drop_constraint('products', 'price_check')) + + # Add a UNIQUE constraint on the first and last names. + migrate(migrator.add_unique('person', 'first_name', 'last_name')) +""" +from collections import namedtuple +import functools +import hashlib +import re + +from peewee import * +from peewee import CommaNodeList +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import callable_ +from peewee import sort_models +from peewee import _truncate_constraint_name +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + + +class Operation(object): + """Encapsulate a single schema altering operation.""" + def __init__(self, migrator, method, *args, **kwargs): + self.migrator = migrator + self.method = method + self.args = args + self.kwargs = kwargs + + def execute(self, node): + self.migrator.database.execute(node) + + def _handle_result(self, result): + if isinstance(result, (Node, Context)): + self.execute(result) + elif isinstance(result, Operation): + result.run() + elif isinstance(result, (list, tuple)): + for item in result: + self._handle_result(item) + + def run(self): + kwargs = self.kwargs.copy() + kwargs['with_context'] = True + method = getattr(self.migrator, self.method) + self._handle_result(method(*self.args, **kwargs)) + + +def operation(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + with_context = kwargs.pop('with_context', False) + if with_context: + return fn(self, *args, **kwargs) + return Operation(self, fn.__name__, *args, **kwargs) + return inner + + +def make_index_name(table_name, columns): + index_name = '_'.join((table_name,) + tuple(columns)) + if len(index_name) > 64: + index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest() + index_name = '%s_%s' % (index_name[:56], index_hash[:7]) + return index_name + + +class SchemaMigrator(object): + explicit_create_foreign_key = False + explicit_delete_foreign_key = False + + def __init__(self, database): + self.database = database + + def make_context(self): + return self.database.get_sql_context() + + @classmethod + def from_database(cls, database): + if CockroachDatabase and isinstance(database, CockroachDatabase): + return CockroachDBMigrator(database) + elif isinstance(database, PostgresqlDatabase): + return PostgresqlMigrator(database) + elif isinstance(database, MySQLDatabase): + return MySQLMigrator(database) + elif isinstance(database, SqliteDatabase): + return SqliteMigrator(database) + raise ValueError('Unsupported database: %s' % database) + + @operation + def apply_default(self, table, column_name, field): + default = field.default + if callable_(default): + default = default() + + return (self.make_context() + .literal('UPDATE ') + .sql(Entity(table)) + .literal(' SET ') + .sql(Expression( + Entity(column_name), + OP.EQ, + field.db_value(default), + flat=True))) + + def _alter_table(self, ctx, table): + return ctx.literal('ALTER TABLE ').sql(Entity(table)) + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' ALTER COLUMN ') + .sql(Entity(column))) + + @operation + def alter_add_column(self, table, column_name, field): + # Make field null at first. + ctx = self.make_context() + field_null, field.null = field.null, True + + # Set the field's column-name and name, if it is not set or doesn't + # match the new value. + if field.column_name != column_name: + field.name = field.column_name = column_name + + (self + ._alter_table(ctx, table) + .literal(' ADD COLUMN ') + .sql(field.ddl(ctx))) + + field.null = field_null + if isinstance(field, ForeignKeyField): + self.add_inline_fk_sql(ctx, field) + return ctx + + @operation + def add_constraint(self, table, name, constraint): + return (self + ._alter_table(self.make_context(), table) + .literal(' ADD CONSTRAINT ') + .sql(Entity(name)) + .literal(' ') + .sql(constraint)) + + @operation + def add_unique(self, table, *column_names): + constraint_name = 'uniq_%s' % '_'.join(column_names) + constraint = NodeList(( + SQL('UNIQUE'), + EnclosedNodeList([Entity(column) for column in column_names]))) + return self.add_constraint(table, constraint_name, constraint) + + @operation + def drop_constraint(self, table, name): + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP CONSTRAINT ') + .sql(Entity(name))) + + def add_inline_fk_sql(self, ctx, field): + ctx = (ctx + .literal(' REFERENCES ') + .sql(Entity(field.rel_model._meta.table_name)) + .literal(' ') + .sql(EnclosedNodeList((Entity(field.rel_field.column_name),)))) + if field.on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % field.on_delete) + if field.on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % field.on_update) + return ctx + + @operation + def add_foreign_key_constraint(self, table, column_name, rel, rel_column, + on_delete=None, on_update=None): + constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel) + ctx = (self + .make_context() + .literal('ALTER TABLE ') + .sql(Entity(table)) + .literal(' ADD CONSTRAINT ') + .sql(Entity(_truncate_constraint_name(constraint))) + .literal(' FOREIGN KEY ') + .sql(EnclosedNodeList((Entity(column_name),))) + .literal(' REFERENCES ') + .sql(Entity(rel)) + .literal(' (') + .sql(Entity(rel_column)) + .literal(')')) + if on_delete is not None: + ctx = ctx.literal(' ON DELETE %s' % on_delete) + if on_update is not None: + ctx = ctx.literal(' ON UPDATE %s' % on_update) + return ctx + + @operation + def add_column(self, table, column_name, field): + # Adding a column is complicated by the fact that if there are rows + # present and the field is non-null, then we need to first add the + # column as a nullable field, then set the value, then add a not null + # constraint. + if not field.null and field.default is None: + raise ValueError('%s is not null but has no default' % column_name) + + is_foreign_key = isinstance(field, ForeignKeyField) + if is_foreign_key and not field.rel_field: + raise ValueError('Foreign keys must specify a `field`.') + + operations = [self.alter_add_column(table, column_name, field)] + + # In the event the field is *not* nullable, update with the default + # value and set not null. + if not field.null: + operations.extend([ + self.apply_default(table, column_name, field), + self.add_not_null(table, column_name)]) + + if is_foreign_key and self.explicit_create_foreign_key: + operations.append( + self.add_foreign_key_constraint( + table, + column_name, + field.rel_model._meta.table_name, + field.rel_field.column_name, + field.on_delete, + field.on_update)) + + if field.index or field.unique: + using = getattr(field, 'index_type', None) + operations.append(self.add_index(table, (column_name,), + field.unique, using)) + + return operations + + @operation + def drop_foreign_key_constraint(self, table, column_name): + raise NotImplementedError + + @operation + def drop_column(self, table, column_name, cascade=True): + ctx = self.make_context() + (self._alter_table(ctx, table) + .literal(' DROP COLUMN ') + .sql(Entity(column_name))) + + if cascade: + ctx.literal(' CASCADE') + + fk_columns = [ + foreign_key.column + for foreign_key in self.database.get_foreign_keys(table)] + if column_name in fk_columns and self.explicit_delete_foreign_key: + return [self.drop_foreign_key_constraint(table, column_name), ctx] + + return ctx + + @operation + def rename_column(self, table, old_name, new_name): + return (self + ._alter_table(self.make_context(), table) + .literal(' RENAME COLUMN ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + @operation + def add_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' SET NOT NULL')) + + @operation + def drop_not_null(self, table, column): + return (self + ._alter_column(self.make_context(), table, column) + .literal(' DROP NOT NULL')) + + @operation + def alter_column_type(self, table, column, field, cast=None): + # ALTER TABLE
ALTER COLUMN + ctx = self.make_context() + ctx = (self + ._alter_column(ctx, table, column) + .literal(' TYPE ') + .sql(field.ddl_datatype(ctx))) + if cast is not None: + if not isinstance(cast, Node): + cast = SQL(cast) + ctx = ctx.literal(' USING ').sql(cast) + return ctx + + @operation + def rename_table(self, old_name, new_name): + return (self + ._alter_table(self.make_context(), old_name) + .literal(' RENAME TO ') + .sql(Entity(new_name))) + + @operation + def add_index(self, table, columns, unique=False, using=None): + ctx = self.make_context() + index_name = make_index_name(table, columns) + table_obj = Table(table) + cols = [getattr(table_obj.c, column) for column in columns] + index = Index(index_name, table_obj, cols, unique=unique, using=using) + return ctx.sql(index) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name))) + + +class PostgresqlMigrator(SchemaMigrator): + def _primary_key_columns(self, tbl): + query = """ + SELECT pg_attribute.attname + FROM pg_index, pg_class, pg_attribute + WHERE + pg_class.oid = '%s'::regclass AND + indrelid = pg_class.oid AND + pg_attribute.attrelid = pg_class.oid AND + pg_attribute.attnum = any(pg_index.indkey) AND + indisprimary; + """ + cursor = self.database.execute_sql(query % tbl) + return [row[0] for row in cursor.fetchall()] + + @operation + def set_search_path(self, schema_name): + return (self + .make_context() + .literal('SET search_path TO %s' % schema_name)) + + @operation + def rename_table(self, old_name, new_name): + pk_names = self._primary_key_columns(old_name) + ParentClass = super(PostgresqlMigrator, self) + + operations = [ + ParentClass.rename_table(old_name, new_name, with_context=True)] + + if len(pk_names) == 1: + # Check for existence of primary key sequence. + seq_name = '%s_%s_seq' % (old_name, pk_names[0]) + query = """ + SELECT 1 + FROM information_schema.sequences + WHERE LOWER(sequence_name) = LOWER(%s) + """ + cursor = self.database.execute_sql(query, (seq_name,)) + if bool(cursor.fetchone()): + new_seq_name = '%s_%s_seq' % (new_name, pk_names[0]) + operations.append(ParentClass.rename_table( + seq_name, new_seq_name)) + + return operations + + +class CockroachDBMigrator(PostgresqlMigrator): + explicit_create_foreign_key = True + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' CASCADE')) + + +class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk', + 'default', 'extra'))): + @property + def is_pk(self): + return self.pk == 'PRI' + + @property + def is_unique(self): + return self.pk == 'UNI' + + @property + def is_null(self): + return self.null == 'YES' + + def sql(self, column_name=None, is_null=None): + if is_null is None: + is_null = self.is_null + if column_name is None: + column_name = self.name + parts = [ + Entity(column_name), + SQL(self.definition)] + if self.is_unique: + parts.append(SQL('UNIQUE')) + if is_null: + parts.append(SQL('NULL')) + else: + parts.append(SQL('NOT NULL')) + if self.is_pk: + parts.append(SQL('PRIMARY KEY')) + if self.extra: + parts.append(SQL(self.extra)) + return NodeList(parts) + + +class MySQLMigrator(SchemaMigrator): + explicit_create_foreign_key = True + explicit_delete_foreign_key = True + + def _alter_column(self, ctx, table, column): + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column))) + + @operation + def rename_table(self, old_name, new_name): + return (self + .make_context() + .literal('RENAME TABLE ') + .sql(Entity(old_name)) + .literal(' TO ') + .sql(Entity(new_name))) + + def _get_column_definition(self, table, column_name): + cursor = self.database.execute_sql('DESCRIBE `%s`;' % table) + rows = cursor.fetchall() + for row in rows: + column = MySQLColumn(*row) + if column.name == column_name: + return column + return False + + def get_foreign_key_constraint(self, table, column_name): + cursor = self.database.execute_sql( + ('SELECT constraint_name ' + 'FROM information_schema.key_column_usage WHERE ' + 'table_schema = DATABASE() AND ' + 'table_name = %s AND ' + 'column_name = %s AND ' + 'referenced_table_name IS NOT NULL AND ' + 'referenced_column_name IS NOT NULL;'), + (table, column_name)) + result = cursor.fetchone() + if not result: + raise AttributeError( + 'Unable to find foreign key constraint for ' + '"%s" on table "%s".' % (table, column_name)) + return result[0] + + @operation + def drop_foreign_key_constraint(self, table, column_name): + fk_constraint = self.get_foreign_key_constraint(table, column_name) + return (self + ._alter_table(self.make_context(), table) + .literal(' DROP FOREIGN KEY ') + .sql(Entity(fk_constraint))) + + def add_inline_fk_sql(self, ctx, field): + pass + + @operation + def add_not_null(self, table, column): + column_def = self._get_column_definition(table, column) + add_not_null = (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column_def.sql(is_null=False))) + + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + if column not in fk_objects: + return add_not_null + + fk_metadata = fk_objects[column] + return (self.drop_foreign_key_constraint(table, column), + add_not_null, + self.add_foreign_key_constraint( + table, + column, + fk_metadata.dest_table, + fk_metadata.dest_column)) + + @operation + def drop_not_null(self, table, column): + column = self._get_column_definition(table, column) + if column.is_pk: + raise ValueError('Primary keys can not be null') + return (self + ._alter_table(self.make_context(), table) + .literal(' MODIFY ') + .sql(column.sql(is_null=True))) + + @operation + def rename_column(self, table, old_name, new_name): + fk_objects = dict( + (fk.column, fk) + for fk in self.database.get_foreign_keys(table)) + is_foreign_key = old_name in fk_objects + + column = self._get_column_definition(table, old_name) + rename_ctx = (self + ._alter_table(self.make_context(), table) + .literal(' CHANGE ') + .sql(Entity(old_name)) + .literal(' ') + .sql(column.sql(column_name=new_name))) + if is_foreign_key: + fk_metadata = fk_objects[old_name] + return [ + self.drop_foreign_key_constraint(table, old_name), + rename_ctx, + self.add_foreign_key_constraint( + table, + new_name, + fk_metadata.dest_table, + fk_metadata.dest_column), + ] + else: + return rename_ctx + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'MySQL.') + ctx = self.make_context() + return (self + ._alter_table(ctx, table) + .literal(' MODIFY ') + .sql(Entity(column)) + .literal(' ') + .sql(field.ddl(ctx))) + + @operation + def drop_index(self, table, index_name): + return (self + .make_context() + .literal('DROP INDEX ') + .sql(Entity(index_name)) + .literal(' ON ') + .sql(Entity(table))) + + +class SqliteMigrator(SchemaMigrator): + """ + SQLite supports a subset of ALTER TABLE queries, view the docs for the + full details http://sqlite.org/lang_altertable.html + """ + column_re = re.compile('(.+?)\((.+)\)') + column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+') + column_name_re = re.compile(r'''["`']?([\w]+)''') + fk_re = re.compile(r'FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I) + + def _get_column_names(self, table): + res = self.database.execute_sql('select * from "%s" limit 1' % table) + return [item[0] for item in res.description] + + def _get_create_table(self, table): + res = self.database.execute_sql( + ('select name, sql from sqlite_master ' + 'where type=? and LOWER(name)=?'), + ['table', table.lower()]) + return res.fetchone() + + @operation + def _update_column(self, table, column_to_update, fn): + columns = set(column.name.lower() + for column in self.database.get_columns(table)) + if column_to_update.lower() not in columns: + raise ValueError('Column "%s" does not exist on "%s"' % + (column_to_update, table)) + + # Get the SQL used to create the given table. + table, create_table = self._get_create_table(table) + + # Get the indexes and SQL to re-create indexes. + indexes = self.database.get_indexes(table) + + # Find any foreign keys we may need to remove. + self.database.get_foreign_keys(table) + + # Make sure the create_table does not contain any newlines or tabs, + # allowing the regex to work correctly. + create_table = re.sub(r'\s+', ' ', create_table) + + # Parse out the `CREATE TABLE` and column list portions of the query. + raw_create, raw_columns = self.column_re.search(create_table).groups() + + # Clean up the individual column definitions. + split_columns = self.column_split_re.findall(raw_columns) + column_defs = [col.strip() for col in split_columns] + + new_column_defs = [] + new_column_names = [] + original_column_names = [] + constraint_terms = ('foreign ', 'primary ', 'constraint ', 'check ') + + for column_def in column_defs: + column_name, = self.column_name_re.match(column_def).groups() + + if column_name == column_to_update: + new_column_def = fn(column_name, column_def) + if new_column_def: + new_column_defs.append(new_column_def) + original_column_names.append(column_name) + column_name, = self.column_name_re.match( + new_column_def).groups() + new_column_names.append(column_name) + else: + new_column_defs.append(column_def) + + # Avoid treating constraints as columns. + if not column_def.lower().startswith(constraint_terms): + new_column_names.append(column_name) + original_column_names.append(column_name) + + # Create a mapping of original columns to new columns. + original_to_new = dict(zip(original_column_names, new_column_names)) + new_column = original_to_new.get(column_to_update) + + fk_filter_fn = lambda column_def: column_def + if not new_column: + # Remove any foreign keys associated with this column. + fk_filter_fn = lambda column_def: None + elif new_column != column_to_update: + # Update any foreign keys for this column. + fk_filter_fn = lambda column_def: self.fk_re.sub( + 'FOREIGN KEY ("%s") ' % new_column, + column_def) + + cleaned_columns = [] + for column_def in new_column_defs: + match = self.fk_re.match(column_def) + if match is not None and match.groups()[0] == column_to_update: + column_def = fk_filter_fn(column_def) + if column_def: + cleaned_columns.append(column_def) + + # Update the name of the new CREATE TABLE query. + temp_table = table + '__tmp__' + rgx = re.compile('("?)%s("?)' % table, re.I) + create = rgx.sub( + '\\1%s\\2' % temp_table, + raw_create) + + # Create the new table. + columns = ', '.join(cleaned_columns) + queries = [ + NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]), + SQL('%s (%s)' % (create.strip(), columns))] + + # Populate new table. + populate_table = NodeList(( + SQL('INSERT INTO'), + Entity(temp_table), + EnclosedNodeList([Entity(col) for col in new_column_names]), + SQL('SELECT'), + CommaNodeList([Entity(col) for col in original_column_names]), + SQL('FROM'), + Entity(table))) + drop_original = NodeList([SQL('DROP TABLE'), Entity(table)]) + + # Drop existing table and rename temp table. + queries += [ + populate_table, + drop_original, + self.rename_table(temp_table, table)] + + # Re-create user-defined indexes. User-defined indexes will have a + # non-empty SQL attribute. + for index in filter(lambda idx: idx.sql, indexes): + if column_to_update not in index.columns: + queries.append(SQL(index.sql)) + elif new_column: + sql = self._fix_index(index.sql, column_to_update, new_column) + if sql is not None: + queries.append(SQL(sql)) + + return queries + + def _fix_index(self, sql, column_to_update, new_column): + # Split on the name of the column to update. If it splits into two + # pieces, then there's no ambiguity and we can simply replace the + # old with the new. + parts = sql.split(column_to_update) + if len(parts) == 2: + return sql.replace(column_to_update, new_column) + + # Find the list of columns in the index expression. + lhs, rhs = sql.rsplit('(', 1) + + # Apply the same "split in two" logic to the column list portion of + # the query. + if len(rhs.split(column_to_update)) == 2: + return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column)) + + # Strip off the trailing parentheses and go through each column. + parts = rhs.rsplit(')', 1)[0].split(',') + columns = [part.strip('"`[]\' ') for part in parts] + + # `columns` looks something like: ['status', 'timestamp" DESC'] + # https://www.sqlite.org/lang_keywords.html + # Strip out any junk after the column name. + clean = [] + for column in columns: + if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): + column = new_column + column[len(column_to_update):] + clean.append(column) + + return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) + + @operation + def drop_column(self, table, column_name, cascade=True): + return self._update_column(table, column_name, lambda a, b: None) + + @operation + def rename_column(self, table, old_name, new_name): + def _rename(column_name, column_def): + return column_def.replace(column_name, new_name) + return self._update_column(table, old_name, _rename) + + @operation + def add_not_null(self, table, column): + def _add_not_null(column_name, column_def): + return column_def + ' NOT NULL' + return self._update_column(table, column, _add_not_null) + + @operation + def drop_not_null(self, table, column): + def _drop_not_null(column_name, column_def): + return column_def.replace('NOT NULL', '') + return self._update_column(table, column, _drop_not_null) + + @operation + def alter_column_type(self, table, column, field, cast=None): + if cast is not None: + raise ValueError('alter_column_type() does not support cast with ' + 'Sqlite.') + ctx = self.make_context() + def _alter_column_type(column_name, column_def): + node_list = field.ddl(ctx) + sql, _ = ctx.sql(Entity(column)).sql(node_list).query() + return sql + return self._update_column(table, column, _alter_column_type) + + @operation + def add_constraint(self, table, name, constraint): + raise NotImplementedError + + @operation + def drop_constraint(self, table, name): + raise NotImplementedError + + @operation + def add_foreign_key_constraint(self, table, column_name, field, + on_delete=None, on_update=None): + raise NotImplementedError + + +def migrate(*operations, **kwargs): + for operation in operations: + operation.run() diff --git a/libs/playhouse/mysql_ext.py b/libs/playhouse/mysql_ext.py new file mode 100644 index 000000000..9ee265573 --- /dev/null +++ b/libs/playhouse/mysql_ext.py @@ -0,0 +1,49 @@ +import json + +try: + import mysql.connector as mysql_connector +except ImportError: + mysql_connector = None + +from peewee import ImproperlyConfigured +from peewee import MySQLDatabase +from peewee import NodeList +from peewee import SQL +from peewee import TextField +from peewee import fn + + +class MySQLConnectorDatabase(MySQLDatabase): + def _connect(self): + if mysql_connector is None: + raise ImproperlyConfigured('MySQL connector not installed!') + return mysql_connector.connect(db=self.database, **self.connect_params) + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + return self._state.conn.cursor(buffered=True) + + +class JSONField(TextField): + field_type = 'JSON' + + def db_value(self, value): + if value is not None: + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + +def Match(columns, expr, modifier=None): + if isinstance(columns, (list, tuple)): + match = fn.MATCH(*columns) # Tuple of one or more columns / fields. + else: + match = fn.MATCH(columns) # Single column / field. + args = expr if modifier is None else NodeList((expr, SQL(modifier))) + return NodeList((match, fn.AGAINST(args))) diff --git a/libs/playhouse/pool.py b/libs/playhouse/pool.py new file mode 100644 index 000000000..2ee3b486f --- /dev/null +++ b/libs/playhouse/pool.py @@ -0,0 +1,318 @@ +""" +Lightweight connection pooling for peewee. + +In a multi-threaded application, up to `max_connections` will be opened. Each +thread (or, if using gevent, greenlet) will have it's own connection. + +In a single-threaded application, only one connection will be created. It will +be continually recycled until either it exceeds the stale timeout or is closed +explicitly (using `.manual_close()`). + +By default, all your application needs to do is ensure that connections are +closed when you are finished with them, and they will be returned to the pool. +For web applications, this typically means that at the beginning of a request, +you will open a connection, and when you return a response, you will close the +connection. + +Simple Postgres pool example code: + + # Use the special postgresql extensions. + from playhouse.pool import PooledPostgresqlExtDatabase + + db = PooledPostgresqlExtDatabase( + 'my_app', + max_connections=32, + stale_timeout=300, # 5 minutes. + user='postgres') + + class BaseModel(Model): + class Meta: + database = db + +That's it! +""" +import heapq +import logging +import random +import time +from collections import namedtuple +from itertools import chain + +try: + from psycopg2.extensions import TRANSACTION_STATUS_IDLE + from psycopg2.extensions import TRANSACTION_STATUS_INERROR + from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN +except ImportError: + TRANSACTION_STATUS_IDLE = \ + TRANSACTION_STATUS_INERROR = \ + TRANSACTION_STATUS_UNKNOWN = None + +from peewee import MySQLDatabase +from peewee import PostgresqlDatabase +from peewee import SqliteDatabase + +logger = logging.getLogger('peewee.pool') + + +def make_int(val): + if val is not None and not isinstance(val, (int, float)): + return int(val) + return val + + +class MaxConnectionsExceeded(ValueError): pass + + +PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection', + 'checked_out')) + + +class PooledDatabase(object): + def __init__(self, database, max_connections=20, stale_timeout=None, + timeout=None, **kwargs): + self._max_connections = make_int(max_connections) + self._stale_timeout = make_int(stale_timeout) + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + # Available / idle connections stored in a heap, sorted oldest first. + self._connections = [] + + # Mapping of connection id to PoolConnection. Ordinarily we would want + # to use something like a WeakKeyDictionary, but Python typically won't + # allow us to create weak references to connection objects. + self._in_use = {} + + # Use the memory address of the connection as the key in the event the + # connection object is not hashable. Connections will not get + # garbage-collected, however, because a reference to them will persist + # in "_in_use" as long as the conn has not been closed. + self.conn_key = id + + super(PooledDatabase, self).__init__(database, **kwargs) + + def init(self, database, max_connections=None, stale_timeout=None, + timeout=None, **connect_kwargs): + super(PooledDatabase, self).init(database, **connect_kwargs) + if max_connections is not None: + self._max_connections = make_int(max_connections) + if stale_timeout is not None: + self._stale_timeout = make_int(stale_timeout) + if timeout is not None: + self._wait_timeout = make_int(timeout) + if self._wait_timeout == 0: + self._wait_timeout = float('inf') + + def connect(self, reuse_if_open=False): + if not self._wait_timeout: + return super(PooledDatabase, self).connect(reuse_if_open) + + expires = time.time() + self._wait_timeout + while expires > time.time(): + try: + ret = super(PooledDatabase, self).connect(reuse_if_open) + except MaxConnectionsExceeded: + time.sleep(0.1) + else: + return ret + raise MaxConnectionsExceeded('Max connections exceeded, timed out ' + 'attempting to connect.') + + def _connect(self): + while True: + try: + # Remove the oldest connection from the heap. + ts, conn = heapq.heappop(self._connections) + key = self.conn_key(conn) + except IndexError: + ts = conn = None + logger.debug('No connection available in pool.') + break + else: + if self._is_closed(conn): + # This connecton was closed, but since it was not stale + # it got added back to the queue of available conns. We + # then closed it and marked it as explicitly closed, so + # it's safe to throw it away now. + # (Because Database.close() calls Database._close()). + logger.debug('Connection %s was closed.', key) + ts = conn = None + elif self._stale_timeout and self._is_stale(ts): + # If we are attempting to check out a stale connection, + # then close it. We don't need to mark it in the "closed" + # set, because it is not in the list of available conns + # anymore. + logger.debug('Connection %s was stale, closing.', key) + self._close(conn, True) + ts = conn = None + else: + break + + if conn is None: + if self._max_connections and ( + len(self._in_use) >= self._max_connections): + raise MaxConnectionsExceeded('Exceeded maximum connections.') + conn = super(PooledDatabase, self)._connect() + ts = time.time() - random.random() / 1000 + key = self.conn_key(conn) + logger.debug('Created new connection %s.', key) + + self._in_use[key] = PoolConnection(ts, conn, time.time()) + return conn + + def _is_stale(self, timestamp): + # Called on check-out and check-in to ensure the connection has + # not outlived the stale timeout. + return (time.time() - timestamp) > self._stale_timeout + + def _is_closed(self, conn): + return False + + def _can_reuse(self, conn): + # Called on check-in to make sure the connection can be re-used. + return True + + def _close(self, conn, close_conn=False): + key = self.conn_key(conn) + if close_conn: + super(PooledDatabase, self)._close(conn) + elif key in self._in_use: + pool_conn = self._in_use.pop(key) + if self._stale_timeout and self._is_stale(pool_conn.timestamp): + logger.debug('Closing stale connection %s.', key) + super(PooledDatabase, self)._close(conn) + elif self._can_reuse(conn): + logger.debug('Returning %s to pool.', key) + heapq.heappush(self._connections, (pool_conn.timestamp, conn)) + else: + logger.debug('Closed %s.', key) + + def manual_close(self): + """ + Close the underlying connection without returning it to the pool. + """ + if self.is_closed(): + return False + + # Obtain reference to the connection in-use by the calling thread. + conn = self.connection() + + # A connection will only be re-added to the available list if it is + # marked as "in use" at the time it is closed. We will explicitly + # remove it from the "in use" list, call "close()" for the + # side-effects, and then explicitly close the connection. + self._in_use.pop(self.conn_key(conn), None) + self.close() + self._close(conn, close_conn=True) + + def close_idle(self): + # Close any open connections that are not currently in-use. + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + self._connections = [] + + def close_stale(self, age=600): + # Close any connections that are in-use but were checked out quite some + # time ago and can be considered stale. + with self._lock: + in_use = {} + cutoff = time.time() - age + n = 0 + for key, pool_conn in self._in_use.items(): + if pool_conn.checked_out < cutoff: + self._close(pool_conn.connection, close_conn=True) + n += 1 + else: + in_use[key] = pool_conn + self._in_use = in_use + return n + + def close_all(self): + # Close all connections -- available and in-use. Warning: may break any + # active connections used by other threads. + self.close() + with self._lock: + for _, conn in self._connections: + self._close(conn, close_conn=True) + for pool_conn in self._in_use.values(): + self._close(pool_conn.connection, close_conn=True) + self._connections = [] + self._in_use = {} + + +class PooledMySQLDatabase(PooledDatabase, MySQLDatabase): + def _is_closed(self, conn): + try: + conn.ping(False) + except: + return True + else: + return False + + +class _PooledPostgresqlDatabase(PooledDatabase): + def _is_closed(self, conn): + if conn.closed: + return True + + txn_status = conn.get_transaction_status() + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return True + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return False + + def _can_reuse(self, conn): + txn_status = conn.get_transaction_status() + # Do not return connection in an error state, as subsequent queries + # will all fail. If the status is unknown then we lost the connection + # to the server and the connection should not be re-used. + if txn_status == TRANSACTION_STATUS_UNKNOWN: + return False + elif txn_status == TRANSACTION_STATUS_INERROR: + conn.reset() + elif txn_status != TRANSACTION_STATUS_IDLE: + conn.rollback() + return True + +class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase): + pass + +try: + from playhouse.postgres_ext import PostgresqlExtDatabase + + class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): + pass +except ImportError: + PooledPostgresqlExtDatabase = None + + +class _PooledSqliteDatabase(PooledDatabase): + def _is_closed(self, conn): + try: + conn.total_changes + except: + return True + else: + return False + +class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): + pass + +try: + from playhouse.sqlite_ext import SqliteExtDatabase + + class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase): + pass +except ImportError: + PooledSqliteExtDatabase = None + +try: + from playhouse.sqlite_ext import CSqliteExtDatabase + + class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase): + pass +except ImportError: + PooledCSqliteExtDatabase = None diff --git a/libs/playhouse/postgres_ext.py b/libs/playhouse/postgres_ext.py new file mode 100644 index 000000000..a50510d0f --- /dev/null +++ b/libs/playhouse/postgres_ext.py @@ -0,0 +1,493 @@ +""" +Collection of postgres-specific extensions, currently including: + +* Support for hstore, a key/value type storage +""" +import json +import logging +import uuid + +from peewee import * +from peewee import ColumnBase +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import SENTINEL +from peewee import __exception_wrapper__ + +try: + from psycopg2cffi import compat + compat.register() +except ImportError: + pass + +try: + from psycopg2.extras import register_hstore +except ImportError: + def register_hstore(c, globally): + pass +try: + from psycopg2.extras import Json +except: + Json = None + + +logger = logging.getLogger('peewee') + + +HCONTAINS_DICT = '@>' +HCONTAINS_KEYS = '?&' +HCONTAINS_KEY = '?' +HCONTAINS_ANY_KEY = '?|' +HKEY = '->' +HUPDATE = '||' +ACONTAINS = '@>' +ACONTAINS_ANY = '&&' +TS_MATCH = '@@' +JSONB_CONTAINS = '@>' +JSONB_CONTAINED_BY = '<@' +JSONB_CONTAINS_KEY = '?' +JSONB_CONTAINS_ANY_KEY = '?|' +JSONB_CONTAINS_ALL_KEYS = '?&' +JSONB_EXISTS = '?' +JSONB_REMOVE = '-' + + +class _LookupNode(ColumnBase): + def __init__(self, node, parts): + self.node = node + self.parts = parts + super(_LookupNode, self).__init__() + + def clone(self): + return type(self)(self.node, list(self.parts)) + + def __hash__(self): + return hash((self.__class__.__name__, id(self))) + + +class _JsonLookupBase(_LookupNode): + def __init__(self, node, parts, as_json=False): + super(_JsonLookupBase, self).__init__(node, parts) + self._as_json = as_json + + def clone(self): + return type(self)(self.node, list(self.parts), self._as_json) + + @Node.copy + def as_json(self, as_json=True): + self._as_json = as_json + + def concat(self, rhs): + if not isinstance(rhs, Node): + rhs = Json(rhs) + return Expression(self.as_json(True), OP.CONCAT, rhs) + + def contains(self, other): + clone = self.as_json(True) + if isinstance(other, (list, dict)): + return Expression(clone, JSONB_CONTAINS, Json(other)) + return Expression(clone, JSONB_EXISTS, other) + + def contains_any(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ANY_KEY, + Value(list(keys), unpack=False)) + + def contains_all(self, *keys): + return Expression( + self.as_json(True), + JSONB_CONTAINS_ALL_KEYS, + Value(list(keys), unpack=False)) + + def has_key(self, key): + return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) + + +class JsonLookup(_JsonLookupBase): + def __getitem__(self, value): + return JsonLookup(self.node, self.parts + [value], self._as_json) + + def __sql__(self, ctx): + ctx.sql(self.node) + for part in self.parts[:-1]: + ctx.literal('->').sql(part) + if self.parts: + (ctx + .literal('->' if self._as_json else '->>') + .sql(self.parts[-1])) + + return ctx + + +class JsonPath(_JsonLookupBase): + def __sql__(self, ctx): + return (ctx + .sql(self.node) + .literal('#>' if self._as_json else '#>>') + .sql(Value('{%s}' % ','.join(map(str, self.parts))))) + + +class ObjectSlice(_LookupNode): + @classmethod + def create(cls, node, value): + if isinstance(value, slice): + parts = [value.start or 0, value.stop or 0] + elif isinstance(value, int): + parts = [value] + elif isinstance(value, Node): + parts = value + else: + # Assumes colon-separated integer indexes. + parts = [int(i) for i in value.split(':')] + return cls(node, parts) + + def __sql__(self, ctx): + ctx.sql(self.node) + if isinstance(self.parts, Node): + ctx.literal('[').sql(self.parts).literal(']') + else: + ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)) + return ctx + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + +class IndexedFieldMixin(object): + default_index_type = 'GIN' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('index', True) # By default, use an index. + super(IndexedFieldMixin, self).__init__(*args, **kwargs) + + +class ArrayField(IndexedFieldMixin, Field): + passthrough = True + + def __init__(self, field_class=IntegerField, field_kwargs=None, + dimensions=1, convert_values=False, *args, **kwargs): + self.__field = field_class(**(field_kwargs or {})) + self.dimensions = dimensions + self.convert_values = convert_values + self.field_type = self.__field.field_type + super(ArrayField, self).__init__(*args, **kwargs) + + def bind(self, model, name, set_attribute=True): + ret = super(ArrayField, self).bind(model, name, set_attribute) + self.__field.bind(model, '__array_%s' % name, False) + return ret + + def ddl_datatype(self, ctx): + data_type = self.__field.ddl_datatype(ctx) + return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') + + def db_value(self, value): + if value is None or isinstance(value, Node): + return value + elif self.convert_values: + return self._process(self.__field.db_value, value, self.dimensions) + else: + return value if isinstance(value, list) else list(value) + + def python_value(self, value): + if self.convert_values and value is not None: + conv = self.__field.python_value + if isinstance(value, list): + return self._process(conv, value, self.dimensions) + else: + return conv(value) + else: + return value + + def _process(self, conv, value, dimensions): + dimensions -= 1 + if dimensions == 0: + return [conv(v) for v in value] + else: + return [self._process(conv, v, dimensions) for v in value] + + def __getitem__(self, value): + return ObjectSlice.create(self, value) + + def _e(op): + def inner(self, rhs): + return Expression(self, op, ArrayValue(self, rhs)) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def contains(self, *items): + return Expression(self, ACONTAINS, ArrayValue(self, items)) + + def contains_any(self, *items): + return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) + + +class ArrayValue(Node): + def __init__(self, field, value): + self.field = field + self.value = value + + def __sql__(self, ctx): + return (ctx + .sql(Value(self.value, unpack=False)) + .literal('::') + .sql(self.field.ddl_datatype(ctx))) + + +class DateTimeTZField(DateTimeField): + field_type = 'TIMESTAMPTZ' + + +class HStoreField(IndexedFieldMixin, Field): + field_type = 'HSTORE' + __hash__ = Field.__hash__ + + def __getitem__(self, key): + return Expression(self, HKEY, Value(key)) + + def keys(self): + return fn.akeys(self) + + def values(self): + return fn.avals(self) + + def items(self): + return fn.hstore_to_matrix(self) + + def slice(self, *args): + return fn.slice(self, Value(list(args), unpack=False)) + + def exists(self, key): + return fn.exist(self, key) + + def defined(self, key): + return fn.defined(self, key) + + def update(self, **data): + return Expression(self, HUPDATE, data) + + def delete(self, *keys): + return fn.delete(self, Value(list(keys), unpack=False)) + + def contains(self, value): + if isinstance(value, dict): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_DICT, rhs) + elif isinstance(value, (list, tuple)): + rhs = Value(value, unpack=False) + return Expression(self, HCONTAINS_KEYS, rhs) + return Expression(self, HCONTAINS_KEY, value) + + def contains_any(self, *keys): + return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys), + unpack=False)) + + +class JSONField(Field): + field_type = 'JSON' + _json_datatype = 'json' + + def __init__(self, dumps=None, *args, **kwargs): + if Json is None: + raise Exception('Your version of psycopg2 does not support JSON.') + self.dumps = dumps or json.dumps + super(JSONField, self).__init__(*args, **kwargs) + + def db_value(self, value): + if value is None: + return value + if not isinstance(value, Json): + return Cast(self.dumps(value), self._json_datatype) + return value + + def __getitem__(self, value): + return JsonLookup(self, [value]) + + def path(self, *keys): + return JsonPath(self, keys) + + def concat(self, value): + if not isinstance(value, Node): + value = Json(value) + return super(JSONField, self).concat(value) + + +def cast_jsonb(node): + return NodeList((node, SQL('::jsonb')), glue='') + + +class BinaryJSONField(IndexedFieldMixin, JSONField): + field_type = 'JSONB' + _json_datatype = 'jsonb' + __hash__ = Field.__hash__ + + def contains(self, other): + if isinstance(other, (list, dict)): + return Expression(self, JSONB_CONTAINS, Json(other)) + elif isinstance(other, JSONField): + return Expression(self, JSONB_CONTAINS, other) + return Expression(cast_jsonb(self), JSONB_EXISTS, other) + + def contained_by(self, other): + return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other)) + + def contains_any(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ANY_KEY, + Value(list(items), unpack=False)) + + def contains_all(self, *items): + return Expression( + cast_jsonb(self), + JSONB_CONTAINS_ALL_KEYS, + Value(list(items), unpack=False)) + + def has_key(self, key): + return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key) + + def remove(self, *items): + return Expression( + cast_jsonb(self), + JSONB_REMOVE, + Value(list(items), unpack=False)) + + +class TSVectorField(IndexedFieldMixin, TextField): + field_type = 'TSVECTOR' + __hash__ = Field.__hash__ + + def match(self, query, language=None, plain=False): + params = (language, query) if language is not None else (query,) + func = fn.plainto_tsquery if plain else fn.to_tsquery + return Expression(self, TS_MATCH, func(*params)) + + +def Match(field, query, language=None): + params = (language, query) if language is not None else (query,) + field_params = (language, field) if language is not None else (field,) + return Expression( + fn.to_tsvector(*field_params), + TS_MATCH, + fn.to_tsquery(*params)) + + +class IntervalField(Field): + field_type = 'INTERVAL' + + +class FetchManyCursor(object): + __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') + + def __init__(self, cursor, array_size=None): + self.cursor = cursor + self.array_size = array_size or cursor.itersize + self.exhausted = False + self.iterable = self.row_gen() + + @property + def description(self): + return self.cursor.description + + def close(self): + self.cursor.close() + + def row_gen(self): + while True: + rows = self.cursor.fetchmany(self.array_size) + if not rows: + return + for row in rows: + yield row + + def fetchone(self): + if self.exhausted: + return + try: + return next(self.iterable) + except StopIteration: + self.exhausted = True + + +class ServerSideQuery(Node): + def __init__(self, query, array_size=None): + self.query = query + self.array_size = array_size + self._cursor_wrapper = None + + def __sql__(self, ctx): + return self.query.__sql__(ctx) + + def __iter__(self): + if self._cursor_wrapper is None: + self._execute(self.query._database) + return iter(self._cursor_wrapper.iterator()) + + def _execute(self, database): + if self._cursor_wrapper is None: + cursor = database.execute(self.query, named_cursor=True, + array_size=self.array_size) + self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) + return self._cursor_wrapper + + +def ServerSide(query, database=None, array_size=None): + if database is None: + database = query._database + with database.transaction(): + server_side_query = ServerSideQuery(query, array_size=array_size) + for row in server_side_query: + yield row + + +class _empty_object(object): + __slots__ = () + def __nonzero__(self): + return False + __bool__ = __nonzero__ + +__named_cursor__ = _empty_object() + + +class PostgresqlExtDatabase(PostgresqlDatabase): + def __init__(self, *args, **kwargs): + self._register_hstore = kwargs.pop('register_hstore', False) + self._server_side_cursors = kwargs.pop('server_side_cursors', False) + super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) + + def _connect(self): + conn = super(PostgresqlExtDatabase, self)._connect() + if self._register_hstore: + register_hstore(conn, globally=True) + return conn + + def cursor(self, commit=None): + if self.is_closed(): + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') + if commit is __named_cursor__: + return self._state.conn.cursor(name=str(uuid.uuid1())) + return self._state.conn.cursor() + + def execute(self, query, commit=SENTINEL, named_cursor=False, + array_size=None, **context_options): + ctx = self.get_sql_context(**context_options) + sql, params = ctx.sql(query).query() + named_cursor = named_cursor or (self._server_side_cursors and + sql[:6].lower() == 'select') + if named_cursor: + commit = __named_cursor__ + cursor = self.execute_sql(sql, params, commit=commit) + if named_cursor: + cursor = FetchManyCursor(cursor, array_size) + return cursor diff --git a/libs/playhouse/reflection.py b/libs/playhouse/reflection.py new file mode 100644 index 000000000..accaaa774 --- /dev/null +++ b/libs/playhouse/reflection.py @@ -0,0 +1,833 @@ +try: + from collections import OrderedDict +except ImportError: + OrderedDict = dict +from collections import namedtuple +from inspect import isclass +import re + +from peewee import * +from peewee import _StringField +from peewee import _query_val_transform +from peewee import CommaNodeList +from peewee import SCOPE_VALUES +from peewee import make_snake_case +from peewee import text_type +try: + from pymysql.constants import FIELD_TYPE +except ImportError: + try: + from MySQLdb.constants import FIELD_TYPE + except ImportError: + FIELD_TYPE = None +try: + from playhouse import postgres_ext +except ImportError: + postgres_ext = None +try: + from playhouse.cockroachdb import CockroachDatabase +except ImportError: + CockroachDatabase = None + +RESERVED_WORDS = set([ + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if', + 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise', + 'return', 'try', 'while', 'with', 'yield', +]) + + +class UnknownField(object): + pass + + +class Column(object): + """ + Store metadata about a database column. + """ + primary_key_types = (IntegerField, AutoField) + + def __init__(self, name, field_class, raw_column_type, nullable, + primary_key=False, column_name=None, index=False, + unique=False, default=None, extra_parameters=None): + self.name = name + self.field_class = field_class + self.raw_column_type = raw_column_type + self.nullable = nullable + self.primary_key = primary_key + self.column_name = column_name + self.index = index + self.unique = unique + self.default = default + self.extra_parameters = extra_parameters + + # Foreign key metadata. + self.rel_model = None + self.related_name = None + self.to_field = None + + def __repr__(self): + attrs = [ + 'field_class', + 'raw_column_type', + 'nullable', + 'primary_key', + 'column_name'] + keyword_args = ', '.join( + '%s=%s' % (attr, getattr(self, attr)) + for attr in attrs) + return 'Column(%s, %s)' % (self.name, keyword_args) + + def get_field_parameters(self): + params = {} + if self.extra_parameters is not None: + params.update(self.extra_parameters) + + # Set up default attributes. + if self.nullable: + params['null'] = True + if self.field_class is ForeignKeyField or self.name != self.column_name: + params['column_name'] = "'%s'" % self.column_name + if self.primary_key and not issubclass(self.field_class, AutoField): + params['primary_key'] = True + if self.default is not None: + params['constraints'] = '[SQL("DEFAULT %s")]' % self.default + + # Handle ForeignKeyField-specific attributes. + if self.is_foreign_key(): + params['model'] = self.rel_model + if self.to_field: + params['field'] = "'%s'" % self.to_field + if self.related_name: + params['backref'] = "'%s'" % self.related_name + + # Handle indexes on column. + if not self.is_primary_key(): + if self.unique: + params['unique'] = 'True' + elif self.index and not self.is_foreign_key(): + params['index'] = 'True' + + return params + + def is_primary_key(self): + return self.field_class is AutoField or self.primary_key + + def is_foreign_key(self): + return self.field_class is ForeignKeyField + + def is_self_referential_fk(self): + return (self.field_class is ForeignKeyField and + self.rel_model == "'self'") + + def set_foreign_key(self, foreign_key, model_names, dest=None, + related_name=None): + self.foreign_key = foreign_key + self.field_class = ForeignKeyField + if foreign_key.dest_table == foreign_key.table: + self.rel_model = "'self'" + else: + self.rel_model = model_names[foreign_key.dest_table] + self.to_field = dest and dest.name or None + self.related_name = related_name or None + + def get_field(self): + # Generate the field definition for this column. + field_params = {} + for key, value in self.get_field_parameters().items(): + if isclass(value) and issubclass(value, Field): + value = value.__name__ + field_params[key] = value + + param_str = ', '.join('%s=%s' % (k, v) + for k, v in sorted(field_params.items())) + field = '%s = %s(%s)' % ( + self.name, + self.field_class.__name__, + param_str) + + if self.field_class is UnknownField: + field = '%s # %s' % (field, self.raw_column_type) + + return field + + +class Metadata(object): + column_map = {} + extension_import = '' + + def __init__(self, database): + self.database = database + self.requires_extension = False + + def execute(self, sql, *params): + return self.database.execute_sql(sql, params) + + def get_columns(self, table, schema=None): + metadata = OrderedDict( + (metadata.name, metadata) + for metadata in self.database.get_columns(table, schema)) + + # Look up the actual column type for each column. + column_types, extra_params = self.get_column_types(table, schema) + + # Look up the primary keys. + pk_names = self.get_primary_keys(table, schema) + if len(pk_names) == 1: + pk = pk_names[0] + if column_types[pk] is IntegerField: + column_types[pk] = AutoField + elif column_types[pk] is BigIntegerField: + column_types[pk] = BigAutoField + + columns = OrderedDict() + for name, column_data in metadata.items(): + field_class = column_types[name] + default = self._clean_default(field_class, column_data.default) + + columns[name] = Column( + name, + field_class=field_class, + raw_column_type=column_data.data_type, + nullable=column_data.null, + primary_key=column_data.primary_key, + column_name=name, + default=default, + extra_parameters=extra_params.get(name)) + + return columns + + def get_column_types(self, table, schema=None): + raise NotImplementedError + + def _clean_default(self, field_class, default): + if default is None or field_class in (AutoField, BigAutoField) or \ + default.lower() == 'null': + return + if issubclass(field_class, _StringField) and \ + isinstance(default, text_type) and not default.startswith("'"): + default = "'%s'" % default + return default or "''" + + def get_foreign_keys(self, table, schema=None): + return self.database.get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + return self.database.get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + return self.database.get_indexes(table, schema) + + +class PostgresqlMetadata(Metadata): + column_map = { + 16: BooleanField, + 17: BlobField, + 20: BigIntegerField, + 21: SmallIntegerField, + 23: IntegerField, + 25: TextField, + 700: FloatField, + 701: DoubleField, + 1042: CharField, # blank-padded CHAR + 1043: CharField, + 1082: DateField, + 1114: DateTimeField, + 1184: DateTimeField, + 1083: TimeField, + 1266: TimeField, + 1700: DecimalField, + 2950: UUIDField, # UUID + } + array_types = { + 1000: BooleanField, + 1001: BlobField, + 1005: SmallIntegerField, + 1007: IntegerField, + 1009: TextField, + 1014: CharField, + 1015: CharField, + 1016: BigIntegerField, + 1115: DateTimeField, + 1182: DateField, + 1183: TimeField, + 2951: UUIDField, + } + extension_import = 'from playhouse.postgres_ext import *' + + def __init__(self, database): + super(PostgresqlMetadata, self).__init__(database) + + if postgres_ext is not None: + # Attempt to add types like HStore and JSON. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'json': + self.column_map[oid] = postgres_ext.JSONField + elif typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + elif typname == 'hstore': + self.column_map[oid] = postgres_ext.HStoreField + elif typname == 'tsvector': + self.column_map[oid] = postgres_ext.TSVectorField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + def get_column_types(self, table, schema): + column_types = {} + extra_params = {} + extension_types = set(( + postgres_ext.ArrayField, + postgres_ext.BinaryJSONField, + postgres_ext.JSONField, + postgres_ext.TSVectorField, + postgres_ext.HStoreField)) if postgres_ext is not None else set() + + # Look up the actual column type for each column. + identifier = '%s."%s"' % (schema, table) + cursor = self.execute( + 'SELECT attname, atttypid FROM pg_catalog.pg_attribute ' + 'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0) + + # Store column metadata in dictionary keyed by column name. + for name, oid in cursor.fetchall(): + column_types[name] = self.column_map.get(oid, UnknownField) + if column_types[name] in extension_types: + self.requires_extension = True + if oid in self.array_types: + extra_params[name] = {'field_class': self.array_types[oid]} + + return column_types, extra_params + + def get_columns(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_columns(table, schema) + + def get_foreign_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_foreign_keys(table, schema) + + def get_primary_keys(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_primary_keys(table, schema) + + def get_indexes(self, table, schema=None): + schema = schema or 'public' + return super(PostgresqlMetadata, self).get_indexes(table, schema) + + +class CockroachDBMetadata(PostgresqlMetadata): + # CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to + # regular IntegerField. + column_map = PostgresqlMetadata.column_map.copy() + column_map[20] = IntegerField + array_types = PostgresqlMetadata.array_types.copy() + array_types[1016] = IntegerField + extension_import = 'from playhouse.cockroachdb import *' + + def __init__(self, database): + Metadata.__init__(self, database) + self.requires_extension = True + + if postgres_ext is not None: + # Attempt to add JSON types. + cursor = self.execute('select oid, typname, format_type(oid, NULL)' + ' from pg_type;') + results = cursor.fetchall() + + for oid, typname, formatted_type in results: + if typname == 'jsonb': + self.column_map[oid] = postgres_ext.BinaryJSONField + + for oid in self.array_types: + self.column_map[oid] = postgres_ext.ArrayField + + +class MySQLMetadata(Metadata): + if FIELD_TYPE is None: + column_map = {} + else: + column_map = { + FIELD_TYPE.BLOB: TextField, + FIELD_TYPE.CHAR: CharField, + FIELD_TYPE.DATE: DateField, + FIELD_TYPE.DATETIME: DateTimeField, + FIELD_TYPE.DECIMAL: DecimalField, + FIELD_TYPE.DOUBLE: FloatField, + FIELD_TYPE.FLOAT: FloatField, + FIELD_TYPE.INT24: IntegerField, + FIELD_TYPE.LONG_BLOB: TextField, + FIELD_TYPE.LONG: IntegerField, + FIELD_TYPE.LONGLONG: BigIntegerField, + FIELD_TYPE.MEDIUM_BLOB: TextField, + FIELD_TYPE.NEWDECIMAL: DecimalField, + FIELD_TYPE.SHORT: IntegerField, + FIELD_TYPE.STRING: CharField, + FIELD_TYPE.TIMESTAMP: DateTimeField, + FIELD_TYPE.TIME: TimeField, + FIELD_TYPE.TINY_BLOB: TextField, + FIELD_TYPE.TINY: IntegerField, + FIELD_TYPE.VAR_STRING: CharField, + } + + def __init__(self, database, **kwargs): + if 'password' in kwargs: + kwargs['passwd'] = kwargs.pop('password') + super(MySQLMetadata, self).__init__(database, **kwargs) + + def get_column_types(self, table, schema=None): + column_types = {} + + # Look up the actual column type for each column. + cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table) + + # Store column metadata in dictionary keyed by column name. + for column_description in cursor.description: + name, type_code = column_description[:2] + column_types[name] = self.column_map.get(type_code, UnknownField) + + return column_types, {} + + +class SqliteMetadata(Metadata): + column_map = { + 'bigint': BigIntegerField, + 'blob': BlobField, + 'bool': BooleanField, + 'boolean': BooleanField, + 'char': CharField, + 'date': DateField, + 'datetime': DateTimeField, + 'decimal': DecimalField, + 'float': FloatField, + 'integer': IntegerField, + 'integer unsigned': IntegerField, + 'int': IntegerField, + 'long': BigIntegerField, + 'numeric': DecimalField, + 'real': FloatField, + 'smallinteger': IntegerField, + 'smallint': IntegerField, + 'smallint unsigned': IntegerField, + 'text': TextField, + 'time': TimeField, + 'varchar': CharField, + } + + begin = '(?:["\[\(]+)?' + end = '(?:["\]\)]+)?' + re_foreign_key = ( + '(?:FOREIGN KEY\s*)?' + '{begin}(.+?){end}\s+(?:.+\s+)?' + 'references\s+{begin}(.+?){end}' + '\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end) + re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' + + def _map_col(self, column_type): + raw_column_type = column_type.lower() + if raw_column_type in self.column_map: + field_class = self.column_map[raw_column_type] + elif re.search(self.re_varchar, raw_column_type): + field_class = CharField + else: + column_type = re.sub('\(.+\)', '', raw_column_type) + if column_type == '': + field_class = BareField + else: + field_class = self.column_map.get(column_type, UnknownField) + return field_class + + def get_column_types(self, table, schema=None): + column_types = {} + columns = self.database.get_columns(table) + + for column in columns: + column_types[column.name] = self._map_col(column.data_type) + + return column_types, {} + + +_DatabaseMetadata = namedtuple('_DatabaseMetadata', ( + 'columns', + 'primary_keys', + 'foreign_keys', + 'model_names', + 'indexes')) + + +class DatabaseMetadata(_DatabaseMetadata): + def multi_column_indexes(self, table): + accum = [] + for index in self.indexes[table]: + if len(index.columns) > 1: + field_names = [self.columns[table][column].name + for column in index.columns + if column in self.columns[table]] + accum.append((field_names, index.unique)) + return accum + + def column_indexes(self, table): + accum = {} + for index in self.indexes[table]: + if len(index.columns) == 1: + accum[index.columns[0]] = index.unique + return accum + + +class Introspector(object): + pk_classes = [AutoField, IntegerField] + + def __init__(self, metadata, schema=None): + self.metadata = metadata + self.schema = schema + + def __repr__(self): + return '' % self.metadata.database + + @classmethod + def from_database(cls, database, schema=None): + if CockroachDatabase and isinstance(database, CockroachDatabase): + metadata = CockroachDBMetadata(database) + elif isinstance(database, PostgresqlDatabase): + metadata = PostgresqlMetadata(database) + elif isinstance(database, MySQLDatabase): + metadata = MySQLMetadata(database) + elif isinstance(database, SqliteDatabase): + metadata = SqliteMetadata(database) + else: + raise ValueError('Introspection not supported for %r' % database) + return cls(metadata, schema=schema) + + def get_database_class(self): + return type(self.metadata.database) + + def get_database_name(self): + return self.metadata.database.database + + def get_database_kwargs(self): + return self.metadata.database.connect_params + + def get_additional_imports(self): + if self.metadata.requires_extension: + return '\n' + self.metadata.extension_import + return '' + + def make_model_name(self, table, snake_case=True): + if snake_case: + table = make_snake_case(table) + model = re.sub(r'[^\w]+', '', table) + model_name = ''.join(sub.title() for sub in model.split('_')) + if not model_name[0].isalpha(): + model_name = 'T' + model_name + return model_name + + def make_column_name(self, column, is_foreign_key=False, snake_case=True): + column = column.strip() + if snake_case: + column = make_snake_case(column) + column = column.lower() + if is_foreign_key: + # Strip "_id" from foreign keys, unless the foreign-key happens to + # be named "_id", in which case the name is retained. + column = re.sub('_id$', '', column) or column + + # Remove characters that are invalid for Python identifiers. + column = re.sub(r'[^\w]+', '_', column) + if column in RESERVED_WORDS: + column += '_' + if len(column) and column[0].isdigit(): + column = '_' + column + return column + + def introspect(self, table_names=None, literal_column_names=False, + include_views=False, snake_case=True): + # Retrieve all the tables in the database. + tables = self.metadata.database.get_tables(schema=self.schema) + if include_views: + views = self.metadata.database.get_views(schema=self.schema) + tables.extend([view.name for view in views]) + + if table_names is not None: + tables = [table for table in tables if table in table_names] + table_set = set(tables) + + # Store a mapping of table name -> dictionary of columns. + columns = {} + + # Store a mapping of table name -> set of primary key columns. + primary_keys = {} + + # Store a mapping of table -> foreign keys. + foreign_keys = {} + + # Store a mapping of table name -> model name. + model_names = {} + + # Store a mapping of table name -> indexes. + indexes = {} + + # Gather the columns for each table. + for table in tables: + table_indexes = self.metadata.get_indexes(table, self.schema) + table_columns = self.metadata.get_columns(table, self.schema) + try: + foreign_keys[table] = self.metadata.get_foreign_keys( + table, self.schema) + except ValueError as exc: + err(*exc.args) + foreign_keys[table] = [] + else: + # If there is a possibility we could exclude a dependent table, + # ensure that we introspect it so FKs will work. + if table_names is not None: + for foreign_key in foreign_keys[table]: + if foreign_key.dest_table not in table_set: + tables.append(foreign_key.dest_table) + table_set.add(foreign_key.dest_table) + + model_names[table] = self.make_model_name(table, snake_case) + + # Collect sets of all the column names as well as all the + # foreign-key column names. + lower_col_names = set(column_name.lower() + for column_name in table_columns) + fks = set(fk_col.column for fk_col in foreign_keys[table]) + + for col_name, column in table_columns.items(): + if literal_column_names: + new_name = re.sub(r'[^\w]+', '_', col_name) + else: + new_name = self.make_column_name(col_name, col_name in fks, + snake_case) + + # If we have two columns, "parent" and "parent_id", ensure + # that when we don't introduce naming conflicts. + lower_name = col_name.lower() + if lower_name.endswith('_id') and new_name in lower_col_names: + new_name = col_name.lower() + + column.name = new_name + + for index in table_indexes: + if len(index.columns) == 1: + column = index.columns[0] + if column in table_columns: + table_columns[column].unique = index.unique + table_columns[column].index = True + + primary_keys[table] = self.metadata.get_primary_keys( + table, self.schema) + columns[table] = table_columns + indexes[table] = table_indexes + + # Gather all instances where we might have a `related_name` conflict, + # either due to multiple FKs on a table pointing to the same table, + # or a related_name that would conflict with an existing field. + related_names = {} + sort_fn = lambda foreign_key: foreign_key.column + for table in tables: + models_referenced = set() + for foreign_key in sorted(foreign_keys[table], key=sort_fn): + try: + column = columns[table][foreign_key.column] + except KeyError: + continue + + dest_table = foreign_key.dest_table + if dest_table in models_referenced: + related_names[column] = '%s_%s_set' % ( + dest_table, + column.name) + else: + models_referenced.add(dest_table) + + # On the second pass convert all foreign keys. + for table in tables: + for foreign_key in foreign_keys[table]: + src = columns[foreign_key.table][foreign_key.column] + try: + dest = columns[foreign_key.dest_table][ + foreign_key.dest_column] + except KeyError: + dest = None + + src.set_foreign_key( + foreign_key=foreign_key, + model_names=model_names, + dest=dest, + related_name=related_names.get(src)) + + return DatabaseMetadata( + columns, + primary_keys, + foreign_keys, + model_names, + indexes) + + def generate_models(self, skip_invalid=False, table_names=None, + literal_column_names=False, bare_fields=False, + include_views=False): + database = self.introspect(table_names, literal_column_names, + include_views) + models = {} + + class BaseModel(Model): + class Meta: + database = self.metadata.database + schema = self.schema + + def _create_model(table, models): + for foreign_key in database.foreign_keys[table]: + dest = foreign_key.dest_table + + if dest not in models and dest != table: + _create_model(dest, models) + + primary_keys = [] + columns = database.columns[table] + for column_name, column in columns.items(): + if column.primary_key: + primary_keys.append(column.name) + + multi_column_indexes = database.multi_column_indexes(table) + column_indexes = database.column_indexes(table) + + class Meta: + indexes = multi_column_indexes + table_name = table + + # Fix models with multi-column primary keys. + composite_key = False + if len(primary_keys) == 0: + primary_keys = columns.keys() + if len(primary_keys) > 1: + Meta.primary_key = CompositeKey(*[ + field.name for col, field in columns.items() + if col in primary_keys]) + composite_key = True + + attrs = {'Meta': Meta} + for column_name, column in columns.items(): + FieldClass = column.field_class + if FieldClass is not ForeignKeyField and bare_fields: + FieldClass = BareField + elif FieldClass is UnknownField: + FieldClass = BareField + + params = { + 'column_name': column_name, + 'null': column.nullable} + if column.primary_key and composite_key: + if FieldClass is AutoField: + FieldClass = IntegerField + params['primary_key'] = False + elif column.primary_key and FieldClass is not AutoField: + params['primary_key'] = True + if column.is_foreign_key(): + if column.is_self_referential_fk(): + params['model'] = 'self' + else: + dest_table = column.foreign_key.dest_table + params['model'] = models[dest_table] + if column.to_field: + params['field'] = column.to_field + + # Generate a unique related name. + params['backref'] = '%s_%s_rel' % (table, column_name) + + if column.default is not None: + constraint = SQL('DEFAULT %s' % column.default) + params['constraints'] = [constraint] + + if column_name in column_indexes and not \ + column.is_primary_key(): + if column_indexes[column_name]: + params['unique'] = True + elif not column.is_foreign_key(): + params['index'] = True + + attrs[column.name] = FieldClass(**params) + + try: + models[table] = type(str(table), (BaseModel,), attrs) + except ValueError: + if not skip_invalid: + raise + + # Actually generate Model classes. + for table, model in sorted(database.model_names.items()): + if table not in models: + _create_model(table, models) + + return models + + +def introspect(database, schema=None): + introspector = Introspector.from_database(database, schema=schema) + return introspector.introspect() + + +def generate_models(database, schema=None, **options): + introspector = Introspector.from_database(database, schema=schema) + return introspector.generate_models(**options) + + +def print_model(model, indexes=True, inline_indexes=False): + print(model._meta.name) + for field in model._meta.sorted_fields: + parts = [' %s %s' % (field.name, field.field_type)] + if field.primary_key: + parts.append(' PK') + elif inline_indexes: + if field.unique: + parts.append(' UNIQUE') + elif field.index: + parts.append(' INDEX') + if isinstance(field, ForeignKeyField): + parts.append(' FK: %s.%s' % (field.rel_model.__name__, + field.rel_field.name)) + print(''.join(parts)) + + if indexes: + index_list = model._meta.fields_to_index() + if not index_list: + return + + print('\nindex(es)') + for index in index_list: + parts = [' '] + ctx = model._meta.database.get_sql_context() + with ctx.scope_values(param='%s', quote='""'): + ctx.sql(CommaNodeList(index._expressions)) + if index._where: + ctx.literal(' WHERE ') + ctx.sql(index._where) + sql, params = ctx.query() + + clean = sql % tuple(map(_query_val_transform, params)) + parts.append(clean.replace('"', '')) + + if index._unique: + parts.append(' UNIQUE') + print(''.join(parts)) + + +def get_table_sql(model): + sql, params = model._schema._create_table().query() + if model._meta.database.param != '%s': + sql = sql.replace(model._meta.database.param, '%s') + + # Format and indent the table declaration, simplest possible approach. + match_obj = re.match('^(.+?\()(.+)(\).*)', sql) + create, columns, extra = match_obj.groups() + indented = ',\n'.join(' %s' % column for column in columns.split(', ')) + + clean = '\n'.join((create, indented, extra)).strip() + return clean % tuple(map(_query_val_transform, params)) + +def print_table_sql(model): + print(get_table_sql(model)) diff --git a/libs/playhouse/shortcuts.py b/libs/playhouse/shortcuts.py new file mode 100644 index 000000000..cefa6e437 --- /dev/null +++ b/libs/playhouse/shortcuts.py @@ -0,0 +1,252 @@ +from peewee import * +from peewee import Alias +from peewee import CompoundSelectQuery +from peewee import SENTINEL +from peewee import callable_ + + +_clone_set = lambda s: set(s) if s else set() + + +def model_to_dict(model, recurse=True, backrefs=False, only=None, + exclude=None, seen=None, extra_attrs=None, + fields_from_query=None, max_depth=None, manytomany=False): + """ + Convert a model instance (and any related objects) to a dictionary. + + :param bool recurse: Whether foreign-keys should be recursed. + :param bool backrefs: Whether lists of related objects should be recursed. + :param only: A list (or set) of field instances indicating which fields + should be included. + :param exclude: A list (or set) of field instances that should be + excluded from the dictionary. + :param list extra_attrs: Names of model instance attributes or methods + that should be included. + :param SelectQuery fields_from_query: Query that was source of model. Take + fields explicitly selected by the query and serialize them. + :param int max_depth: Maximum depth to recurse, value <= 0 means no max. + :param bool manytomany: Process many-to-many fields. + """ + max_depth = -1 if max_depth is None else max_depth + if max_depth == 0: + recurse = False + + only = _clone_set(only) + extra_attrs = _clone_set(extra_attrs) + should_skip = lambda n: (n in exclude) or (only and (n not in only)) + + if fields_from_query is not None: + for item in fields_from_query._returning: + if isinstance(item, Field): + only.add(item) + elif isinstance(item, Alias): + extra_attrs.add(item._alias) + + data = {} + exclude = _clone_set(exclude) + seen = _clone_set(seen) + exclude |= seen + model_class = type(model) + + if manytomany: + for name, m2m in model._meta.manytomany.items(): + if should_skip(name): + continue + + exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref])) + for fkf in m2m.through_model._meta.refs: + exclude.add(fkf) + + accum = [] + for rel_obj in getattr(model, name): + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + data[name] = accum + + for field in model._meta.sorted_fields: + if should_skip(field): + continue + + field_data = model.__data__.get(field.name) + if isinstance(field, ForeignKeyField) and recurse: + if field_data is not None: + seen.add(field) + rel_obj = getattr(model, field.name) + field_data = model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + seen=seen, + max_depth=max_depth - 1) + else: + field_data = None + + data[field.name] = field_data + + if extra_attrs: + for attr_name in extra_attrs: + attr = getattr(model, attr_name) + if callable_(attr): + data[attr_name] = attr() + else: + data[attr_name] = attr + + if backrefs and recurse: + for foreign_key, rel_model in model._meta.backrefs.items(): + if foreign_key.backref == '+': continue + descriptor = getattr(model_class, foreign_key.backref) + if descriptor in exclude or foreign_key in exclude: + continue + if only and (descriptor not in only) and (foreign_key not in only): + continue + + accum = [] + exclude.add(foreign_key) + related_query = getattr(model, foreign_key.backref) + + for rel_obj in related_query: + accum.append(model_to_dict( + rel_obj, + recurse=recurse, + backrefs=backrefs, + only=only, + exclude=exclude, + max_depth=max_depth - 1)) + + data[foreign_key.backref] = accum + + return data + + +def update_model_from_dict(instance, data, ignore_unknown=False): + meta = instance._meta + backrefs = dict([(fk.backref, fk) for fk in meta.backrefs]) + + for key, value in data.items(): + if key in meta.combined: + field = meta.combined[key] + is_backref = False + elif key in backrefs: + field = backrefs[key] + is_backref = True + elif ignore_unknown: + setattr(instance, key, value) + continue + else: + raise AttributeError('Unrecognized attribute "%s" for model ' + 'class %s.' % (key, type(instance))) + + is_foreign_key = isinstance(field, ForeignKeyField) + + if not is_backref and is_foreign_key and isinstance(value, dict): + try: + rel_instance = instance.__rel__[field.name] + except KeyError: + rel_instance = field.rel_model() + setattr( + instance, + field.name, + update_model_from_dict(rel_instance, value, ignore_unknown)) + elif is_backref and isinstance(value, (list, tuple)): + instances = [ + dict_to_model(field.model, row_data, ignore_unknown) + for row_data in value] + for rel_instance in instances: + setattr(rel_instance, field.name, instance) + setattr(instance, field.backref, instances) + else: + setattr(instance, field.name, value) + + return instance + + +def dict_to_model(model_class, data, ignore_unknown=False): + return update_model_from_dict(model_class(), data, ignore_unknown) + + +class ReconnectMixin(object): + """ + Mixin class that attempts to automatically reconnect to the database under + certain error conditions. + + For example, MySQL servers will typically close connections that are idle + for 28800 seconds ("wait_timeout" setting). If your application makes use + of long-lived connections, you may find your connections are closed after + a period of no activity. This mixin will attempt to reconnect automatically + when these errors occur. + + This mixin class probably should not be used with Postgres (unless you + REALLY know what you are doing) and definitely has no business being used + with Sqlite. If you wish to use with Postgres, you will need to adapt the + `reconnect_errors` attribute to something appropriate for Postgres. + """ + reconnect_errors = ( + # Error class, error message fragment (or empty string for all). + (OperationalError, '2006'), # MySQL server has gone away. + (OperationalError, '2013'), # Lost connection to MySQL server. + (OperationalError, '2014'), # Commands out of sync. + + # mysql-connector raises a slightly different error when an idle + # connection is terminated by the server. This is equivalent to 2013. + (OperationalError, 'MySQL Connection not available.'), + ) + + def __init__(self, *args, **kwargs): + super(ReconnectMixin, self).__init__(*args, **kwargs) + + # Normalize the reconnect errors to a more efficient data-structure. + self._reconnect_errors = {} + for exc_class, err_fragment in self.reconnect_errors: + self._reconnect_errors.setdefault(exc_class, []) + self._reconnect_errors[exc_class].append(err_fragment.lower()) + + def execute_sql(self, sql, params=None, commit=SENTINEL): + try: + return super(ReconnectMixin, self).execute_sql(sql, params, commit) + except Exception as exc: + exc_class = type(exc) + if exc_class not in self._reconnect_errors: + raise exc + + exc_repr = str(exc).lower() + for err_fragment in self._reconnect_errors[exc_class]: + if err_fragment in exc_repr: + break + else: + raise exc + + if not self.is_closed(): + self.close() + self.connect() + + return super(ReconnectMixin, self).execute_sql(sql, params, commit) + + +def resolve_multimodel_query(query, key='_model_identifier'): + mapping = {} + accum = [query] + while accum: + curr = accum.pop() + if isinstance(curr, CompoundSelectQuery): + accum.extend((curr.lhs, curr.rhs)) + continue + + model_class = curr.model + name = model_class._meta.table_name + mapping[name] = model_class + curr._returning.append(Value(name).alias(key)) + + def wrapped_iterator(): + for row in query.dicts().iterator(): + identifier = row.pop(key) + model = mapping[identifier] + yield model(**row) + + return wrapped_iterator() diff --git a/libs/playhouse/signals.py b/libs/playhouse/signals.py new file mode 100644 index 000000000..4e92872e5 --- /dev/null +++ b/libs/playhouse/signals.py @@ -0,0 +1,79 @@ +""" +Provide django-style hooks for model events. +""" +from peewee import Model as _Model + + +class Signal(object): + def __init__(self): + self._flush() + + def _flush(self): + self._receivers = set() + self._receiver_list = [] + + def connect(self, receiver, name=None, sender=None): + name = name or receiver.__name__ + key = (name, sender) + if key not in self._receivers: + self._receivers.add(key) + self._receiver_list.append((name, receiver, sender)) + else: + raise ValueError('receiver named %s (for sender=%s) already ' + 'connected' % (name, sender or 'any')) + + def disconnect(self, receiver=None, name=None, sender=None): + if receiver: + name = name or receiver.__name__ + if not name: + raise ValueError('a receiver or a name must be provided') + + key = (name, sender) + if key not in self._receivers: + raise ValueError('receiver named %s for sender=%s not found.' % + (name, sender or 'any')) + + self._receivers.remove(key) + self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list + if n != name and s != sender] + + def __call__(self, name=None, sender=None): + def decorator(fn): + self.connect(fn, name, sender) + return fn + return decorator + + def send(self, instance, *args, **kwargs): + sender = type(instance) + responses = [] + for n, r, s in self._receiver_list: + if s is None or isinstance(instance, s): + responses.append((r, r(sender, instance, *args, **kwargs))) + return responses + + +pre_save = Signal() +post_save = Signal() +pre_delete = Signal() +post_delete = Signal() +pre_init = Signal() + + +class Model(_Model): + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + pre_init.send(self) + + def save(self, *args, **kwargs): + pk_value = self._pk if self._meta.primary_key else True + created = kwargs.get('force_insert', False) or not bool(pk_value) + pre_save.send(self, created=created) + ret = super(Model, self).save(*args, **kwargs) + post_save.send(self, created=created) + return ret + + def delete_instance(self, *args, **kwargs): + pre_delete.send(self) + ret = super(Model, self).delete_instance(*args, **kwargs) + post_delete.send(self) + return ret diff --git a/libs/playhouse/sqlcipher_ext.py b/libs/playhouse/sqlcipher_ext.py new file mode 100644 index 000000000..9bad1eca6 --- /dev/null +++ b/libs/playhouse/sqlcipher_ext.py @@ -0,0 +1,103 @@ +""" +Peewee integration with pysqlcipher. + +Project page: https://github.com/leapcode/pysqlcipher/ + +**WARNING!!! EXPERIMENTAL!!!** + +* Although this extention's code is short, it has not been properly + peer-reviewed yet and may have introduced vulnerabilities. + +Also note that this code relies on pysqlcipher and sqlcipher, and +the code there might have vulnerabilities as well, but since these +are widely used crypto modules, we can expect "short zero days" there. + +Example usage: + + from peewee.playground.ciphersql_ext import SqlCipherDatabase + db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real") + +* `passphrase`: should be "long enough". + Note that *length beats vocabulary* (much exponential), and even + a lowercase-only passphrase like easytorememberyethardforotherstoguess + packs more noise than 8 random printable characters and *can* be memorized. + +When opening an existing database, passphrase should be the one used when the +database was created. If the passphrase is incorrect, an exception will only be +raised **when you access the database**. + +If you need to ask for an interactive passphrase, here's example code you can +put after the `db = ...` line: + + try: # Just access the database so that it checks the encryption. + db.get_tables() + # We're looking for a DatabaseError with a specific error message. + except peewee.DatabaseError as e: + # Check whether the message *means* "passphrase is wrong" + if e.args[0] == 'file is encrypted or is not a database': + raise Exception('Developer should Prompt user for passphrase ' + 'again.') + else: + # A different DatabaseError. Raise it. + raise e + +See a more elaborate example with this code at +https://gist.github.com/thedod/11048875 +""" +import datetime +import decimal +import sys + +from peewee import * +from playhouse.sqlite_ext import SqliteExtDatabase +if sys.version_info[0] != 3: + from pysqlcipher import dbapi2 as sqlcipher +else: + try: + from sqlcipher3 import dbapi2 as sqlcipher + except ImportError: + from pysqlcipher3 import dbapi2 as sqlcipher + +sqlcipher.register_adapter(decimal.Decimal, str) +sqlcipher.register_adapter(datetime.date, str) +sqlcipher.register_adapter(datetime.time, str) + + +class _SqlCipherDatabase(object): + def _connect(self): + params = dict(self.connect_params) + passphrase = params.pop('passphrase', '').replace("'", "''") + + conn = sqlcipher.connect(self.database, isolation_level=None, **params) + try: + if passphrase: + conn.execute("PRAGMA key='%s'" % passphrase) + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def set_passphrase(self, passphrase): + if not self.is_closed(): + raise ImproperlyConfigured('Cannot set passphrase when database ' + 'is open. To change passphrase of an ' + 'open database use the rekey() method.') + + self.connect_params['passphrase'] = passphrase + + def rekey(self, passphrase): + if self.is_closed(): + self.connect() + + self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''")) + self.connect_params['passphrase'] = passphrase + return True + + +class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase): + pass + + +class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase): + pass diff --git a/libs/playhouse/sqlite_changelog.py b/libs/playhouse/sqlite_changelog.py new file mode 100644 index 000000000..b036af20f --- /dev/null +++ b/libs/playhouse/sqlite_changelog.py @@ -0,0 +1,123 @@ +from peewee import * +from playhouse.sqlite_ext import JSONField + + +class BaseChangeLog(Model): + timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')]) + action = TextField() + table = TextField() + primary_key = IntegerField() + changes = JSONField() + + +class ChangeLog(object): + # Model class that will serve as the base for the changelog. This model + # will be subclassed and mapped to your application database. + base_model = BaseChangeLog + + # Template for the triggers that handle updating the changelog table. + # table: table name + # action: insert / update / delete + # new_old: NEW or OLD (OLD is for DELETE) + # primary_key: table primary key column name + # column_array: output of build_column_array() + # change_table: changelog table name + template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s + AFTER %(action)s ON %(table)s + BEGIN + INSERT INTO %(change_table)s + ("action", "table", "primary_key", "changes") + SELECT + '%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes" + FROM ( + SELECT json_group_object( + col, + json_array("oldval", "newval")) AS "changes" + FROM ( + SELECT json_extract(value, '$[0]') as "col", + json_extract(value, '$[1]') as "oldval", + json_extract(value, '$[2]') as "newval" + FROM json_each(json_array(%(column_array)s)) + WHERE "oldval" IS NOT "newval" + ) + ); + END;""" + + drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s' + + _actions = ('INSERT', 'UPDATE', 'DELETE') + + def __init__(self, db, table_name='changelog'): + self.db = db + self.table_name = table_name + + def _build_column_array(self, model, use_old, use_new, skip_fields=None): + # Builds a list of SQL expressions for each field we are tracking. This + # is used as the data source for change tracking in our trigger. + col_array = [] + for field in model._meta.sorted_fields: + if field.primary_key: + continue + + if skip_fields is not None and field.name in skip_fields: + continue + + column = field.column_name + new = 'NULL' if not use_new else 'NEW."%s"' % column + old = 'NULL' if not use_old else 'OLD."%s"' % column + + if isinstance(field, JSONField): + # Ensure that values are cast to JSON so that the serialization + # is preserved when calculating the old / new. + if use_old: old = 'json(%s)' % old + if use_new: new = 'json(%s)' % new + + col_array.append("json_array('%s', %s, %s)" % (column, old, new)) + + return ', '.join(col_array) + + def trigger_sql(self, model, action, skip_fields=None): + assert action in self._actions + use_old = action != 'INSERT' + use_new = action != 'DELETE' + cols = self._build_column_array(model, use_old, use_new, skip_fields) + return self.template % { + 'table': model._meta.table_name, + 'action': action, + 'new_old': 'NEW' if action != 'DELETE' else 'OLD', + 'primary_key': model._meta.primary_key.column_name, + 'column_array': cols, + 'change_table': self.table_name} + + def drop_trigger_sql(self, model, action): + assert action in self._actions + return self.drop_template % { + 'table': model._meta.table_name, + 'action': action} + + @property + def model(self): + if not hasattr(self, '_changelog_model'): + class ChangeLog(self.base_model): + class Meta: + database = self.db + table_name = self.table_name + self._changelog_model = ChangeLog + + return self._changelog_model + + def install(self, model, skip_fields=None, drop=True, insert=True, + update=True, delete=True, create_table=True): + ChangeLog = self.model + if create_table: + ChangeLog.create_table() + + actions = list(zip((insert, update, delete), self._actions)) + if drop: + for _, action in actions: + self.db.execute_sql(self.drop_trigger_sql(model, action)) + + for enabled, action in actions: + if enabled: + sql = self.trigger_sql(model, action, skip_fields) + self.db.execute_sql(sql) diff --git a/libs/playhouse/sqlite_ext.py b/libs/playhouse/sqlite_ext.py new file mode 100644 index 000000000..09ff7a99c --- /dev/null +++ b/libs/playhouse/sqlite_ext.py @@ -0,0 +1,1294 @@ +import json +import math +import re +import struct +import sys + +from peewee import * +from peewee import ColumnBase +from peewee import EnclosedNodeList +from peewee import Entity +from peewee import Expression +from peewee import Node +from peewee import NodeList +from peewee import OP +from peewee import VirtualField +from peewee import merge_dict +from peewee import sqlite3 +try: + from playhouse._sqlite_ext import ( + backup, + backup_to_file, + Blob, + ConnectionHelper, + register_bloomfilter, + register_hash_functions, + register_rank_functions, + sqlite_get_db_status, + sqlite_get_status, + TableFunction, + ZeroBlob, + ) + CYTHON_SQLITE_EXTENSIONS = True +except ImportError: + CYTHON_SQLITE_EXTENSIONS = False + + +if sys.version_info[0] == 3: + basestring = str + + +FTS3_MATCHINFO = 'pcx' +FTS4_MATCHINFO = 'pcnalx' +if sqlite3 is not None: + FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 +else: + FTS_VERSION = 3 + +FTS5_MIN_SQLITE_VERSION = (3, 9, 0) + + +class RowIDField(AutoField): + auto_increment = True + column_name = name = required_name = 'rowid' + + def bind(self, model, name, *args): + if name != self.required_name: + raise ValueError('%s must be named "%s".' % + (type(self), self.required_name)) + super(RowIDField, self).bind(model, name, *args) + + +class DocIDField(RowIDField): + column_name = name = required_name = 'docid' + + +class AutoIncrementField(AutoField): + def ddl(self, ctx): + node_list = super(AutoIncrementField, self).ddl(ctx) + return NodeList((node_list, SQL('AUTOINCREMENT'))) + + +class TDecimalField(DecimalField): + field_type = 'TEXT' + def get_modifiers(self): pass + + +class JSONPath(ColumnBase): + def __init__(self, field, path=None): + super(JSONPath, self).__init__() + self._field = field + self._path = path or () + + @property + def path(self): + return Value('$%s' % ''.join(self._path)) + + def __getitem__(self, idx): + if isinstance(idx, int): + item = '[%s]' % idx + else: + item = '.%s' % idx + return JSONPath(self._field, self._path + (item,)) + + def set(self, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.json(self._field._json_dumps(value)) + return fn.json_set(self._field, self.path, value) + + def update(self, value): + return self.set(fn.json_patch(self, self._field._json_dumps(value))) + + def remove(self): + return fn.json_remove(self._field, self.path) + + def json_type(self): + return fn.json_type(self._field, self.path) + + def length(self): + return fn.json_array_length(self._field, self.path) + + def children(self): + return fn.json_each(self._field, self.path) + + def tree(self): + return fn.json_tree(self._field, self.path) + + def __sql__(self, ctx): + return ctx.sql(fn.json_extract(self._field, self.path) + if self._path else self._field) + + +class JSONField(TextField): + field_type = 'JSON' + unpack = False + + def __init__(self, json_dumps=None, json_loads=None, **kwargs): + self._json_dumps = json_dumps or json.dumps + self._json_loads = json_loads or json.loads + super(JSONField, self).__init__(**kwargs) + + def python_value(self, value): + if value is not None: + try: + return self._json_loads(value) + except (TypeError, ValueError): + return value + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = fn.json(self._json_dumps(value)) + return value + + def _e(op): + def inner(self, rhs): + if isinstance(rhs, (list, dict)): + rhs = Value(rhs, converter=self.db_value, unpack=False) + return Expression(self, op, rhs) + return inner + __eq__ = _e(OP.EQ) + __ne__ = _e(OP.NE) + __gt__ = _e(OP.GT) + __ge__ = _e(OP.GTE) + __lt__ = _e(OP.LT) + __le__ = _e(OP.LTE) + __hash__ = Field.__hash__ + + def __getitem__(self, item): + return JSONPath(self)[item] + + def set(self, value, as_json=None): + return JSONPath(self).set(value, as_json) + + def update(self, data): + return JSONPath(self).update(data) + + def remove(self): + return JSONPath(self).remove() + + def json_type(self): + return fn.json_type(self) + + def length(self): + return fn.json_array_length(self) + + def children(self): + """ + Schema of `json_each` and `json_tree`: + + key, + value, + type TEXT (object, array, string, etc), + atom (value for primitive/scalar types, NULL for array and object) + id INTEGER (unique identifier for element) + parent INTEGER (unique identifier of parent element or NULL) + fullkey TEXT (full path describing element) + path TEXT (path to the container of the current element) + json JSON hidden (1st input parameter to function) + root TEXT hidden (2nd input parameter, path at which to start) + """ + return fn.json_each(self) + + def tree(self): + return fn.json_tree(self) + + +class SearchField(Field): + def __init__(self, unindexed=False, column_name=None, **k): + if k: + raise ValueError('SearchField does not accept these keyword ' + 'arguments: %s.' % sorted(k)) + super(SearchField, self).__init__(unindexed=unindexed, + column_name=column_name, null=True) + + def match(self, term): + return match(self, term) + + +class VirtualTableSchemaManager(SchemaManager): + def _create_virtual_table(self, safe=True, **options): + options = self.model.clean_options( + merge_dict(self.model._meta.options, options)) + + # Structure: + # CREATE VIRTUAL TABLE + # USING + # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) + ctx = self._create_context() + ctx.literal('CREATE VIRTUAL TABLE ') + if safe: + ctx.literal('IF NOT EXISTS ') + (ctx + .sql(self.model) + .literal(' USING ')) + + ext_module = self.model._meta.extension_module + if isinstance(ext_module, Node): + return ctx.sql(ext_module) + + ctx.sql(SQL(ext_module)).literal(' ') + arguments = [] + meta = self.model._meta + + if meta.prefix_arguments: + arguments.extend([SQL(a) for a in meta.prefix_arguments]) + + # Constraints, data-types, foreign and primary keys are all omitted. + for field in meta.sorted_fields: + if isinstance(field, (RowIDField)) or field._hidden: + continue + field_def = [Entity(field.column_name)] + if field.unindexed: + field_def.append(SQL('UNINDEXED')) + arguments.append(NodeList(field_def)) + + if meta.arguments: + arguments.extend([SQL(a) for a in meta.arguments]) + + if options: + arguments.extend(self._create_table_option_sql(options)) + return ctx.sql(EnclosedNodeList(arguments)) + + def _create_table(self, safe=True, **options): + if issubclass(self.model, VirtualModel): + return self._create_virtual_table(safe, **options) + + return super(VirtualTableSchemaManager, self)._create_table( + safe, **options) + + +class VirtualModel(Model): + class Meta: + arguments = None + extension_module = None + prefix_arguments = None + primary_key = False + schema_manager_class = VirtualTableSchemaManager + + @classmethod + def clean_options(cls, options): + return options + + +class BaseFTSModel(VirtualModel): + @classmethod + def clean_options(cls, options): + content = options.get('content') + prefix = options.get('prefix') + tokenize = options.get('tokenize') + + if isinstance(content, basestring) and content == '': + # Special-case content-less full-text search tables. + options['content'] = "''" + elif isinstance(content, Field): + # Special-case to ensure fields are fully-qualified. + options['content'] = Entity(content.model._meta.table_name, + content.column_name) + + if prefix: + if isinstance(prefix, (list, tuple)): + prefix = ','.join([str(i) for i in prefix]) + options['prefix'] = "'%s'" % prefix.strip("' ") + + if tokenize and cls._meta.extension_module.lower() == 'fts5': + # Tokenizers need to be in quoted string for FTS5, but not for FTS3 + # or FTS4. + options['tokenize'] = '"%s"' % tokenize + + return options + + +class FTSModel(BaseFTSModel): + """ + VirtualModel class for creating tables that use either the FTS3 or FTS4 + search extensions. Peewee automatically determines which version of the + FTS extension is supported and will use FTS4 if possible. + """ + # FTS3/4 uses "docid" in the same way a normal table uses "rowid". + docid = DocIDField() + + class Meta: + extension_module = 'FTS%s' % FTS_VERSION + + @classmethod + def _fts_cmd(cls, cmd): + tbl = cls._meta.table_name + res = cls._meta.database.execute_sql( + "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) + return res.fetchone() + + @classmethod + def optimize(cls): + return cls._fts_cmd('optimize') + + @classmethod + def rebuild(cls): + return cls._fts_cmd('rebuild') + + @classmethod + def integrity_check(cls): + return cls._fts_cmd('integrity-check') + + @classmethod + def merge(cls, blocks=200, segments=8): + return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) + + @classmethod + def automerge(cls, state=True): + return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *weights): + matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) + return fn.fts_rank(matchinfo, *weights) + + @classmethod + def bm25(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25(match_info, *weights) + + @classmethod + def bm25f(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_bm25f(match_info, *weights) + + @classmethod + def lucene(cls, *weights): + match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) + return fn.fts_lucene(match_info, *weights) + + @classmethod + def _search(cls, term, weights, with_score, score_alias, score_fn, + explicit_ordering): + if not weights: + rank = score_fn() + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + # Attempt to get the specified weight of the field by looking + # it up using it's field instance followed by name. + field_weight = weights.get(field, weights.get(field.name, 1.0)) + weight_args.append(field_weight) + rank = score_fn(*weight_args) + else: + rank = score_fn(*weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(term)) + .order_by(order_by)) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.rank, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25, + explicit_ordering) + + @classmethod + def search_bm25f(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.bm25f, + explicit_ordering) + + @classmethod + def search_lucene(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search for selected `term` using BM25 algorithm.""" + return cls._search( + term, + weights, + with_score, + score_alias, + cls.lucene, + explicit_ordering) + + +_alphabet = 'abcdefghijklmnopqrstuvwxyz' +_alphanum = (set('\t ,"(){}*:_+0123456789') | + set(_alphabet) | + set(_alphabet.upper()) | + set((chr(26),))) +_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) +_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+') + + +class FTS5Model(BaseFTSModel): + """ + Requires SQLite >= 3.9.0. + + Table options: + + content: table name of external content, or empty string for "contentless" + content_rowid: column name of external content primary key + prefix: integer(s). Ex: '2' or '2 3 4' + tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' + + The unicode tokenizer supports the following parameters: + + * remove_diacritics (1 or 0, default is 1) + * tokenchars (string of characters, e.g. '-_' + * separators (string of characters) + + Parameters are passed as alternating parameter name and value, so: + + {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} + + Content-less tables: + + If you don't need the full-text content in it's original form, you can + specify a content-less table. Searches and auxiliary functions will work + as usual, but the only values returned when SELECT-ing can be rowid. Also + content-less tables do not support UPDATE or DELETE. + + External content tables: + + You can set up triggers to sync these, e.g. + + -- Create a table. And an external content fts5 table to index it. + CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); + CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); + + -- Triggers to keep the FTS index up to date. + CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + END; + CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN + INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); + INSERT INTO ft(rowid, b) VALUES (new.a, new.b); + END; + + Built-in auxiliary functions: + + * bm25(tbl[, weight_0, ... weight_n]) + * highlight(tbl, col_idx, prefix, suffix) + * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) + """ + # FTS5 does not support declared primary keys, but we can use the + # implicit rowid. + rowid = RowIDField() + + class Meta: + extension_module = 'fts5' + + _error_messages = { + 'field_type': ('Besides the implicit `rowid` column, all columns must ' + 'be instances of SearchField'), + 'index': 'Secondary indexes are not supported for FTS5 models', + 'pk': 'FTS5 models must use the default `rowid` primary key', + } + + @classmethod + def validate_model(cls): + # Perform FTS5-specific validation and options post-processing. + if cls._meta.primary_key.name != 'rowid': + raise ImproperlyConfigured(cls._error_messages['pk']) + for field in cls._meta.fields.values(): + if not isinstance(field, (SearchField, RowIDField)): + raise ImproperlyConfigured(cls._error_messages['field_type']) + if cls._meta.indexes: + raise ImproperlyConfigured(cls._error_messages['index']) + + @classmethod + def fts5_installed(cls): + if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: + return False + + # Test in-memory DB to determine if the FTS5 extension is installed. + tmp_db = sqlite3.connect(':memory:') + try: + tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') + except: + try: + tmp_db.enable_load_extension(True) + tmp_db.load_extension('fts5') + except: + return False + else: + cls._meta.database.load_extension('fts5') + finally: + tmp_db.close() + + return True + + @staticmethod + def validate_query(query): + """ + Simple helper function to indicate whether a search query is a + valid FTS5 query. Note: this simply looks at the characters being + used, and is not guaranteed to catch all problematic queries. + """ + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + continue + if set(token) & _invalid_ascii: + return False + return True + + @staticmethod + def clean_query(query, replace=chr(26)): + """ + Clean a query of invalid tokens. + """ + accum = [] + any_invalid = False + tokens = _quote_re.findall(query) + for token in tokens: + if token.startswith('"') and token.endswith('"'): + accum.append(token) + continue + token_set = set(token) + invalid_for_token = token_set & _invalid_ascii + if invalid_for_token: + any_invalid = True + for c in invalid_for_token: + token = token.replace(c, replace) + accum.append(token) + + if any_invalid: + return ' '.join(accum) + return query + + @classmethod + def match(cls, term): + """ + Generate a `MATCH` expression appropriate for searching this table. + """ + return match(cls._meta.entity, term) + + @classmethod + def rank(cls, *args): + return cls.bm25(*args) if args else SQL('rank') + + @classmethod + def bm25(cls, *weights): + return fn.bm25(cls._meta.entity, *weights) + + @classmethod + def search(cls, term, weights=None, with_score=False, score_alias='score', + explicit_ordering=False): + """Full-text search using selected `term`.""" + return cls.search_bm25( + FTS5Model.clean_query(term), + weights, + with_score, + score_alias, + explicit_ordering) + + @classmethod + def search_bm25(cls, term, weights=None, with_score=False, + score_alias='score', explicit_ordering=False): + """Full-text search using selected `term`.""" + if not weights: + rank = SQL('rank') + elif isinstance(weights, dict): + weight_args = [] + for field in cls._meta.sorted_fields: + if isinstance(field, SearchField) and not field.unindexed: + weight_args.append( + weights.get(field, weights.get(field.name, 1.0))) + rank = fn.bm25(cls._meta.entity, *weight_args) + else: + rank = fn.bm25(cls._meta.entity, *weights) + + selection = () + order_by = rank + if with_score: + selection = (cls, rank.alias(score_alias)) + if with_score and not explicit_ordering: + order_by = SQL(score_alias) + + return (cls + .select(*selection) + .where(cls.match(FTS5Model.clean_query(term))) + .order_by(order_by)) + + @classmethod + def _fts_cmd_sql(cls, cmd, **extra_params): + tbl = cls._meta.entity + columns = [tbl] + values = [cmd] + for key, value in extra_params.items(): + columns.append(Entity(key)) + values.append(value) + + return NodeList(( + SQL('INSERT INTO'), + cls._meta.entity, + EnclosedNodeList(columns), + SQL('VALUES'), + EnclosedNodeList(values))) + + @classmethod + def _fts_cmd(cls, cmd, **extra_params): + query = cls._fts_cmd_sql(cmd, **extra_params) + return cls._meta.database.execute(query) + + @classmethod + def automerge(cls, level): + if not (0 <= level <= 16): + raise ValueError('level must be between 0 and 16') + return cls._fts_cmd('automerge', rank=level) + + @classmethod + def merge(cls, npages): + return cls._fts_cmd('merge', rank=npages) + + @classmethod + def set_pgsz(cls, pgsz): + return cls._fts_cmd('pgsz', rank=pgsz) + + @classmethod + def set_rank(cls, rank_expression): + return cls._fts_cmd('rank', rank=rank_expression) + + @classmethod + def delete_all(cls): + return cls._fts_cmd('delete-all') + + @classmethod + def VocabModel(cls, table_type='row', table=None): + if table_type not in ('row', 'col', 'instance'): + raise ValueError('table_type must be either "row", "col" or ' + '"instance".') + + attr = '_vocab_model_%s' % table_type + + if not hasattr(cls, attr): + class Meta: + database = cls._meta.database + table_name = table or cls._meta.table_name + '_v' + extension_module = fn.fts5vocab( + cls._meta.entity, + SQL(table_type)) + + attrs = { + 'term': VirtualField(TextField), + 'doc': IntegerField(), + 'cnt': IntegerField(), + 'rowid': RowIDField(), + 'Meta': Meta, + } + if table_type == 'col': + attrs['col'] = VirtualField(TextField) + elif table_type == 'instance': + attrs['offset'] = VirtualField(IntegerField) + + class_name = '%sVocab' % cls.__name__ + setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) + + return getattr(cls, attr) + + +def ClosureTable(model_class, foreign_key=None, referencing_class=None, + referencing_key=None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.refs: + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + source_key = model_class._meta.primary_key + if referencing_key is None: + referencing_key = source_key + + class BaseClosureTable(VirtualModel): + depth = VirtualField(IntegerField) + id = VirtualField(IntegerField) + idcolumn = VirtualField(TextField) + parentcolumn = VirtualField(TextField) + root = VirtualField(IntegerField) + tablename = VirtualField(TextField) + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.id)) + .where(cls.root == node) + .objects()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(source_key == cls.root)) + .where(cls.id == node) + .objects()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node.__data__.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(referencing_key) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(source_key << siblings) + .objects()) + + if not include_node: + query = query.where(source_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + options = { + 'tablename': referencing_class._meta.table_name, + 'idcolumn': referencing_key.column_name, + 'parentcolumn': foreign_key.column_name} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta}) + + +class LSMTable(VirtualModel): + class Meta: + extension_module = 'lsm1' + filename = None + + @classmethod + def clean_options(cls, options): + filename = cls._meta.filename + if not filename: + raise ValueError('LSM1 extension requires that you specify a ' + 'filename for the LSM database.') + else: + if len(filename) >= 2 and filename[0] != '"': + filename = '"%s"' % filename + if not cls._meta.primary_key: + raise ValueError('LSM1 models must specify a primary-key field.') + + key = cls._meta.primary_key + if isinstance(key, AutoField): + raise ValueError('LSM1 models must explicitly declare a primary ' + 'key field.') + if not isinstance(key, (TextField, BlobField, IntegerField)): + raise ValueError('LSM1 key must be a TextField, BlobField, or ' + 'IntegerField.') + key._hidden = True + if isinstance(key, IntegerField): + data_type = 'UINT' + elif isinstance(key, BlobField): + data_type = 'BLOB' + else: + data_type = 'TEXT' + cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] + + # Does the key map to a scalar value, or a tuple of values? + if len(cls._meta.sorted_fields) == 2: + cls._meta._value_field = cls._meta.sorted_fields[1] + else: + cls._meta._value_field = None + + return options + + @classmethod + def load_extension(cls, path='lsm.so'): + cls._meta.database.load_extension(path) + + @staticmethod + def slice_to_expr(key, idx): + if idx.start is not None and idx.stop is not None: + return key.between(idx.start, idx.stop) + elif idx.start is not None: + return key >= idx.start + elif idx.stop is not None: + return key <= idx.stop + + @staticmethod + def _apply_lookup_to_query(query, key, lookup): + if isinstance(lookup, slice): + expr = LSMTable.slice_to_expr(key, lookup) + if expr is not None: + query = query.where(expr) + return query, False + elif isinstance(lookup, Expression): + return query.where(lookup), False + else: + return query.where(key == lookup), True + + @classmethod + def get_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.select().namedtuples(), + cls._meta.primary_key, + pk) + + if is_single: + try: + row = query.get() + except cls.DoesNotExist: + raise KeyError(pk) + return row[1] if cls._meta._value_field is not None else row + else: + return query + + @classmethod + def set_by_id(cls, key, value): + if cls._meta._value_field is not None: + data = {cls._meta._value_field: value} + elif isinstance(value, tuple): + data = {} + for field, fval in zip(cls._meta.sorted_fields[1:], value): + data[field] = fval + elif isinstance(value, dict): + data = value + elif isinstance(value, cls): + data = value.__dict__ + data[cls._meta.primary_key] = key + cls.replace(data).execute() + + @classmethod + def delete_by_id(cls, pk): + query, is_single = cls._apply_lookup_to_query( + cls.delete(), + cls._meta.primary_key, + pk) + return query.execute() + + +OP.MATCH = 'MATCH' + +def _sqlite_regexp(regex, value): + return re.search(regex, value) is not None + + +class SqliteExtDatabase(SqliteDatabase): + def __init__(self, database, c_extensions=None, rank_functions=True, + hash_functions=False, regexp_function=False, + bloomfilter=False, json_contains=False, *args, **kwargs): + super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) + self._row_factory = None + + if c_extensions and not CYTHON_SQLITE_EXTENSIONS: + raise ImproperlyConfigured('SqliteExtDatabase initialized with ' + 'C extensions, but shared library was ' + 'not found!') + prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) + if rank_functions: + if prefer_c: + register_rank_functions(self) + else: + self.register_function(bm25, 'fts_bm25') + self.register_function(rank, 'fts_rank') + self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. + self.register_function(bm25, 'fts_lucene') + if hash_functions: + if not prefer_c: + raise ValueError('C extension required to register hash ' + 'functions.') + register_hash_functions(self) + if regexp_function: + self.register_function(_sqlite_regexp, 'regexp', 2) + if bloomfilter: + if not prefer_c: + raise ValueError('C extension required to use bloomfilter.') + register_bloomfilter(self) + if json_contains: + self.register_function(_json_contains, 'json_contains') + + self._c_extensions = prefer_c + + def _add_conn_hooks(self, conn): + super(SqliteExtDatabase, self)._add_conn_hooks(conn) + if self._row_factory: + conn.row_factory = self._row_factory + + def row_factory(self, fn): + self._row_factory = fn + + +if CYTHON_SQLITE_EXTENSIONS: + SQLITE_STATUS_MEMORY_USED = 0 + SQLITE_STATUS_PAGECACHE_USED = 1 + SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 + SQLITE_STATUS_SCRATCH_USED = 3 + SQLITE_STATUS_SCRATCH_OVERFLOW = 4 + SQLITE_STATUS_MALLOC_SIZE = 5 + SQLITE_STATUS_PARSER_STACK = 6 + SQLITE_STATUS_PAGECACHE_SIZE = 7 + SQLITE_STATUS_SCRATCH_SIZE = 8 + SQLITE_STATUS_MALLOC_COUNT = 9 + SQLITE_DBSTATUS_LOOKASIDE_USED = 0 + SQLITE_DBSTATUS_CACHE_USED = 1 + SQLITE_DBSTATUS_SCHEMA_USED = 2 + SQLITE_DBSTATUS_STMT_USED = 3 + SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 + SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 + SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 + SQLITE_DBSTATUS_CACHE_HIT = 7 + SQLITE_DBSTATUS_CACHE_MISS = 8 + SQLITE_DBSTATUS_CACHE_WRITE = 9 + SQLITE_DBSTATUS_DEFERRED_FKS = 10 + #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 + + def __status__(flag, return_highwater=False): + """ + Expose a sqlite3_status() call for a particular flag as a property of + the Database object. + """ + def getter(self): + result = sqlite_get_status(flag) + return result[1] if return_highwater else result + return property(getter) + + def __dbstatus__(flag, return_highwater=False, return_current=False): + """ + Expose a sqlite3_dbstatus() call for a particular flag as a property of + the Database instance. Unlike sqlite3_status(), the dbstatus properties + pertain to the current connection. + """ + def getter(self): + if self._state.conn is None: + raise ImproperlyConfigured('database connection not opened.') + result = sqlite_get_db_status(self._state.conn, flag) + if return_current: + return result[0] + return result[1] if return_highwater else result + return property(getter) + + class CSqliteExtDatabase(SqliteExtDatabase): + def __init__(self, *args, **kwargs): + self._conn_helper = None + self._commit_hook = self._rollback_hook = self._update_hook = None + self._replace_busy_handler = False + super(CSqliteExtDatabase, self).__init__(*args, **kwargs) + + def init(self, database, replace_busy_handler=False, **kwargs): + super(CSqliteExtDatabase, self).init(database, **kwargs) + self._replace_busy_handler = replace_busy_handler + + def _close(self, conn): + if self._commit_hook: + self._conn_helper.set_commit_hook(None) + if self._rollback_hook: + self._conn_helper.set_rollback_hook(None) + if self._update_hook: + self._conn_helper.set_update_hook(None) + return super(CSqliteExtDatabase, self)._close(conn) + + def _add_conn_hooks(self, conn): + super(CSqliteExtDatabase, self)._add_conn_hooks(conn) + self._conn_helper = ConnectionHelper(conn) + if self._commit_hook is not None: + self._conn_helper.set_commit_hook(self._commit_hook) + if self._rollback_hook is not None: + self._conn_helper.set_rollback_hook(self._rollback_hook) + if self._update_hook is not None: + self._conn_helper.set_update_hook(self._update_hook) + if self._replace_busy_handler: + timeout = self._timeout or 5 + self._conn_helper.set_busy_handler(timeout * 1000) + + def on_commit(self, fn): + self._commit_hook = fn + if not self.is_closed(): + self._conn_helper.set_commit_hook(fn) + return fn + + def on_rollback(self, fn): + self._rollback_hook = fn + if not self.is_closed(): + self._conn_helper.set_rollback_hook(fn) + return fn + + def on_update(self, fn): + self._update_hook = fn + if not self.is_closed(): + self._conn_helper.set_update_hook(fn) + return fn + + def changes(self): + return self._conn_helper.changes() + + @property + def last_insert_rowid(self): + return self._conn_helper.last_insert_rowid() + + @property + def autocommit(self): + return self._conn_helper.autocommit() + + def backup(self, destination, pages=None, name=None, progress=None): + return backup(self.connection(), destination.connection(), + pages=pages, name=name, progress=progress) + + def backup_to_file(self, filename, pages=None, name=None, + progress=None): + return backup_to_file(self.connection(), filename, pages=pages, + name=name, progress=progress) + + def blob_open(self, table, column, rowid, read_only=False): + return Blob(self, table, column, rowid, read_only) + + # Status properties. + memory_used = __status__(SQLITE_STATUS_MEMORY_USED) + malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) + malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) + pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) + pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) + pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) + scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) + scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) + scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) + + # Connection status properties. + lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) + lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) + lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, + True) + lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, + True) + cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) + #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, + # False, True) + schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) + statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) + cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) + cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) + cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) + + +def match(lhs, rhs): + return Expression(lhs, OP.MATCH, rhs) + +def _parse_match_info(buf): + # See http://sqlite.org/fts3.html#matchinfo + bufsize = len(buf) # Length in bytes. + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + +def get_weights(ncol, raw_weights): + if not raw_weights: + return [1] * ncol + else: + weights = [0] * ncol + for i, weight in enumerate(raw_weights): + weights[i] = weight + return weights + +# Ranking implementation, which parse matchinfo. +def rank(raw_match_info, *raw_weights): + # Handle match_info called w/default args 'pcx' - based on the example rank + # function http://sqlite.org/fts3.html#appendix_a + match_info = _parse_match_info(raw_match_info) + score = 0.0 + + p, c = match_info[:2] + weights = get_weights(c, raw_weights) + + # matchinfo X value corresponds to, for each phrase in the search query, a + # list of 3 values for each column in the search table. + # So if we have a two-phrase search query and three columns of data, the + # following would be the layout: + # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] + # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + weight = weights[col_num] + if not weight: + continue + + col_idx = phrase_info_idx + (col_num * 3) + + # The idea is that we count the number of times the phrase appears + # in this column of the current row, compared to how many times it + # appears in this column across all rows. The ratio of these values + # provides a rough way to score based on "high value" terms. + row_hits = match_info[col_idx] + all_rows_hits = match_info[col_idx + 1] + if row_hits > 0: + score += weight * (float(row_hits) / all_rows_hits) + + return -score + +# Okapi BM25 ranking implementation (FTS4 only). +def bm25(raw_match_info, *args): + """ + Usage: + + # Format string *must* be pcnalx + # Second parameter to bm25 specifies the index of the column, on + # the table being queries. + bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank + """ + match_info = _parse_match_info(raw_match_info) + K = 1.2 + B = 0.75 + score = 0.0 + + P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. + term_count = match_info[P_O] # n + col_count = match_info[C_O] + total_docs = match_info[N_O] # N + L_O = A_O + col_count + X_O = L_O + col_count + + # Worked example of pcnalx for two columns and two phrases, 100 docs total. + # { + # p = 2 + # c = 2 + # n = 100 + # a0 = 4 -- avg number of tokens for col0, e.g. title + # a1 = 40 -- avg number of tokens for col1, e.g. body + # l0 = 5 -- curr doc has 5 tokens in col0 + # l1 = 30 -- curr doc has 30 tokens in col1 + # + # x000 -- hits this row for phrase0, col0 + # x001 -- hits all rows for phrase0, col0 + # x002 -- rows with phrase0 in col0 at least once + # + # x010 -- hits this row for phrase0, col1 + # x011 -- hits all rows for phrase0, col1 + # x012 -- rows with phrase0 in col1 at least once + # + # x100 -- hits this row for phrase1, col0 + # x101 -- hits all rows for phrase1, col0 + # x102 -- rows with phrase1 in col0 at least once + # + # x110 -- hits this row for phrase1, col1 + # x111 -- hits all rows for phrase1, col1 + # x112 -- rows with phrase1 in col1 at least once + # } + + weights = get_weights(col_count, args) + + for i in range(term_count): + for j in range(col_count): + weight = weights[j] + if weight == 0: + continue + + x = X_O + (3 * (j + i * col_count)) + term_frequency = float(match_info[x]) # f(qi, D) + docs_with_term = float(match_info[x + 2]) # n(qi) + + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) + idf = math.log( + (total_docs - docs_with_term + 0.5) / + (docs_with_term + 0.5)) + if idf <= 0.0: + idf = 1e-6 + + doc_length = float(match_info[L_O + j]) # |D| + avg_length = float(match_info[A_O + j]) or 1. # avgdl + ratio = doc_length / avg_length + + num = term_frequency * (K + 1.0) + b_part = 1.0 - B + (B * ratio) + denom = term_frequency + (K * b_part) + + pc_score = idf * (num / denom) + score += (pc_score * weight) + + return -score + + +def _json_contains(src_json, obj_json): + stack = [] + try: + stack.append((json.loads(obj_json), json.loads(src_json))) + except: + # Invalid JSON! + return False + + while stack: + obj, src = stack.pop() + if isinstance(src, dict): + if isinstance(obj, dict): + for key in obj: + if key not in src: + return False + stack.append((obj[key], src[key])) + elif isinstance(obj, list): + for item in obj: + if item not in src: + return False + elif obj not in src: + return False + elif isinstance(src, list): + if isinstance(obj, dict): + return False + elif isinstance(obj, list): + try: + for i in range(len(obj)): + stack.append((obj[i], src[i])) + except IndexError: + return False + elif obj not in src: + return False + elif obj != src: + return False + return True diff --git a/libs/playhouse/sqlite_udf.py b/libs/playhouse/sqlite_udf.py new file mode 100644 index 000000000..050dc9b15 --- /dev/null +++ b/libs/playhouse/sqlite_udf.py @@ -0,0 +1,536 @@ +import datetime +import hashlib +import heapq +import math +import os +import random +import re +import sys +import threading +import zlib +try: + from collections import Counter +except ImportError: + Counter = None +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + +try: + from playhouse._sqlite_ext import TableFunction +except ImportError: + TableFunction = None + + +SQLITE_DATETIME_FORMATS = ( + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d', + '%H:%M:%S', + '%H:%M:%S.%f', + '%H:%M') + +from peewee import format_date_time + +def format_date_time_sqlite(date_value): + return format_date_time(date_value, SQLITE_DATETIME_FORMATS) + +try: + from playhouse import _sqlite_udf as cython_udf +except ImportError: + cython_udf = None + + +# Group udf by function. +CONTROL_FLOW = 'control_flow' +DATE = 'date' +FILE = 'file' +HELPER = 'helpers' +MATH = 'math' +STRING = 'string' + +AGGREGATE_COLLECTION = {} +TABLE_FUNCTION_COLLECTION = {} +UDF_COLLECTION = {} + + +class synchronized_dict(dict): + def __init__(self, *args, **kwargs): + super(synchronized_dict, self).__init__(*args, **kwargs) + self._lock = threading.Lock() + + def __getitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__getitem__(key) + + def __setitem__(self, key, value): + with self._lock: + return super(synchronized_dict, self).__setitem__(key, value) + + def __delitem__(self, key): + with self._lock: + return super(synchronized_dict, self).__delitem__(key) + + +STATE = synchronized_dict() +SETTINGS = synchronized_dict() + +# Class and function decorators. +def aggregate(*groups): + def decorator(klass): + for group in groups: + AGGREGATE_COLLECTION.setdefault(group, []) + AGGREGATE_COLLECTION[group].append(klass) + return klass + return decorator + +def table_function(*groups): + def decorator(klass): + for group in groups: + TABLE_FUNCTION_COLLECTION.setdefault(group, []) + TABLE_FUNCTION_COLLECTION[group].append(klass) + return klass + return decorator + +def udf(*groups): + def decorator(fn): + for group in groups: + UDF_COLLECTION.setdefault(group, []) + UDF_COLLECTION[group].append(fn) + return fn + return decorator + +# Register aggregates / functions with connection. +def register_aggregate_groups(db, *groups): + seen = set() + for group in groups: + klasses = AGGREGATE_COLLECTION.get(group, ()) + for klass in klasses: + name = getattr(klass, 'name', klass.__name__) + if name not in seen: + seen.add(name) + db.register_aggregate(klass, name) + +def register_table_function_groups(db, *groups): + seen = set() + for group in groups: + klasses = TABLE_FUNCTION_COLLECTION.get(group, ()) + for klass in klasses: + if klass.name not in seen: + seen.add(klass.name) + db.register_table_function(klass) + +def register_udf_groups(db, *groups): + seen = set() + for group in groups: + functions = UDF_COLLECTION.get(group, ()) + for function in functions: + name = function.__name__ + if name not in seen: + seen.add(name) + db.register_function(function, name) + +def register_groups(db, *groups): + register_aggregate_groups(db, *groups) + register_table_function_groups(db, *groups) + register_udf_groups(db, *groups) + +def register_all(db): + register_aggregate_groups(db, *AGGREGATE_COLLECTION) + register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION) + register_udf_groups(db, *UDF_COLLECTION) + + +# Begin actual user-defined functions and aggregates. + +# Scalar functions. +@udf(CONTROL_FLOW) +def if_then_else(cond, truthy, falsey=None): + if cond: + return truthy + return falsey + +@udf(DATE) +def strip_tz(date_str): + date_str = date_str.replace('T', ' ') + tz_idx1 = date_str.find('+') + if tz_idx1 != -1: + return date_str[:tz_idx1] + tz_idx2 = date_str.find('-') + if tz_idx2 > 13: + return date_str[:tz_idx2] + return date_str + +@udf(DATE) +def human_delta(nseconds, glue=', '): + parts = ( + (86400 * 365, 'year'), + (86400 * 30, 'month'), + (86400 * 7, 'week'), + (86400, 'day'), + (3600, 'hour'), + (60, 'minute'), + (1, 'second'), + ) + accum = [] + for offset, name in parts: + val, nseconds = divmod(nseconds, offset) + if val: + suffix = val != 1 and 's' or '' + accum.append('%s %s%s' % (val, name, suffix)) + if not accum: + return '0 seconds' + return glue.join(accum) + +@udf(FILE) +def file_ext(filename): + try: + res = os.path.splitext(filename) + except ValueError: + return None + return res[1] + +@udf(FILE) +def file_read(filename): + try: + with open(filename) as fh: + return fh.read() + except: + pass + +if sys.version_info[0] == 2: + @udf(HELPER) + def gzip(data, compression=9): + return buffer(zlib.compress(data, compression)) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) +else: + @udf(HELPER) + def gzip(data, compression=9): + if isinstance(data, str): + data = bytes(data.encode('raw_unicode_escape')) + return zlib.compress(data, compression) + + @udf(HELPER) + def gunzip(data): + return zlib.decompress(data) + +@udf(HELPER) +def hostname(url): + parse_result = urlparse(url) + if parse_result: + return parse_result.netloc + +@udf(HELPER) +def toggle(key): + key = key.lower() + STATE[key] = ret = not STATE.get(key) + return ret + +@udf(HELPER) +def setting(key, value=None): + if value is None: + return SETTINGS.get(key) + else: + SETTINGS[key] = value + return value + +@udf(HELPER) +def clear_settings(): + SETTINGS.clear() + +@udf(HELPER) +def clear_toggles(): + STATE.clear() + +@udf(MATH) +def randomrange(start, end=None, step=None): + if end is None: + start, end = 0, start + elif step is None: + step = 1 + return random.randrange(start, end, step) + +@udf(MATH) +def gauss_distribution(mean, sigma): + try: + return random.gauss(mean, sigma) + except ValueError: + return None + +@udf(MATH) +def sqrt(n): + try: + return math.sqrt(n) + except ValueError: + return None + +@udf(MATH) +def tonumber(s): + try: + return int(s) + except ValueError: + try: + return float(s) + except: + return None + +@udf(STRING) +def substr_count(haystack, needle): + if not haystack or not needle: + return 0 + return haystack.count(needle) + +@udf(STRING) +def strip_chars(haystack, chars): + return haystack.strip(chars) + +def _hash(constructor, *args): + hash_obj = constructor() + for arg in args: + hash_obj.update(arg) + return hash_obj.hexdigest() + +# Aggregates. +class _heap_agg(object): + def __init__(self): + self.heap = [] + self.ct = 0 + + def process(self, value): + return value + + def step(self, value): + self.ct += 1 + heapq.heappush(self.heap, self.process(value)) + +class _datetime_heap_agg(_heap_agg): + def process(self, value): + return format_date_time_sqlite(value) + +if sys.version_info[:2] == (2, 6): + def total_seconds(td): + return (td.seconds + + (td.days * 86400) + + (td.microseconds / (10.**6))) +else: + total_seconds = lambda td: td.total_seconds() + +@aggregate(DATE) +class mintdiff(_datetime_heap_agg): + def finalize(self): + dtp = min_diff = None + while self.heap: + if min_diff is None: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + dt = heapq.heappop(self.heap) + diff = dt - dtp + if min_diff is None or min_diff > diff: + min_diff = diff + dtp = dt + if min_diff is not None: + return total_seconds(min_diff) + +@aggregate(DATE) +class avgtdiff(_datetime_heap_agg): + def finalize(self): + if self.ct < 1: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + dtp = None + while self.heap: + if total == 0: + if dtp is None: + dtp = heapq.heappop(self.heap) + continue + + dt = heapq.heappop(self.heap) + diff = dt - dtp + ct += 1 + total += total_seconds(diff) + dtp = dt + + return float(total) / ct + +@aggregate(DATE) +class duration(object): + def __init__(self): + self._min = self._max = None + + def step(self, value): + dt = format_date_time_sqlite(value) + if self._min is None or dt < self._min: + self._min = dt + if self._max is None or dt > self._max: + self._max = dt + + def finalize(self): + if self._min and self._max: + td = (self._max - self._min) + return total_seconds(td) + return None + +@aggregate(MATH) +class mode(object): + if Counter: + def __init__(self): + self.items = Counter() + + def step(self, *args): + self.items.update(args) + + def finalize(self): + if self.items: + return self.items.most_common(1)[0][0] + else: + def __init__(self): + self.items = [] + + def step(self, item): + self.items.append(item) + + def finalize(self): + if self.items: + return max(set(self.items), key=self.items.count) + +@aggregate(MATH) +class minrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + prev = min_diff = None + + while self.heap: + if min_diff is None: + if prev is None: + prev = heapq.heappop(self.heap) + continue + curr = heapq.heappop(self.heap) + diff = curr - prev + if min_diff is None or min_diff > diff: + min_diff = diff + prev = curr + return min_diff + +@aggregate(MATH) +class avgrange(_heap_agg): + def finalize(self): + if self.ct == 0: + return + elif self.ct == 1: + return 0 + + total = ct = 0 + prev = None + while self.heap: + if total == 0: + if prev is None: + prev = heapq.heappop(self.heap) + continue + + curr = heapq.heappop(self.heap) + diff = curr - prev + ct += 1 + total += diff + prev = curr + + return float(total) / ct + +@aggregate(MATH) +class _range(object): + name = 'range' + + def __init__(self): + self._min = self._max = None + + def step(self, value): + if self._min is None or value < self._min: + self._min = value + if self._max is None or value > self._max: + self._max = value + + def finalize(self): + if self._min is not None and self._max is not None: + return self._max - self._min + return None + +@aggregate(MATH) +class stddev(object): + def __init__(self): + self.n = 0 + self.values = [] + def step(self, v): + self.n += 1 + self.values.append(v) + def finalize(self): + if self.n <= 1: + return 0 + mean = sum(self.values) / self.n + return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1)) + + +if cython_udf is not None: + damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist) + levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist) + str_dist = udf(STRING)(cython_udf.str_dist) + median = aggregate(MATH)(cython_udf.median) + + +if TableFunction is not None: + @table_function(STRING) + class RegexSearch(TableFunction): + params = ['regex', 'search_string'] + columns = ['match'] + name = 'regex_search' + + def initialize(self, regex=None, search_string=None): + self._iter = re.finditer(regex, search_string) + + def iterate(self, idx): + return (next(self._iter).group(0),) + + @table_function(DATE) + class DateSeries(TableFunction): + params = ['start', 'stop', 'step_seconds'] + columns = ['date'] + name = 'date_series' + + def initialize(self, start, stop, step_seconds=86400): + self.start = format_date_time_sqlite(start) + self.stop = format_date_time_sqlite(stop) + step_seconds = int(step_seconds) + self.step_seconds = datetime.timedelta(seconds=step_seconds) + + if (self.start.hour == 0 and + self.start.minute == 0 and + self.start.second == 0 and + step_seconds >= 86400): + self.format = '%Y-%m-%d' + elif (self.start.year == 1900 and + self.start.month == 1 and + self.start.day == 1 and + self.stop.year == 1900 and + self.stop.month == 1 and + self.stop.day == 1 and + step_seconds < 86400): + self.format = '%H:%M:%S' + else: + self.format = '%Y-%m-%d %H:%M:%S' + + def iterate(self, idx): + if self.start > self.stop: + raise StopIteration + current = self.start + self.start += self.step_seconds + return (current.strftime(self.format),) diff --git a/libs/playhouse/sqliteq.py b/libs/playhouse/sqliteq.py new file mode 100644 index 000000000..e0015878b --- /dev/null +++ b/libs/playhouse/sqliteq.py @@ -0,0 +1,331 @@ +import logging +import weakref +from threading import local as thread_local +from threading import Event +from threading import Thread +try: + from Queue import Queue +except ImportError: + from queue import Queue + +try: + import gevent + from gevent import Greenlet as GThread + from gevent.event import Event as GEvent + from gevent.local import local as greenlet_local + from gevent.queue import Queue as GQueue +except ImportError: + GThread = GQueue = GEvent = None + +from peewee import SENTINEL +from playhouse.sqlite_ext import SqliteExtDatabase + + +logger = logging.getLogger('peewee.sqliteq') + + +class ResultTimeout(Exception): + pass + +class WriterPaused(Exception): + pass + +class ShutdownException(Exception): + pass + + +class AsyncCursor(object): + __slots__ = ('sql', 'params', 'commit', 'timeout', + '_event', '_cursor', '_exc', '_idx', '_rows', '_ready') + + def __init__(self, event, sql, params, commit, timeout): + self._event = event + self.sql = sql + self.params = params + self.commit = commit + self.timeout = timeout + self._cursor = self._exc = self._idx = self._rows = None + self._ready = False + + def set_result(self, cursor, exc=None): + self._cursor = cursor + self._exc = exc + self._idx = 0 + self._rows = cursor.fetchall() if exc is None else [] + self._event.set() + return self + + def _wait(self, timeout=None): + timeout = timeout if timeout is not None else self.timeout + if not self._event.wait(timeout=timeout) and timeout: + raise ResultTimeout('results not ready, timed out.') + if self._exc is not None: + raise self._exc + self._ready = True + + def __iter__(self): + if not self._ready: + self._wait() + if self._exc is not None: + raise self._exc + return self + + def next(self): + if not self._ready: + self._wait() + try: + obj = self._rows[self._idx] + except IndexError: + raise StopIteration + else: + self._idx += 1 + return obj + __next__ = next + + @property + def lastrowid(self): + if not self._ready: + self._wait() + return self._cursor.lastrowid + + @property + def rowcount(self): + if not self._ready: + self._wait() + return self._cursor.rowcount + + @property + def description(self): + return self._cursor.description + + def close(self): + self._cursor.close() + + def fetchall(self): + return list(self) # Iterating implies waiting until populated. + + def fetchone(self): + if not self._ready: + self._wait() + try: + return next(self) + except StopIteration: + return None + +SHUTDOWN = StopIteration +PAUSE = object() +UNPAUSE = object() + + +class Writer(object): + __slots__ = ('database', 'queue') + + def __init__(self, database, queue): + self.database = database + self.queue = queue + + def run(self): + conn = self.database.connection() + try: + while True: + try: + if conn is None: # Paused. + if self.wait_unpause(): + conn = self.database.connection() + else: + conn = self.loop(conn) + except ShutdownException: + logger.info('writer received shutdown request, exiting.') + return + finally: + if conn is not None: + self.database._close(conn) + self.database._state.reset() + + def wait_unpause(self): + obj = self.queue.get() + if obj is UNPAUSE: + logger.info('writer unpaused - reconnecting to database.') + return True + elif obj is SHUTDOWN: + raise ShutdownException() + elif obj is PAUSE: + logger.error('writer received pause, but is already paused.') + else: + obj.set_result(None, WriterPaused()) + logger.warning('writer paused, not handling %s', obj) + + def loop(self, conn): + obj = self.queue.get() + if isinstance(obj, AsyncCursor): + self.execute(obj) + elif obj is PAUSE: + logger.info('writer paused - closing database connection.') + self.database._close(conn) + self.database._state.reset() + return + elif obj is UNPAUSE: + logger.error('writer received unpause, but is already running.') + elif obj is SHUTDOWN: + raise ShutdownException() + else: + logger.error('writer received unsupported object: %s', obj) + return conn + + def execute(self, obj): + logger.debug('received query %s', obj.sql) + try: + cursor = self.database._execute(obj.sql, obj.params, obj.commit) + except Exception as execute_err: + cursor = None + exc = execute_err # python3 is so fucking lame. + else: + exc = None + return obj.set_result(cursor, exc) + + +class SqliteQueueDatabase(SqliteExtDatabase): + WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL ' + 'journal mode when using this feature. WAL mode ' + 'allows one or more readers to continue reading ' + 'while another connection writes to the ' + 'database.') + + def __init__(self, database, use_gevent=False, autostart=True, + queue_max_size=None, results_timeout=None, *args, **kwargs): + kwargs['check_same_thread'] = False + + # Ensure that journal_mode is WAL. This value is passed to the parent + # class constructor below. + pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None)) + + # Reference to execute_sql on the parent class. Since we've overridden + # execute_sql(), this is just a handy way to reference the real + # implementation. + Parent = super(SqliteQueueDatabase, self) + self._execute = Parent.execute_sql + + # Call the parent class constructor with our modified pragmas. + Parent.__init__(database, pragmas=pragmas, *args, **kwargs) + + self._autostart = autostart + self._results_timeout = results_timeout + self._is_stopped = True + + # Get different objects depending on the threading implementation. + self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size) + + # Create the writer thread, optionally starting it. + self._create_write_queue() + if self._autostart: + self.start() + + def get_thread_impl(self, use_gevent): + return GreenletHelper if use_gevent else ThreadHelper + + def _validate_journal_mode(self, pragmas=None): + if not pragmas: + return {'journal_mode': 'wal'} + + if not isinstance(pragmas, dict): + pragmas = dict((k.lower(), v) for (k, v) in pragmas) + if pragmas.get('journal_mode', 'wal').lower() != 'wal': + raise ValueError(self.WAL_MODE_ERROR_MESSAGE) + + pragmas['journal_mode'] = 'wal' + return pragmas + + def _create_write_queue(self): + self._write_queue = self._thread_helper.queue() + + def queue_size(self): + return self._write_queue.qsize() + + def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None): + if commit is SENTINEL: + commit = not sql.lower().startswith('select') + + if not commit: + return self._execute(sql, params, commit=commit) + + cursor = AsyncCursor( + event=self._thread_helper.event(), + sql=sql, + params=params, + commit=commit, + timeout=self._results_timeout if timeout is None else timeout) + self._write_queue.put(cursor) + return cursor + + def start(self): + with self._lock: + if not self._is_stopped: + return False + def run(): + writer = Writer(self, self._write_queue) + writer.run() + + self._writer = self._thread_helper.thread(run) + self._writer.start() + self._is_stopped = False + return True + + def stop(self): + logger.debug('environment stop requested.') + with self._lock: + if self._is_stopped: + return False + self._write_queue.put(SHUTDOWN) + self._writer.join() + self._is_stopped = True + return True + + def is_stopped(self): + with self._lock: + return self._is_stopped + + def pause(self): + with self._lock: + self._write_queue.put(PAUSE) + + def unpause(self): + with self._lock: + self._write_queue.put(UNPAUSE) + + def __unsupported__(self, *args, **kwargs): + raise ValueError('This method is not supported by %r.' % type(self)) + atomic = transaction = savepoint = __unsupported__ + + +class ThreadHelper(object): + __slots__ = ('queue_max_size',) + + def __init__(self, queue_max_size=None): + self.queue_max_size = queue_max_size + + def event(self): return Event() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return Queue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + thread = Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + return thread + + +class GreenletHelper(ThreadHelper): + __slots__ = () + + def event(self): return GEvent() + + def queue(self, max_size=None): + max_size = max_size if max_size is not None else self.queue_max_size + return GQueue(maxsize=max_size or 0) + + def thread(self, fn, *args, **kwargs): + def wrap(*a, **k): + gevent.sleep() + return fn(*a, **k) + return GThread(wrap, *args, **kwargs) diff --git a/libs/playhouse/test_utils.py b/libs/playhouse/test_utils.py new file mode 100644 index 000000000..333dc078b --- /dev/null +++ b/libs/playhouse/test_utils.py @@ -0,0 +1,62 @@ +from functools import wraps +import logging + + +logger = logging.getLogger('peewee') + + +class _QueryLogHandler(logging.Handler): + def __init__(self, *args, **kwargs): + self.queries = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.queries.append(record) + + +class count_queries(object): + def __init__(self, only_select=False): + self.only_select = only_select + self.count = 0 + + def get_queries(self): + return self._handler.queries + + def __enter__(self): + self._handler = _QueryLogHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(self._handler) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.removeHandler(self._handler) + if self.only_select: + self.count = len([q for q in self._handler.queries + if q.msg[0].startswith('SELECT ')]) + else: + self.count = len(self._handler.queries) + + +class assert_query_count(count_queries): + def __init__(self, expected, only_select=False): + super(assert_query_count, self).__init__(only_select=only_select) + self.expected = expected + + def __call__(self, f): + @wraps(f) + def decorated(*args, **kwds): + with self: + ret = f(*args, **kwds) + + self._assert_count() + return ret + + return decorated + + def _assert_count(self): + error_msg = '%s != %s' % (self.count, self.expected) + assert self.count == self.expected, error_msg + + def __exit__(self, exc_type, exc_val, exc_tb): + super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) + self._assert_count() diff --git a/libs/sqlite3worker.py b/libs/sqlite3worker.py deleted file mode 100644 index c955066d8..000000000 --- a/libs/sqlite3worker.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2014 Palantir Technologies -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -"""Thread safe sqlite3 interface.""" - -__author__ = "Shawn Lee" -__email__ = "shawnl@palantir.com" -__license__ = "MIT" - -import logging -try: - import queue as Queue # module re-named in Python 3 -except ImportError: - import Queue -import sqlite3 -import threading -import time -import uuid - -LOGGER = logging.getLogger('sqlite3worker') - - -class Sqlite3Worker(threading.Thread): - """Sqlite thread safe object. - - Example: - from sqlite3worker import Sqlite3Worker - sql_worker = Sqlite3Worker("/tmp/test.sqlite") - sql_worker.execute( - "CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") - sql_worker.execute( - "INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) - sql_worker.execute( - "INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) - sql_worker.execute("SELECT * from tester") - sql_worker.close() - """ - def __init__(self, file_name, max_queue_size=100, as_dict=False): - """Automatically starts the thread. - - Args: - file_name: The name of the file. - max_queue_size: The max queries that will be queued. - as_dict: Return result as a dictionary. - """ - threading.Thread.__init__(self) - self.daemon = True - self.sqlite3_conn = sqlite3.connect( - file_name, check_same_thread=False, - detect_types=sqlite3.PARSE_DECLTYPES) - if as_dict: - self.sqlite3_conn.row_factory = dict_factory - self.sqlite3_cursor = self.sqlite3_conn.cursor() - self.sql_queue = Queue.Queue(maxsize=max_queue_size) - self.results = {} - self.max_queue_size = max_queue_size - self.exit_set = False - # Token that is put into queue when close() is called. - self.exit_token = str(uuid.uuid4()) - self.start() - self.thread_running = True - - def run(self): - """Thread loop. - - This is an infinite loop. The iter method calls self.sql_queue.get() - which blocks if there are not values in the queue. As soon as values - are placed into the queue the process will continue. - - If many executes happen at once it will churn through them all before - calling commit() to speed things up by reducing the number of times - commit is called. - """ - LOGGER.debug("run: Thread started") - execute_count = 0 - for token, query, values, only_one, execute_many in iter(self.sql_queue.get, None): - LOGGER.debug("sql_queue: %s", self.sql_queue.qsize()) - if token != self.exit_token: - LOGGER.debug("run: %s, %s", query, values) - self.run_query(token, query, values, only_one, execute_many) - execute_count += 1 - # Let the executes build up a little before committing to disk - # to speed things up. - if ( - self.sql_queue.empty() or - execute_count == self.max_queue_size): - LOGGER.debug("run: commit") - self.sqlite3_conn.commit() - execute_count = 0 - # Only exit if the queue is empty. Otherwise keep getting - # through the queue until it's empty. - if self.exit_set and self.sql_queue.empty(): - self.sqlite3_conn.commit() - self.sqlite3_conn.close() - self.thread_running = False - return - - def run_query(self, token, query, values, only_one=False, execute_many=False): - """Run a query. - - Args: - token: A uuid object of the query you want returned. - query: A sql query with ? placeholders for values. - values: A tuple of values to replace "?" in query. - """ - if query.lower().strip().startswith(("select", "pragma")): - try: - self.sqlite3_cursor.execute(query, values) - if only_one: - self.results[token] = self.sqlite3_cursor.fetchone() - else: - self.results[token] = self.sqlite3_cursor.fetchall() - except sqlite3.Error as err: - # Put the error into the output queue since a response - # is required. - self.results[token] = ( - "Query returned error: %s: %s: %s" % (query, values, err)) - LOGGER.error( - "Query returned error: %s: %s: %s", query, values, err) - else: - try: - if execute_many: - self.sqlite3_cursor.executemany(query, values) - if query.lower().strip().startswith(("insert", "update", "delete")): - self.results[token] = self.sqlite3_cursor.rowcount - else: - self.sqlite3_cursor.execute(query, values) - if query.lower().strip().startswith(("insert", "update", "delete")): - self.results[token] = self.sqlite3_cursor.rowcount - except sqlite3.Error as err: - self.results[token] = ( - "Query returned error: %s: %s: %s" % (query, values, err)) - LOGGER.error( - "Query returned error: %s: %s: %s", query, values, err) - - def close(self): - """Close down the thread and close the sqlite3 database file.""" - self.exit_set = True - self.sql_queue.put((self.exit_token, "", "", "", ""), timeout=5) - # Sleep and check that the thread is done before returning. - while self.thread_running: - time.sleep(.01) # Don't kill the CPU waiting. - - @property - def queue_size(self): - """Return the queue size.""" - return self.sql_queue.qsize() - - def query_results(self, token): - """Get the query results for a specific token. - - Args: - token: A uuid object of the query you want returned. - - Returns: - Return the results of the query when it's executed by the thread. - """ - delay = .001 - while True: - if token in self.results: - return_val = self.results[token] - del self.results[token] - return return_val - # Double back on the delay to a max of 8 seconds. This prevents - # a long lived select statement from trashing the CPU with this - # infinite loop as it's waiting for the query results. - LOGGER.debug("Sleeping: %s %s", delay, token) - time.sleep(delay) - if delay < 8: - delay += delay - - def execute(self, query, values=None, only_one=False, execute_many=False): - """Execute a query. - - Args: - query: The sql string using ? for placeholders of dynamic values. - values: A tuple of values to be replaced into the ? of the query. - - Returns: - If it's a select query it will return the results of the query. - """ - if self.exit_set: - LOGGER.debug("Exit set, not running: %s, %s", query, values) - return "Exit Called" - LOGGER.debug("execute: %s, %s", query, values) - values = values or [] - # A token to track this query with. - token = str(uuid.uuid4()) - # If it's a select we queue it up with a token to mark the results - # into the output queue so we know what results are ours. - if query.lower().strip().startswith(("select", "insert", "update", "delete", "pragma")): - self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5) - return self.query_results(token) - else: - self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5) - - -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d diff --git a/libs/version.txt b/libs/version.txt index f11ffcce3..5c04c9373 100644 --- a/libs/version.txt +++ b/libs/version.txt @@ -20,6 +20,7 @@ guess_language-spirit=0.5.3 Js2Py=0.63 <-- modified: manually merged from upstream: https://github.com/PiotrDabkowski/Js2Py/pull/192/files knowit=0.3.0-dev msgpack=1.0.2 +peewee=3.14.4 py-pretty=1 pycountry=18.2.23 pyga=2.6.1