Implemented Peewee ORM in replacement of raw SQL queries.

pull/1419/head
morpheus65535 4 years ago committed by GitHub
parent c2eec34aff
commit 2b9d892ca9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,13 +5,16 @@ from datetime import timedelta
from dateutil import rrule
import pretty
import time
import operator
from operator import itemgetter
from functools import reduce
import platform
import re
import json
import hashlib
import apprise
import gc
from peewee import fn, Value
from get_args import args
from config import settings, base_url, save_settings, get_settings
@ -19,8 +22,10 @@ from logger import empty_log
from init import *
import logging
from database import database, get_exclusion_clause, get_profiles_list, get_desired_languages, get_profile_id_name, \
get_audio_profile_languages, update_profile_id_list, convert_list_to_clause
from database import get_exclusion_clause, get_profiles_list, get_desired_languages, get_profile_id_name, \
get_audio_profile_languages, update_profile_id_list, convert_list_to_clause, TableEpisodes, TableShows, \
TableMovies, TableSettingsLanguages, TableSettingsNotifier, TableLanguagesProfiles, TableHistory, \
TableHistoryMovie, TableBlacklist, TableBlacklistMovie
from helper import path_mappings
from get_languages import language_from_alpha2, language_from_alpha3, alpha2_from_alpha3, alpha3_from_alpha2
from get_subtitle import download_subtitle, series_download_subtitles, manual_search, manual_download_subtitle, \
@ -70,7 +75,7 @@ def authenticate(actual_method):
return wrapper
def postprocess(item: dict):
def postprocess(item):
# Remove ffprobe_cache
if 'ffprobe_cache' in item:
del (item['ffprobe_cache'])
@ -310,15 +315,15 @@ class System(Resource):
class Badges(Resource):
@authenticate
def get(self):
missing_episodes = database.execute("SELECT table_shows.tags, table_episodes.monitored, table_shows.seriesType "
"FROM table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId ="
" table_episodes.sonarrSeriesId WHERE missing_subtitles is not null AND "
"missing_subtitles != '[]'" + get_exclusion_clause('series'))
missing_episodes = len(missing_episodes)
episodes_conditions = [(TableEpisodes.missing_subtitles is not None),
(TableEpisodes.missing_subtitles != '[]')]
episodes_conditions += get_exclusion_clause('series')
missing_episodes = TableEpisodes.select().where(reduce(operator.and_, episodes_conditions)).count()
missing_movies = database.execute("SELECT tags, monitored FROM table_movies WHERE missing_subtitles is not "
"null AND missing_subtitles != '[]'" + get_exclusion_clause('movie'))
missing_movies = len(missing_movies)
movies_conditions = [(TableMovies.missing_subtitles is not None),
(TableMovies.missing_subtitles != '[]')]
movies_conditions += get_exclusion_clause('movie')
missing_movies = TableMovies.select().where(reduce(operator.and_, movies_conditions)).count()
throttled_providers = len(eval(str(get_throttled_providers())))
@ -336,7 +341,11 @@ class Badges(Resource):
class Languages(Resource):
@authenticate
def get(self):
result = database.execute("SELECT name, code2, enabled FROM table_settings_languages ORDER BY name")
result = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.enabled)\
.order_by(TableSettingsLanguages.name).dicts()
result = list(result)
for item in result:
item['enabled'] = item['enabled'] == 1
return jsonify(result)
@ -376,16 +385,24 @@ class Searches(Resource):
if query:
if settings.general.getboolean('use_sonarr'):
# Get matching series
series = database.execute("SELECT title, sonarrSeriesId, year FROM table_shows WHERE title LIKE ? "
"ORDER BY title ASC", ("%" + query + "%",))
series = TableShows.select(TableShows.title,
TableShows.sonarrSeriesId,
TableShows.year)\
.where(TableShows.title.contains(query))\
.order_by(TableShows.title)\
.dicts()
series = list(series)
search_list += series
if settings.general.getboolean('use_radarr'):
# Get matching movies
movies = database.execute("SELECT title, radarrId, year FROM table_movies WHERE title LIKE ? ORDER BY "
"title ASC", ("%" + query + "%",))
movies = TableMovies.select(TableMovies.title,
TableMovies.radarrId,
TableMovies.year) \
.where(TableMovies.title.contains(query)) \
.order_by(TableMovies.title) \
.dicts()
movies = list(movies)
search_list += movies
return jsonify(search_list)
@ -396,7 +413,8 @@ class SystemSettings(Resource):
def get(self):
data = get_settings()
notifications = database.execute("SELECT * FROM table_settings_notifier ORDER BY name")
notifications = TableSettingsNotifier.select().order_by(TableSettingsNotifier.name).dicts()
notifications = list(notifications)
for i, item in enumerate(notifications):
item["enabled"] = item["enabled"] == 1
notifications[i] = item
@ -410,37 +428,51 @@ class SystemSettings(Resource):
def post(self):
enabled_languages = request.form.getlist('languages-enabled')
if len(enabled_languages) != 0:
database.execute("UPDATE table_settings_languages SET enabled=0")
TableSettingsLanguages.update({
TableSettingsLanguages.enabled: 0
}).execute()
for code in enabled_languages:
database.execute("UPDATE table_settings_languages SET enabled=1 WHERE code2=?",(code,))
TableSettingsLanguages.update({
TableSettingsLanguages.enabled: 1
})\
.where(TableSettingsLanguages.code2 == code)\
.execute()
event_stream("languages")
languages_profiles = request.form.get('languages-profiles')
if languages_profiles:
existing_ids = database.execute('SELECT profileId FROM table_languages_profiles')
existing_ids = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId).dicts()
existing_ids = list(existing_ids)
existing = [x['profileId'] for x in existing_ids]
for item in json.loads(languages_profiles):
if item['profileId'] in existing:
# Update existing profiles
database.execute('UPDATE table_languages_profiles SET name = ?, cutoff = ?, items = ? '
'WHERE profileId = ?', (item['name'],
item['cutoff'] if item['cutoff'] != 'null' else None,
json.dumps(item['items']),
item['profileId']))
TableLanguagesProfiles.update({
TableLanguagesProfiles.name: item['name'],
TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None,
TableLanguagesProfiles.items: json.dumps(item['items'])
})\
.where(TableLanguagesProfiles.profileId == item['profileId'])\
.execute()
existing.remove(item['profileId'])
else:
# Add new profiles
database.execute('INSERT INTO table_languages_profiles (profileId, name, cutoff, items) '
'VALUES (?, ?, ?, ?)', (item['profileId'],
item['name'],
item['cutoff'] if item['cutoff'] != 'null' else None,
json.dumps(item['items'])))
TableLanguagesProfiles.insert({
TableLanguagesProfiles.profileId: item['profileId'],
TableLanguagesProfiles.name: item['name'],
TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None,
TableLanguagesProfiles.items: json.dumps(item['items'])
}).execute()
for profileId in existing:
# Unassign this profileId from series and movies
database.execute('UPDATE table_shows SET profileId = null WHERE profileId = ?', (profileId,))
database.execute('UPDATE table_movies SET profileId = null WHERE profileId = ?', (profileId,))
TableShows.update({
TableShows.profileId: None
}).where(TableShows.profileId == profileId).execute()
TableMovies.update({
TableMovies.profileId: None
}).where(TableMovies.profileId == profileId).execute()
# Remove deleted profiles
database.execute('DELETE FROM table_languages_profiles WHERE profileId = ?', (profileId,))
TableLanguagesProfiles.delete().where(TableLanguagesProfiles.profileId == profileId).execute()
update_profile_id_list()
event_stream("languages")
@ -454,8 +486,10 @@ class SystemSettings(Resource):
notifications = request.form.getlist('notifications-providers')
for item in notifications:
item = json.loads(item)
database.execute("UPDATE table_settings_notifier SET enabled = ?, url = ? WHERE name = ?",
(item['enabled'], item['url'], item['name']))
TableSettingsNotifier.update({
TableSettingsNotifier.enabled: item['enabled'],
TableSettingsNotifier.url: item['url']
}).where(TableSettingsNotifier.name == item['name']).execute()
save_settings(zip(request.form.keys(), request.form.listvalues()))
event_stream("settings")
@ -574,35 +608,43 @@ class Series(Resource):
length = request.args.get('length') or -1
seriesId = request.args.getlist('seriesid[]')
count = database.execute("SELECT COUNT(*) as count FROM table_shows", only_one=True)['count']
count = TableShows.select().count()
if len(seriesId) != 0:
result = database.execute(
f"SELECT * FROM table_shows WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY sortTitle ASC")
result = TableShows.select()\
.where(TableShows.sonarrSeriesId.in_(seriesId))\
.order_by(TableShows.sortTitle).dicts()
else:
result = database.execute("SELECT * FROM table_shows ORDER BY sortTitle ASC LIMIT ? OFFSET ?"
, (length, start))
result = TableShows.select().order_by(TableShows.sortTitle).limit(length).offset(start).dicts()
result = list(result)
for item in result:
postprocessSeries(item)
# Add missing subtitles episode count
episodeMissingCount = database.execute("SELECT table_shows.tags, table_episodes.monitored, "
"table_shows.seriesType FROM table_episodes INNER JOIN table_shows "
"on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId "
"WHERE table_episodes.sonarrSeriesId=? AND missing_subtitles is not "
"null AND missing_subtitles != '[]'" +
get_exclusion_clause('series'), (item['sonarrSeriesId'],))
episodeMissingCount = len(episodeMissingCount)
episodes_missing_conditions = [(TableEpisodes.sonarrSeriesId == item['sonarrSeriesId']),
(TableEpisodes.missing_subtitles != '[]')]
episodes_missing_conditions += get_exclusion_clause('series')
episodeMissingCount = TableEpisodes.select(TableShows.tags,
TableEpisodes.monitored,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, episodes_missing_conditions))\
.count()
item.update({"episodeMissingCount": episodeMissingCount})
# Add episode count
episodeFileCount = database.execute("SELECT table_shows.tags, table_episodes.monitored, "
"table_shows.seriesType FROM table_episodes INNER JOIN table_shows on "
"table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE "
"table_episodes.sonarrSeriesId=?" + get_exclusion_clause('series'),
(item['sonarrSeriesId'],))
episodeFileCount = len(episodeFileCount)
episodes_count_conditions = [(TableEpisodes.sonarrSeriesId == item['sonarrSeriesId'])]
episodes_count_conditions += get_exclusion_clause('series')
episodeFileCount = TableEpisodes.select(TableShows.tags,
TableEpisodes.monitored,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, episodes_count_conditions))\
.count()
item.update({"episodeFileCount": episodeFileCount})
return jsonify(data=result, total=count)
@ -624,7 +666,11 @@ class Series(Resource):
except Exception:
return '', 400
database.execute("UPDATE table_shows SET profileId=? WHERE sonarrSeriesId=?", (profileId, seriesId))
TableShows.update({
TableShows.profileId: profileId
})\
.where(TableShows.sonarrSeriesId == seriesId)\
.execute()
list_missing_subtitles(no=seriesId, send_event=False)
@ -657,14 +703,16 @@ class Episodes(Resource):
episodeId = request.args.getlist('episodeid[]')
if len(episodeId) > 0:
result = database.execute(f"SELECT * FROM table_episodes WHERE sonarrEpisodeId in {convert_list_to_clause(episodeId)}")
result = TableEpisodes.select().where(TableEpisodes.sonarrEpisodeId.in_(episodeId)).dicts()
elif len(seriesId) > 0:
result = database.execute("SELECT * FROM table_episodes "
f"WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY season DESC, "
"episode DESC")
result = TableEpisodes.select()\
.where(TableEpisodes.sonarrSeriesId.in_(seriesId))\
.order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc())\
.dicts()
else:
return "Series or Episode ID not provided", 400
result = list(result)
for item in result:
postprocessEpisode(item)
@ -679,9 +727,13 @@ class EpisodesSubtitles(Resource):
def patch(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
episodeInfo = database.execute(
"SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarrEpisodeId,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
TableEpisodes.audio_language)\
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\
.dicts()\
.get()
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
@ -735,9 +787,13 @@ class EpisodesSubtitles(Resource):
def post(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
episodeInfo = database.execute(
"SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarrEpisodeId,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
TableEpisodes.audio_language)\
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\
.dicts()\
.get()
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
@ -789,9 +845,13 @@ class EpisodesSubtitles(Resource):
def delete(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
episodeInfo = database.execute(
"SELECT title, path, scene_name, audio_language FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarrEpisodeId,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
TableEpisodes.audio_language)\
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)\
.dicts()\
.get()
episodePath = path_mappings.path_replace(episodeInfo['path'])
@ -819,13 +879,16 @@ class Movies(Resource):
length = request.args.get('length') or -1
radarrId = request.args.getlist('radarrid[]')
count = database.execute("SELECT COUNT(*) as count FROM table_movies", only_one=True)['count']
count = TableMovies.select().count()
if len(radarrId) != 0:
result = database.execute(f"SELECT * FROM table_movies WHERE radarrId in {convert_list_to_clause(radarrId)} ORDER BY sortTitle ASC")
result = TableMovies.select()\
.where(TableMovies.radarrId.in_(radarrId))\
.order_by(TableMovies.sortTitle)\
.dicts()
else:
result = database.execute("SELECT * FROM table_movies ORDER BY sortTitle ASC LIMIT ? OFFSET ?",
(length, start))
result = TableMovies.select().order_by(TableMovies.sortTitle).limit(length).offset(start).dicts()
result = list(result)
for item in result:
postprocessMovie(item)
@ -848,7 +911,11 @@ class Movies(Resource):
except Exception:
return '', 400
database.execute("UPDATE table_movies SET profileId=? WHERE radarrId=?", (profileId, radarrId))
TableMovies.update({
TableMovies.profileId: profileId
})\
.where(TableMovies.radarrId == radarrId)\
.execute()
list_missing_subtitles_movies(no=radarrId)
@ -885,8 +952,13 @@ class MoviesSubtitles(Resource):
# Download
radarrId = request.args.get('radarrid')
movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?",
(radarrId,), only_one=True)
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language)\
.where(TableMovies.radarrId == radarrId)\
.dicts()\
.get()
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName']
@ -940,8 +1012,13 @@ class MoviesSubtitles(Resource):
# Upload
# TODO: Support Multiply Upload
radarrId = request.args.get('radarrid')
movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?",
(radarrId,), only_one=True)
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get()
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName']
@ -992,7 +1069,10 @@ class MoviesSubtitles(Resource):
def delete(self):
# Delete
radarrId = request.args.get('radarrid')
movieInfo = database.execute("SELECT path FROM table_movies WHERE radarrId=?", (radarrId,), only_one=True)
movieInfo = TableMovies.select(TableMovies.path) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get()
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
@ -1044,8 +1124,13 @@ class ProviderMovies(Resource):
def get(self):
# Manual Search
radarrId = request.args.get('radarrid')
movieInfo = database.execute("SELECT title, path, sceneName, profileId FROM table_movies WHERE radarrId=?",
(radarrId,), only_one=True)
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.profileId) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get()
title = movieInfo['title']
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
@ -1066,8 +1151,13 @@ class ProviderMovies(Resource):
def post(self):
# Manual Download
radarrId = request.args.get('radarrid')
movieInfo = database.execute("SELECT title, path, sceneName, audio_language FROM table_movies WHERE radarrId=?",
(radarrId,), only_one=True)
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get()
title = movieInfo['title']
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
@ -1121,19 +1211,19 @@ class ProviderEpisodes(Resource):
def get(self):
# Manual Search
sonarrEpisodeId = request.args.get('episodeid')
episodeInfo = database.execute(
"SELECT title, path, scene_name, audio_language, sonarrSeriesId FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarrEpisodeId,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
TableShows.profileId) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get()
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['scene_name']
seriesId = episodeInfo['sonarrSeriesId']
seriesInfo = database.execute("SELECT profileId FROM table_shows WHERE sonarrSeriesId=?", (seriesId,),
only_one=True)
profileId = seriesInfo['profileId']
profileId = episodeInfo['profileId']
if sceneName is None: sceneName = "None"
providers_list = get_providers()
@ -1150,8 +1240,12 @@ class ProviderEpisodes(Resource):
# Manual Download
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
episodeInfo = database.execute("SELECT title, path, scene_name FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarrEpisodeId,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get()
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
@ -1218,14 +1312,22 @@ class EpisodesHistory(Resource):
else:
query_actions = [1, 3]
upgradable_episodes = database.execute(
"SELECT video_path, MAX(timestamp) as timestamp, score, table_shows.tags, table_episodes.monitored, "
"table_shows.seriesType FROM table_history INNER JOIN table_episodes on "
"table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId INNER JOIN table_shows on "
"table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE action IN (" +
','.join(map(str, query_actions)) + ") AND timestamp > ? AND score is not null" +
get_exclusion_clause('series') + " GROUP BY table_history.video_path", (minimum_timestamp,))
upgradable_episodes_conditions = [(TableHistory.action.in_(query_actions)),
(TableHistory.timestamp > minimum_timestamp),
(TableHistory.score is not None)]
upgradable_episodes_conditions += get_exclusion_clause('series')
upgradable_episodes = TableHistory.select(TableHistory.video_path,
fn.MAX(TableHistory.timestamp).alias('timestamp'),
TableHistory.score,
TableShows.tags,
TableEpisodes.monitored,
TableShows.seriesType)\
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\
.join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, upgradable_episodes_conditions))\
.group_by(TableHistory.video_path)\
.dicts()
upgradable_episodes = list(upgradable_episodes)
for upgradable_episode in upgradable_episodes:
if upgradable_episode['timestamp'] > minimum_timestamp:
try:
@ -1236,24 +1338,38 @@ class EpisodesHistory(Resource):
if int(upgradable_episode['score']) < 360:
upgradable_episodes_not_perfect.append(upgradable_episode)
# TODO: Find a better solution
query_limit = ""
query_conditions = [(TableEpisodes.title is not None)]
if episodeid:
query_limit = f"AND table_episodes.sonarrEpisodeId={episodeid}"
episode_history = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, "
"table_episodes.season || 'x' || table_episodes.episode as episode_number, "
"table_episodes.title as episodeTitle, table_history.timestamp, table_history.subs_id, "
"table_history.description, table_history.sonarrSeriesId, table_episodes.path, "
"table_history.language, table_history.score, table_shows.tags, table_history.action, "
"table_history.subtitles_path, table_history.sonarrEpisodeId, table_history.provider, "
"table_shows.seriesType FROM table_history LEFT JOIN table_shows on "
"table_shows.sonarrSeriesId = table_history.sonarrSeriesId LEFT JOIN table_episodes on "
"table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId WHERE "
"table_episodes.title is not NULL " + query_limit + " ORDER BY timestamp DESC LIMIT ? OFFSET ?",
(length, start))
blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist ")
query_conditions.append((TableEpisodes.sonarrEpisodeId == episodeid))
query_condition = reduce(operator.and_, query_conditions)
episode_history = TableHistory.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableHistory.timestamp,
TableHistory.subs_id,
TableHistory.description,
TableHistory.sonarrSeriesId,
TableEpisodes.path,
TableHistory.language,
TableHistory.score,
TableShows.tags,
TableHistory.action,
TableHistory.subtitles_path,
TableHistory.sonarrEpisodeId,
TableHistory.provider,
TableShows.seriesType)\
.join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\
.where(query_condition)\
.order_by(TableHistory.timestamp.desc())\
.limit(length)\
.offset(start)\
.dicts()
episode_history = list(episode_history)
blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts()
blacklist_db = list(blacklist_db)
for item in episode_history:
# Mark episode as upgradable or not
@ -1281,14 +1397,14 @@ class EpisodesHistory(Resource):
item.update({"blacklisted": False})
if item['action'] not in [0, 4, 5]:
for blacklisted_item in blacklist_db:
if blacklisted_item['provider'] == item['provider'] and blacklisted_item['subs_id'] == item[
'subs_id']:
if blacklisted_item['provider'] == item['provider'] and \
blacklisted_item['subs_id'] == item['subs_id']:
item.update({"blacklisted": True})
break
count = database.execute("SELECT COUNT(*) as count FROM table_history LEFT JOIN table_episodes "
"on table_episodes.sonarrEpisodeId = table_history.sonarrEpisodeId WHERE "
"table_episodes.title is not NULL", only_one=True)['count']
count = TableHistory.select()\
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\
.where(TableEpisodes.title is not None).count()
return jsonify(data=episode_history, total=count)
@ -1312,11 +1428,20 @@ class MoviesHistory(Resource):
else:
query_actions = [1, 3]
upgradable_movies = database.execute(
"SELECT video_path, MAX(timestamp) as timestamp, score, tags, monitored FROM table_history_movie "
"INNER JOIN table_movies on table_movies.radarrId=table_history_movie.radarrId WHERE action IN (" +
','.join(map(str, query_actions)) + ") AND timestamp > ? AND score is not NULL" +
get_exclusion_clause('movie') + " GROUP BY video_path", (minimum_timestamp,))
upgradable_movies_conditions = [(TableHistoryMovie.action.in_(query_actions)),
(TableHistoryMovie.timestamp > minimum_timestamp),
(TableHistoryMovie.score is not None)]
upgradable_movies_conditions += get_exclusion_clause('movie')
upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path,
fn.MAX(TableHistoryMovie.timestamp).alias('timestamp'),
TableHistoryMovie.score,
TableMovies.tags,
TableMovies.monitored)\
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\
.where(reduce(operator.and_, upgradable_movies_conditions))\
.group_by(TableHistoryMovie.video_path)\
.dicts()
upgradable_movies = list(upgradable_movies)
for upgradable_movie in upgradable_movies:
if upgradable_movie['timestamp'] > minimum_timestamp:
@ -1328,22 +1453,34 @@ class MoviesHistory(Resource):
if int(upgradable_movie['score']) < 120:
upgradable_movies_not_perfect.append(upgradable_movie)
# TODO: Find a better solution
query_limit = ""
query_conditions = [(TableMovies is not None)]
if radarrid:
query_limit = f"AND table_movies.radarrid={radarrid}"
movie_history = database.execute(
"SELECT table_history_movie.action, table_movies.title, table_history_movie.timestamp, "
"table_history_movie.description, table_history_movie.radarrId, table_movies.monitored,"
"table_history_movie.video_path as path, table_history_movie.language, table_movies.tags, "
"table_history_movie.score, table_history_movie.subs_id, table_history_movie.provider, "
"table_history_movie.subtitles_path, table_history_movie.subtitles_path FROM "
"table_history_movie LEFT JOIN table_movies on table_movies.radarrId = "
"table_history_movie.radarrId WHERE table_movies.title is not NULL " + query_limit + " ORDER BY timestamp DESC LIMIT ? OFFSET ?",
(length, start))
blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist_movie")
query_conditions.append((TableMovies.radarrId == radarrid))
query_condition = reduce(operator.and_, query_conditions)
movie_history = TableHistoryMovie.select(TableHistoryMovie.action,
TableMovies.title,
TableHistoryMovie.timestamp,
TableHistoryMovie.description,
TableHistoryMovie.radarrId,
TableMovies.monitored,
TableHistoryMovie.video_path.alias('path'),
TableHistoryMovie.language,
TableMovies.tags,
TableHistoryMovie.score,
TableHistoryMovie.subs_id,
TableHistoryMovie.provider,
TableHistoryMovie.subtitles_path)\
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\
.where(query_condition)\
.order_by(TableHistoryMovie.timestamp.desc())\
.limit(length)\
.offset(start)\
.dicts()
movie_history = list(movie_history)
blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts()
blacklist_db = list(blacklist_db)
for item in movie_history:
# Mark movies as upgradable or not
@ -1375,9 +1512,10 @@ class MoviesHistory(Resource):
item.update({"blacklisted": True})
break
count = database.execute("SELECT COUNT(*) as count FROM table_history_movie LEFT JOIN table_movies on "
"table_movies.radarrId = table_history_movie.radarrId WHERE table_movies.title "
"is not NULL", only_one=True)['count']
count = TableHistoryMovie.select()\
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\
.where(TableMovies.title is not None)\
.count()
return jsonify(data=movie_history, total=count)
@ -1390,38 +1528,56 @@ class HistoryStats(Resource):
provider = request.args.get('provider') or 'All'
language = request.args.get('language') or 'All'
history_where_clause = " WHERE id"
# timeframe must be in ['week', 'month', 'trimester', 'year']
if timeframe == 'year':
days = 364
delay = 364 * 24 * 60 * 60
elif timeframe == 'trimester':
days = 90
delay = 90 * 24 * 60 * 60
elif timeframe == 'month':
days = 30
delay = 30 * 24 * 60 * 60
elif timeframe == 'week':
days = 6
delay = 6 * 24 * 60 * 60
now = time.time()
past = now - delay
history_where_clauses = [(TableHistory.timestamp.between(past, now))]
history_where_clauses_movie = [(TableHistoryMovie.timestamp.between(past, now))]
history_where_clause += " AND datetime(timestamp, 'unixepoch') BETWEEN datetime('now', '-" + str(days) + \
" days') AND datetime('now', 'localtime')"
if action != 'All':
history_where_clause += " AND action = " + action
history_where_clauses.append((TableHistory.action == action))
history_where_clauses_movie.append((TableHistoryMovie.action == action))
else:
history_where_clause += " AND action IN (1,2,3)"
history_where_clauses.append((TableHistory.action.in_([1, 2, 3])))
history_where_clauses_movie.append((TableHistoryMovie.action.in_([1, 2, 3])))
if provider != 'All':
history_where_clause += " AND provider = '" + provider + "'"
if language != 'All':
history_where_clause += " AND language = '" + language + "'"
history_where_clauses.append((TableHistory.provider == provider))
history_where_clauses_movie.append((TableHistoryMovie.provider == provider))
data_series = database.execute("SELECT strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch')) as date, "
"COUNT(id) as count FROM table_history" + history_where_clause +
" GROUP BY strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch'))")
data_movies = database.execute("SELECT strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch')) as date, "
"COUNT(id) as count FROM table_history_movie" + history_where_clause +
" GROUP BY strftime ('%Y-%m-%d',datetime(timestamp, 'unixepoch'))")
if language != 'All':
history_where_clauses.append((TableHistory.language == language))
history_where_clauses_movie.append((TableHistoryMovie.language == language))
history_where_clause = reduce(operator.and_, history_where_clauses)
history_where_clause_movie = reduce(operator.and_, history_where_clauses_movie)
data_series = TableHistory.select(fn.strftime('%Y-%m-%d', TableHistory.timestamp, 'unixepoch').alias('date'),
fn.COUNT(TableHistory.id).alias('count'))\
.where(history_where_clause) \
.group_by(fn.strftime('%Y-%m-%d', TableHistory.timestamp, 'unixepoch'))\
.dicts()
data_series = list(data_series)
data_movies = TableHistoryMovie.select(fn.strftime('%Y-%m-%d', TableHistoryMovie.timestamp, 'unixepoch').alias('date'),
fn.COUNT(TableHistoryMovie.id).alias('count')) \
.where(history_where_clause_movie) \
.group_by(fn.strftime('%Y-%m-%d', TableHistoryMovie.timestamp, 'unixepoch')) \
.dicts()
data_movies = list(data_movies)
for dt in rrule.rrule(rrule.DAILY,
dtstart=datetime.datetime.now() - datetime.timedelta(days=days),
dtstart=datetime.datetime.now() - datetime.timedelta(seconds=delay),
until=datetime.datetime.now()):
if not any(d['date'] == dt.strftime('%Y-%m-%d') for d in data_series):
data_series.append({'date': dt.strftime('%Y-%m-%d'), 'count': 0})
@ -1439,37 +1595,59 @@ class EpisodesWanted(Resource):
@authenticate
def get(self):
episodeid = request.args.getlist('episodeid[]')
wanted_conditions = [(TableEpisodes.missing_subtitles != '[]')]
if len(episodeid) > 0:
data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, "
"table_episodes.season || 'x' || table_episodes.episode as episode_number, "
"table_episodes.title as episodeTitle, table_episodes.missing_subtitles, "
"table_episodes.sonarrSeriesId, "
"table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, "
"table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN "
"table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE "
"table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') +
f" AND sonarrEpisodeId in {convert_list_to_clause(episodeid)}")
pass
wanted_conditions.append((TableEpisodes.sonarrEpisodeId in episodeid))
wanted_conditions += get_exclusion_clause('series')
wanted_condition = reduce(operator.and_, wanted_conditions)
if len(episodeid) > 0:
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.scene_name.alias('sceneName'),
TableShows.tags,
TableEpisodes.failedAttempts,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(wanted_condition)\
.dicts()
else:
start = request.args.get('start') or 0
length = request.args.get('length') or -1
data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, "
"table_episodes.season || 'x' || table_episodes.episode as episode_number, "
"table_episodes.title as episodeTitle, table_episodes.missing_subtitles, "
"table_episodes.sonarrSeriesId, "
"table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, "
"table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN "
"table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE "
"table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') +
" ORDER BY table_episodes._rowid_ DESC LIMIT ? OFFSET ?", (length, start))
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.scene_name.alias('sceneName'),
TableShows.tags,
TableEpisodes.failedAttempts,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(wanted_condition)\
.limit(length)\
.offset(start)\
.dicts()
data = list(data)
for item in data:
postprocessEpisode(item)
count = database.execute("SELECT COUNT(*) as count, table_shows.tags, table_shows.seriesType FROM "
"table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId = "
"table_episodes.sonarrSeriesId WHERE missing_subtitles != '[]'" +
get_exclusion_clause('series'), only_one=True)['count']
count_conditions = [(TableEpisodes.missing_subtitles != '[]')]
count_conditions += get_exclusion_clause('series')
count = TableEpisodes.select(TableShows.tags,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, count_conditions))\
.count()
return jsonify(data=data, total=count)
@ -1479,25 +1657,48 @@ class MoviesWanted(Resource):
@authenticate
def get(self):
radarrid = request.args.getlist("radarrid[]")
wanted_conditions = [(TableMovies.missing_subtitles != '[]')]
if len(radarrid) > 0:
result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, "
"failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" +
get_exclusion_clause('movie') +
f" AND radarrId in {convert_list_to_clause(radarrid)}")
pass
wanted_conditions.append((TableMovies.radarrId.in_(radarrid)))
wanted_conditions += get_exclusion_clause('movie')
wanted_condition = reduce(operator.and_, wanted_conditions)
if len(radarrid) > 0:
result = TableMovies.select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.tags,
TableMovies.monitored)\
.where(wanted_condition)\
.dicts()
else:
start = request.args.get('start') or 0
length = request.args.get('length') or -1
result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, "
"failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" +
get_exclusion_clause('movie') +
" ORDER BY _rowid_ DESC LIMIT ? OFFSET ?", (length, start))
result = TableMovies.select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.tags,
TableMovies.monitored)\
.where(wanted_condition)\
.order_by(TableMovies.radarrId.desc())\
.limit(length)\
.offset(start)\
.dicts()
result = list(result)
for item in result:
postprocessMovie(item)
count = database.execute("SELECT COUNT(*) as count FROM table_movies WHERE missing_subtitles != '[]'" +
get_exclusion_clause('movie'), only_one=True)['count']
count_conditions = [(TableMovies.missing_subtitles != '[]')]
count_conditions += get_exclusion_clause('movie')
count = TableMovies.select()\
.where(reduce(operator.and_, count_conditions))\
.count()
return jsonify(data=result, total=count)
@ -1511,14 +1712,21 @@ class EpisodesBlacklist(Resource):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.season || 'x' || "
"table_episodes.episode as episode_number, table_episodes.title as episodeTitle, "
"table_episodes.sonarrSeriesId, table_blacklist.provider, table_blacklist.subs_id, "
"table_blacklist.language, table_blacklist.timestamp FROM table_blacklist INNER JOIN "
"table_episodes on table_episodes.sonarrEpisodeId = table_blacklist.sonarr_episode_id "
"INNER JOIN table_shows on table_shows.sonarrSeriesId = "
"table_blacklist.sonarr_series_id ORDER BY table_blacklist.timestamp DESC LIMIT ? "
"OFFSET ?", (length, start))
data = TableBlacklist.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.sonarrSeriesId,
TableBlacklist.provider,
TableBlacklist.subs_id,
TableBlacklist.language,
TableBlacklist.timestamp)\
.join(TableEpisodes, on=(TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId))\
.join(TableShows, on=(TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId))\
.order_by(TableBlacklist.timestamp.desc())\
.limit(length)\
.offset(start)\
.dicts()
data = list(data)
for item in data:
# Make timestamp pretty
@ -1537,8 +1745,10 @@ class EpisodesBlacklist(Resource):
subs_id = request.form.get('subs_id')
language = request.form.get('language')
episodeInfo = database.execute("SELECT path FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarr_episode_id,), only_one=True)
episodeInfo = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\
.dicts()\
.get()
media_path = episodeInfo['path']
subtitles_path = request.form.get('subtitles_path')
@ -1580,12 +1790,18 @@ class MoviesBlacklist(Resource):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
data = database.execute("SELECT table_movies.title, table_movies.radarrId, table_blacklist_movie.provider, "
"table_blacklist_movie.subs_id, table_blacklist_movie.language, "
"table_blacklist_movie.timestamp FROM table_blacklist_movie INNER JOIN "
"table_movies on table_movies.radarrId = table_blacklist_movie.radarr_id "
"ORDER BY table_blacklist_movie.timestamp DESC LIMIT ? "
"OFFSET ?", (length, start))
data = TableBlacklistMovie.select(TableMovies.title,
TableMovies.radarrId,
TableBlacklistMovie.provider,
TableBlacklistMovie.subs_id,
TableBlacklistMovie.language,
TableBlacklistMovie.timestamp)\
.join(TableMovies, on=(TableBlacklistMovie.radarr_id == TableMovies.radarrId))\
.order_by(TableBlacklistMovie.timestamp.desc())\
.limit(length)\
.offset(start)\
.dicts()
data = list(data)
for item in data:
postprocessMovie(item)
@ -1606,7 +1822,7 @@ class MoviesBlacklist(Resource):
forced = False
hi = False
data = database.execute("SELECT path FROM table_movies WHERE radarrId=?", (radarr_id,), only_one=True)
data = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == radarr_id).dicts().get()
media_path = data['path']
subtitles_path = request.form.get('subtitles_path')
@ -1649,13 +1865,14 @@ class Subtitles(Resource):
if media_type == 'episode':
subtitles_path = path_mappings.path_replace(subtitles_path)
metadata = database.execute("SELECT path, sonarrSeriesId FROM table_episodes"
" WHERE sonarrEpisodeId = ?", (id,), only_one=True)
metadata = TableEpisodes.select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)\
.where(TableEpisodes.sonarrEpisodeId == id)\
.dicts()\
.get()
video_path = path_mappings.path_replace(metadata['path'])
else:
subtitles_path = path_mappings.path_replace_movie(subtitles_path)
metadata = database.execute("SELECT path FROM table_movies WHERE radarrId = ?",
(id,), only_one=True)
metadata = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == id).dicts().get()
video_path = path_mappings.path_replace_movie(metadata['path'])
if action == 'sync':

@ -1,88 +0,0 @@
BEGIN TRANSACTION;
CREATE TABLE "table_shows" (
`tvdbId` INTEGER NOT NULL UNIQUE,
`title` TEXT NOT NULL,
`path` TEXT NOT NULL UNIQUE,
`languages` TEXT,
`hearing_impaired` TEXT,
`sonarrSeriesId` INTEGER NOT NULL UNIQUE,
`overview` TEXT,
`poster` TEXT,
`fanart` TEXT,
`audio_language` "text",
`sortTitle` "text",
PRIMARY KEY(`tvdbId`)
);
CREATE TABLE "table_settings_providers" (
`name` TEXT NOT NULL UNIQUE,
`enabled` INTEGER,
`username` "text",
`password` "text",
PRIMARY KEY(`name`)
);
CREATE TABLE "table_settings_notifier" (
`name` TEXT,
`url` TEXT,
`enabled` INTEGER,
PRIMARY KEY(`name`)
);
CREATE TABLE "table_settings_languages" (
`code3` TEXT NOT NULL UNIQUE,
`code2` TEXT,
`name` TEXT NOT NULL,
`enabled` INTEGER,
`code3b` TEXT,
PRIMARY KEY(`code3`)
);
CREATE TABLE "table_history" (
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
`action` INTEGER NOT NULL,
`sonarrSeriesId` INTEGER NOT NULL,
`sonarrEpisodeId` INTEGER NOT NULL,
`timestamp` INTEGER NOT NULL,
`description` TEXT NOT NULL
);
CREATE TABLE "table_episodes" (
`sonarrSeriesId` INTEGER NOT NULL,
`sonarrEpisodeId` INTEGER NOT NULL UNIQUE,
`title` TEXT NOT NULL,
`path` TEXT NOT NULL,
`season` INTEGER NOT NULL,
`episode` INTEGER NOT NULL,
`subtitles` TEXT,
`missing_subtitles` TEXT,
`scene_name` TEXT,
`monitored` TEXT,
`failedAttempts` "text"
);
CREATE TABLE "table_movies" (
`tmdbId` TEXT NOT NULL UNIQUE,
`title` TEXT NOT NULL,
`path` TEXT NOT NULL UNIQUE,
`languages` TEXT,
`subtitles` TEXT,
`missing_subtitles` TEXT,
`hearing_impaired` TEXT,
`radarrId` INTEGER NOT NULL UNIQUE,
`overview` TEXT,
`poster` TEXT,
`fanart` TEXT,
`audio_language` "text",
`sceneName` TEXT,
`monitored` TEXT,
`failedAttempts` "text",
PRIMARY KEY(`tmdbId`)
);
CREATE TABLE "table_history_movie" (
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
`action` INTEGER NOT NULL,
`radarrId` INTEGER NOT NULL,
`timestamp` INTEGER NOT NULL,
`description` TEXT NOT NULL
);
CREATE TABLE "system" (
`configured` TEXT,
`updated` TEXT
);
INSERT INTO `system` (configured, updated) VALUES ('0', '0');
COMMIT;

@ -1,78 +1,258 @@
# coding=utf-8
import os
import atexit
import json
import ast
import sqlite3
import logging
import json
import re
from sqlite3worker import Sqlite3Worker
from peewee import *
from playhouse.sqliteq import SqliteQueueDatabase
from playhouse.shortcuts import model_to_dict
from get_args import args
from helper import path_mappings
from config import settings, get_array_from
from get_args import args
global profile_id_list
profile_id_list = []
def db_init():
if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')):
# Get SQL script from file
fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r')
script = fd.read()
# Close SQL script file
fd.close()
# Open database connection
db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30)
c = db.cursor()
# Execute script and commit change to database
c.executescript(script)
# Close database connection
db.close()
logging.info('BAZARR Database created successfully')
database = Sqlite3Worker(os.path.join(args.config_dir, 'db', 'bazarr.db'), max_queue_size=256, as_dict=True)
class SqliteDictConverter:
def __init__(self):
self.keys_insert = tuple()
self.keys_update = tuple()
self.values = tuple()
self.question_marks = tuple()
def convert(self, values_dict):
if type(values_dict) is dict:
self.keys_insert = tuple()
self.keys_update = tuple()
self.values = tuple()
self.question_marks = tuple()
temp_keys = list()
temp_values = list()
for item in values_dict.items():
temp_keys.append(item[0])
temp_values.append(item[1])
self.keys_insert = ','.join(temp_keys)
self.keys_update = ','.join([k + '=?' for k in temp_keys])
self.values = tuple(temp_values)
self.question_marks = ','.join(list('?'*len(values_dict)))
return self
else:
pass
dict_converter = SqliteDictConverter()
database = SqliteQueueDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db'),
use_gevent=True,
autostart=True,
queue_max_size=256)
@atexit.register
def _stop_worker_threads():
database.stop()
class UnknownField(object):
def __init__(self, *_, **__): pass
class BaseModel(Model):
class Meta:
database = database
class System(BaseModel):
configured = TextField(null=True)
updated = TextField(null=True)
class Meta:
table_name = 'system'
primary_key = False
class TableBlacklist(BaseModel):
language = TextField(null=True)
provider = TextField(null=True)
sonarr_episode_id = IntegerField(null=True)
sonarr_series_id = IntegerField(null=True)
subs_id = TextField(null=True)
timestamp = IntegerField(null=True)
class Meta:
table_name = 'table_blacklist'
primary_key = False
class TableBlacklistMovie(BaseModel):
language = TextField(null=True)
provider = TextField(null=True)
radarr_id = IntegerField(null=True)
subs_id = TextField(null=True)
timestamp = IntegerField(null=True)
class Meta:
table_name = 'table_blacklist_movie'
primary_key = False
class TableEpisodes(BaseModel):
audio_codec = TextField(null=True)
audio_language = TextField(null=True)
episode = IntegerField()
episode_file_id = IntegerField(null=True)
failedAttempts = TextField(null=True)
ffprobe_cache = BlobField(null=True)
file_size = IntegerField(default=0, null=True)
format = TextField(null=True)
missing_subtitles = TextField(null=True)
monitored = TextField(null=True)
path = TextField()
resolution = TextField(null=True)
scene_name = TextField(null=True)
season = IntegerField()
sonarrEpisodeId = IntegerField(unique=True)
sonarrSeriesId = IntegerField()
subtitles = TextField(null=True)
title = TextField()
video_codec = TextField(null=True)
class Meta:
table_name = 'table_episodes'
primary_key = False
class TableHistory(BaseModel):
action = IntegerField()
description = TextField()
id = AutoField()
language = TextField(null=True)
provider = TextField(null=True)
score = TextField(null=True)
sonarrEpisodeId = IntegerField()
sonarrSeriesId = IntegerField()
subs_id = TextField(null=True)
subtitles_path = TextField(null=True)
timestamp = IntegerField()
video_path = TextField(null=True)
class Meta:
table_name = 'table_history'
class TableHistoryMovie(BaseModel):
action = IntegerField()
description = TextField()
id = AutoField()
language = TextField(null=True)
provider = TextField(null=True)
radarrId = IntegerField()
score = TextField(null=True)
subs_id = TextField(null=True)
subtitles_path = TextField(null=True)
timestamp = IntegerField()
video_path = TextField(null=True)
class Meta:
table_name = 'table_history_movie'
class TableLanguagesProfiles(BaseModel):
cutoff = IntegerField(null=True)
items = TextField()
name = TextField()
profileId = AutoField()
class Meta:
table_name = 'table_languages_profiles'
class TableMovies(BaseModel):
alternativeTitles = TextField(null=True)
audio_codec = TextField(null=True)
audio_language = TextField(null=True)
failedAttempts = TextField(null=True)
fanart = TextField(null=True)
ffprobe_cache = BlobField(null=True)
file_size = IntegerField(default=0, null=True)
format = TextField(null=True)
imdbId = TextField(null=True)
missing_subtitles = TextField(null=True)
monitored = TextField(null=True)
movie_file_id = IntegerField(null=True)
overview = TextField(null=True)
path = TextField(unique=True)
poster = TextField(null=True)
profileId = IntegerField(null=True)
radarrId = IntegerField(unique=True)
resolution = TextField(null=True)
sceneName = TextField(null=True)
sortTitle = TextField(null=True)
subtitles = TextField(null=True)
tags = TextField(null=True)
title = TextField()
tmdbId = TextField(primary_key=True)
video_codec = TextField(null=True)
year = TextField(null=True)
class Meta:
table_name = 'table_movies'
class TableMoviesRootfolder(BaseModel):
accessible = IntegerField(null=True)
error = TextField(null=True)
id = IntegerField(null=True)
path = TextField(null=True)
class Meta:
table_name = 'table_movies_rootfolder'
primary_key = False
class TableSettingsLanguages(BaseModel):
code2 = TextField(null=True)
code3 = TextField(primary_key=True)
code3b = TextField(null=True)
enabled = IntegerField(null=True)
name = TextField()
class Meta:
table_name = 'table_settings_languages'
class TableSettingsNotifier(BaseModel):
enabled = IntegerField(null=True)
name = TextField(null=True, primary_key=True)
url = TextField(null=True)
class Meta:
table_name = 'table_settings_notifier'
class TableShows(BaseModel):
alternateTitles = TextField(null=True)
audio_language = TextField(null=True)
fanart = TextField(null=True)
imdbId = TextField(default='""', null=True)
overview = TextField(null=True)
path = TextField(unique=True)
poster = TextField(null=True)
profileId = IntegerField(null=True)
seriesType = TextField(null=True)
sonarrSeriesId = IntegerField(unique=True)
sortTitle = TextField(null=True)
tags = TextField(null=True)
title = TextField()
tvdbId = AutoField()
year = TextField(null=True)
class Meta:
table_name = 'table_shows'
class TableShowsRootfolder(BaseModel):
accessible = IntegerField(null=True)
error = TextField(null=True)
id = IntegerField(null=True)
path = TextField(null=True)
class Meta:
table_name = 'table_shows_rootfolder'
primary_key = False
# Create tables if they don't exists.
database.create_tables([System,
TableBlacklist,
TableBlacklistMovie,
TableEpisodes,
TableHistory,
TableHistoryMovie,
TableLanguagesProfiles,
TableMovies,
TableMoviesRootfolder,
TableSettingsLanguages,
TableSettingsNotifier,
TableShows,
TableShowsRootfolder])
class SqliteDictPathMapper:
def __init__(self):
pass
def path_replace(self, values_dict):
@staticmethod
def path_replace(values_dict):
if type(values_dict) is list:
for item in values_dict:
item['path'] = path_mappings.path_replace(item['path'])
@ -81,7 +261,8 @@ class SqliteDictPathMapper:
else:
return path_mappings.path_replace(values_dict)
def path_replace_movie(self, values_dict):
@staticmethod
def path_replace_movie(values_dict):
if type(values_dict) is list:
for item in values_dict:
item['path'] = path_mappings.path_replace_movie(item['path'])
@ -94,293 +275,49 @@ class SqliteDictPathMapper:
dict_mapper = SqliteDictPathMapper()
def db_upgrade():
columnToAdd = [
['table_shows', 'year', 'text'],
['table_shows', 'alternateTitles', 'text'],
['table_shows', 'tags', 'text', '[]'],
['table_shows', 'seriesType', 'text', ''],
['table_shows', 'imdbId', 'text', ''],
['table_shows', 'profileId', 'integer'],
['table_episodes', 'format', 'text'],
['table_episodes', 'resolution', 'text'],
['table_episodes', 'video_codec', 'text'],
['table_episodes', 'audio_codec', 'text'],
['table_episodes', 'episode_file_id', 'integer'],
['table_episodes', 'audio_language', 'text'],
['table_episodes', 'file_size', 'integer', '0'],
['table_episodes', 'ffprobe_cache', 'blob'],
['table_movies', 'sortTitle', 'text'],
['table_movies', 'year', 'text'],
['table_movies', 'alternativeTitles', 'text'],
['table_movies', 'format', 'text'],
['table_movies', 'resolution', 'text'],
['table_movies', 'video_codec', 'text'],
['table_movies', 'audio_codec', 'text'],
['table_movies', 'imdbId', 'text'],
['table_movies', 'movie_file_id', 'integer'],
['table_movies', 'tags', 'text', '[]'],
['table_movies', 'profileId', 'integer'],
['table_movies', 'file_size', 'integer', '0'],
['table_movies', 'ffprobe_cache', 'blob'],
['table_history', 'video_path', 'text'],
['table_history', 'language', 'text'],
['table_history', 'provider', 'text'],
['table_history', 'score', 'text'],
['table_history', 'subs_id', 'text'],
['table_history', 'subtitles_path', 'text'],
['table_history_movie', 'video_path', 'text'],
['table_history_movie', 'language', 'text'],
['table_history_movie', 'provider', 'text'],
['table_history_movie', 'score', 'text'],
['table_history_movie', 'subs_id', 'text'],
['table_history_movie', 'subtitles_path', 'text']
]
for column in columnToAdd:
try:
# Check if column already exist in table
columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0]))
columns_names_list = [x['name'] for x in columns_dict]
if column[1] in columns_names_list:
continue
# Creating the missing column
if len(column) == 3:
database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}"'''.format(column[0], column[1], column[2]))
else:
database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}" DEFAULT "{3}"'''.format(column[0], column[1], column[2], column[3]))
logging.debug('BAZARR Database upgrade process added column {0} to table {1}.'.format(column[1], column[0]))
except:
pass
# Create blacklist tables
database.execute("CREATE TABLE IF NOT EXISTS table_blacklist (sonarr_series_id integer, sonarr_episode_id integer, "
"timestamp integer, provider text, subs_id text, language text)")
database.execute("CREATE TABLE IF NOT EXISTS table_blacklist_movie (radarr_id integer, timestamp integer, "
"provider text, subs_id text, language text)")
# Create rootfolder tables
database.execute("CREATE TABLE IF NOT EXISTS table_shows_rootfolder (id integer, path text, accessible integer, "
"error text)")
database.execute("CREATE TABLE IF NOT EXISTS table_movies_rootfolder (id integer, path text, accessible integer, "
"error text)")
# Create languages profiles table and populate it
lang_table_content = database.execute("SELECT * FROM table_languages_profiles")
if isinstance(lang_table_content, list):
lang_table_exist = True
else:
lang_table_exist = False
database.execute("CREATE TABLE IF NOT EXISTS table_languages_profiles ("
"profileId INTEGER NOT NULL PRIMARY KEY, name TEXT NOT NULL, "
"cutoff INTEGER, items TEXT NOT NULL)")
if not lang_table_exist:
series_default = []
try:
for language in ast.literal_eval(settings.general.serie_default_language):
if settings.general.serie_default_forced == 'Both':
series_default.append([language, 'True', settings.general.serie_default_hi])
series_default.append([language, 'False', settings.general.serie_default_hi])
else:
series_default.append([language, settings.general.serie_default_forced,
settings.general.serie_default_hi])
except ValueError:
pass
movies_default = []
try:
for language in ast.literal_eval(settings.general.movie_default_language):
if settings.general.movie_default_forced == 'Both':
movies_default.append([language, 'True', settings.general.movie_default_hi])
movies_default.append([language, 'False', settings.general.movie_default_hi])
else:
movies_default.append([language, settings.general.movie_default_forced,
settings.general.movie_default_hi])
except ValueError:
pass
profiles_to_create = database.execute("SELECT DISTINCT languages, hearing_impaired, forced "
"FROM (SELECT languages, hearing_impaired, forced FROM table_shows "
"UNION ALL SELECT languages, hearing_impaired, forced FROM table_movies) "
"a WHERE languages NOT null and languages NOT IN ('None', '[]')")
if isinstance(profiles_to_create, list):
for profile in profiles_to_create:
profile_items = []
languages_list = ast.literal_eval(profile['languages'])
for i, language in enumerate(languages_list, 1):
if profile['forced'] == 'Both':
profile_items.append({'id': i, 'language': language, 'forced': 'True',
'hi': profile['hearing_impaired'], 'audio_exclude': 'False'})
profile_items.append({'id': i, 'language': language, 'forced': 'False',
'hi': profile['hearing_impaired'], 'audio_exclude': 'False'})
else:
profile_items.append({'id': i, 'language': language, 'forced': profile['forced'],
'hi': profile['hearing_impaired'], 'audio_exclude': 'False'})
# Create profiles
new_profile_name = profile['languages'] + ' (' + profile['hearing_impaired'] + '/' + profile['forced'] + ')'
database.execute("INSERT INTO table_languages_profiles (name, cutoff, items) VALUES("
"?,null,?)", (new_profile_name, json.dumps(profile_items),))
created_profile_id = database.execute("SELECT profileId FROM table_languages_profiles WHERE name = ?",
(new_profile_name,), only_one=True)['profileId']
# Assign profiles to series and movies
database.execute("UPDATE table_shows SET profileId = ? WHERE languages = ? AND hearing_impaired = ? AND "
"forced = ?", (created_profile_id, profile['languages'], profile['hearing_impaired'],
profile['forced']))
database.execute("UPDATE table_movies SET profileId = ? WHERE languages = ? AND hearing_impaired = ? AND "
"forced = ?", (created_profile_id, profile['languages'], profile['hearing_impaired'],
profile['forced']))
# Save new defaults
profile_items_list = []
for item in profile_items:
profile_items_list.append([item['language'], item['forced'], item['hi']])
try:
if created_profile_id and profile_items_list == series_default:
settings.general.serie_default_profile = str(created_profile_id)
except:
pass
try:
if created_profile_id and profile_items_list == movies_default:
settings.general.movie_default_profile = str(created_profile_id)
except:
pass
# null languages, forced and hearing_impaired for all series and movies
database.execute("UPDATE table_shows SET languages = null, forced = null, hearing_impaired = null")
database.execute("UPDATE table_movies SET languages = null, forced = null, hearing_impaired = null")
# Force series, episodes and movies sync with Sonarr to get all the audio track from video files
# Set environment variable that is going to be use during the init process to run sync once Bazarr is ready.
os.environ['BAZARR_AUDIO_PROFILES_MIGRATION'] = '1'
columnToRemove = [
['table_shows', 'languages'],
['table_shows', 'hearing_impaired'],
['table_shows', 'forced'],
['table_shows', 'sizeOnDisk'],
['table_episodes', 'file_ffprobe'],
['table_movies', 'languages'],
['table_movies', 'hearing_impaired'],
['table_movies', 'forced'],
['table_movies', 'file_ffprobe'],
]
for column in columnToRemove:
try:
# Check if column still exist in table
columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0]))
columns_names_list = [x['name'] for x in columns_dict]
if column[1] not in columns_names_list:
continue
table_name = column[0]
column_name = column[1]
tables_query = database.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
tables = [table['name'] for table in tables_query]
if table_name not in tables:
# Table doesn't exist in database. Skipping.
continue
columns_dict = database.execute('''PRAGMA table_info('{0}')'''.format(column[0]))
columns_names_list = [x['name'] for x in columns_dict]
if column_name in columns_names_list:
columns_names_list.remove(column_name)
columns_names_string = ', '.join(columns_names_list)
if not columns_names_list:
logging.debug("BAZARR No more columns in {}. We won't create an empty table. "
"Exiting.".format(table_name))
continue
else:
logging.debug("BAZARR Column {} doesn't exist in {}".format(column_name, table_name))
continue
# get original sql statement used to create the table
original_sql_statement = database.execute("SELECT sql FROM sqlite_master WHERE type='table' AND "
"name='{}'".format(table_name))[0]['sql']
# pretty format sql statement
original_sql_statement = original_sql_statement.replace('\n, ', ',\n\t')
original_sql_statement = original_sql_statement.replace('", "', '",\n\t"')
original_sql_statement = original_sql_statement.rstrip(')') + '\n'
# generate sql statement for temp table
table_regex = re.compile(r"CREATE TABLE \"{}\"".format(table_name))
column_regex = re.compile(r".+\"{}\".+\n".format(column_name))
new_sql_statement = table_regex.sub("CREATE TABLE \"{}_temp\"".format(table_name), original_sql_statement)
new_sql_statement = column_regex.sub("", new_sql_statement).rstrip('\n').rstrip(',') + '\n)'
# remove leftover temp table from previous execution
database.execute('DROP TABLE IF EXISTS {}_temp'.format(table_name))
# create new temp table
create_error = database.execute(new_sql_statement)
if create_error:
logging.debug('BAZARR cannot create temp table.')
continue
# validate if row insertion worked as expected
new_table_rows = database.execute('INSERT INTO {0}_temp({1}) SELECT {1} FROM {0}'.format(table_name,
columns_names_string))
previous_table_rows = database.execute('SELECT COUNT(*) as count FROM {}'.format(table_name),
only_one=True)['count']
if new_table_rows == previous_table_rows:
drop_error = database.execute('DROP TABLE {}'.format(table_name))
if drop_error:
logging.debug('BAZARR cannot drop {} table before renaming the temp table'.format(table_name))
continue
else:
rename_error = database.execute('ALTER TABLE {0}_temp RENAME TO {0}'.format(table_name))
if rename_error:
logging.debug('BAZARR cannot rename {}_temp table'.format(table_name))
else:
logging.debug('BAZARR cannot insert existing rows to {} table.'.format(table_name))
continue
except:
pass
def get_exclusion_clause(type):
where_clause = ''
if type == 'series':
def get_exclusion_clause(exclusion_type):
where_clause = []
if exclusion_type == 'series':
tagsList = ast.literal_eval(settings.sonarr.excluded_tags)
for tag in tagsList:
where_clause += ' AND table_shows.tags NOT LIKE "%\'' + tag + '\'%"'
where_clause.append(~(TableShows.tags ** tag))
else:
tagsList = ast.literal_eval(settings.radarr.excluded_tags)
for tag in tagsList:
where_clause += ' AND table_movies.tags NOT LIKE "%\'' + tag + '\'%"'
where_clause.append(~(TableMovies.tags ** tag))
if type == 'series':
if exclusion_type == 'series':
monitoredOnly = settings.sonarr.getboolean('only_monitored')
if monitoredOnly:
where_clause += ' AND table_episodes.monitored = "True"'
where_clause.append((TableEpisodes.monitored == 'True'))
else:
monitoredOnly = settings.radarr.getboolean('only_monitored')
if monitoredOnly:
where_clause += ' AND table_movies.monitored = "True"'
where_clause.append((TableMovies.monitored == 'True'))
if type == 'series':
if exclusion_type == 'series':
typesList = get_array_from(settings.sonarr.excluded_series_types)
for type in typesList:
where_clause += ' AND table_shows.seriesType != "' + type + '"'
for item in typesList:
where_clause.append((TableShows.seriesType != item))
return where_clause
def update_profile_id_list():
global profile_id_list
profile_id_list = database.execute("SELECT profileId, name, cutoff, items FROM table_languages_profiles")
profile_id_list = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId,
TableLanguagesProfiles.name,
TableLanguagesProfiles.cutoff,
TableLanguagesProfiles.items).dicts()
profile_id_list = list(profile_id_list)
for profile in profile_id_list:
profile['items'] = json.loads(profile['items'])
profile['items'] = json.loads(profile['items'])
def get_profiles_list(profile_id=None):
if not len(profile_id_list):
try:
len(profile_id_list)
except NameError:
update_profile_id_list()
if profile_id and profile_id != 'null':
@ -452,50 +389,29 @@ def get_audio_profile_languages(series_id=None, episode_id=None, movie_id=None):
audio_languages = []
if series_id:
audio_languages_list_str = database.execute("SELECT audio_language FROM table_shows WHERE sonarrSeriesId=?",
(series_id,), only_one=True)['audio_language']
try:
audio_languages_list = ast.literal_eval(audio_languages_list_str)
except ValueError:
pass
else:
for language in audio_languages_list:
audio_languages.append(
{"name": language,
"code2": alpha2_from_language(language) or None,
"code3": alpha3_from_language(language) or None}
)
audio_languages_list_str = TableShows.get(TableShows.sonarrSeriesId == series_id).audio_language
elif episode_id:
audio_languages_list_str = database.execute("SELECT audio_language FROM table_episodes WHERE sonarrEpisodeId=?",
(episode_id,), only_one=True)['audio_language']
try:
audio_languages_list = ast.literal_eval(audio_languages_list_str)
except ValueError:
pass
else:
for language in audio_languages_list:
audio_languages.append(
{"name": language,
"code2": alpha2_from_language(language) or None,
"code3": alpha3_from_language(language) or None}
)
audio_languages_list_str = TableEpisodes.get(TableEpisodes.sonarrEpisodeId == episode_id).audio_language
elif movie_id:
audio_languages_list_str = database.execute("SELECT audio_language FROM table_movies WHERE radarrId=?",
(movie_id,), only_one=True)['audio_language']
try:
audio_languages_list = ast.literal_eval(audio_languages_list_str)
except ValueError:
pass
else:
for language in audio_languages_list:
audio_languages.append(
{"name": language,
"code2": alpha2_from_language(language) or None,
"code3": alpha3_from_language(language) or None}
)
audio_languages_list_str = TableMovies.get(TableMovies.radarrId == movie_id).audio_language
else:
return audio_languages
try:
audio_languages_list = ast.literal_eval(audio_languages_list_str)
except ValueError:
pass
else:
for language in audio_languages_list:
audio_languages.append(
{"name": language,
"code2": alpha2_from_language(language) or None,
"code3": alpha3_from_language(language) or None}
)
return audio_languages
def convert_list_to_clause(arr: list):
if isinstance(arr, list):
return f"({','.join(str(x) for x in arr)})"

@ -7,7 +7,7 @@ from knowit import api
import enzyme
from enzyme.exceptions import MalformedMKVError
from enzyme.exceptions import MalformedMKVError
from database import database
from database import TableEpisodes, TableMovies
_FFPROBE_SPECIAL_LANGS = {
"zho": {
@ -77,27 +77,30 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No
# Get the actual cache value form database
if episode_file_id:
cache_key = database.execute('SELECT ffprobe_cache FROM table_episodes WHERE episode_file_id=? AND file_size=?',
(episode_file_id, file_size), only_one=True)
cache_key = TableEpisodes.select(TableEpisodes.ffprobe_cache)\
.where((TableEpisodes.episode_file_id == episode_file_id) and
(TableEpisodes.file_size == file_size))\
.dicts()\
.get()
elif movie_file_id:
cache_key = database.execute('SELECT ffprobe_cache FROM table_movies WHERE movie_file_id=? AND file_size=?',
(movie_file_id, file_size), only_one=True)
cache_key = TableMovies.select(TableMovies.ffprobe_cache)\
.where(TableMovies.movie_file_id == movie_file_id and
TableMovies.file_size == file_size)\
.dicts()\
.get()
else:
cache_key = None
# check if we have a value for that cache key
if not isinstance(cache_key, dict):
return data
try:
# Unpickle ffprobe cache
cached_value = pickle.loads(cache_key['ffprobe_cache'])
except:
pass
else:
try:
# Unpickle ffprobe cache
cached_value = pickle.loads(cache_key['ffprobe_cache'])
except:
pass
else:
# Check if file size and file id matches and if so, we return the cached value
if cached_value['file_size'] == file_size and cached_value['file_id'] in [episode_file_id, movie_file_id]:
return cached_value
# Check if file size and file id matches and if so, we return the cached value
if cached_value['file_size'] == file_size and cached_value['file_id'] in [episode_file_id, movie_file_id]:
return cached_value
# if not, we retrieve the metadata from the file
from utils import get_binary
@ -122,9 +125,11 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No
# we write to db the result and return the newly cached ffprobe dict
if episode_file_id:
database.execute('UPDATE table_episodes SET ffprobe_cache=? WHERE episode_file_id=?',
(pickle.dumps(data, pickle.HIGHEST_PROTOCOL), episode_file_id))
TableEpisodes.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\
.where(TableEpisodes.episode_file_id == episode_file_id)\
.execute()
elif movie_file_id:
database.execute('UPDATE table_movies SET ffprobe_cache=? WHERE movie_file_id=?',
(pickle.dumps(data, pickle.HIGHEST_PROTOCOL), movie_file_id))
TableMovies.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\
.where(TableMovies.movie_file_id == movie_file_id)\
.execute()
return data

@ -3,7 +3,7 @@
import os
import requests
import logging
from database import database, dict_converter, get_exclusion_clause
from database import get_exclusion_clause, TableEpisodes, TableShows
from config import settings, url_sonarr
from helper import path_mappings
@ -24,11 +24,11 @@ def sync_episodes(series_id=None, send_event=True):
apikey_sonarr = settings.sonarr.apikey
# Get current episodes id in DB
if series_id:
current_episodes_db = database.execute("SELECT sonarrEpisodeId, path, sonarrSeriesId FROM table_episodes WHERE "
"sonarrSeriesId = ?", (series_id,))
else:
current_episodes_db = database.execute("SELECT sonarrEpisodeId, path, sonarrSeriesId FROM table_episodes")
current_episodes_db = TableEpisodes.select(TableEpisodes.sonarrEpisodeId,
TableEpisodes.path,
TableEpisodes.sonarrSeriesId)\
.where((TableEpisodes.sonarrSeriesId == series_id) if series_id else None)\
.dicts()
current_episodes_db_list = [x['sonarrEpisodeId'] for x in current_episodes_db]
@ -82,17 +82,31 @@ def sync_episodes(series_id=None, send_event=True):
removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr))
for removed_episode in removed_episodes:
episode_to_delete = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM table_episodes WHERE "
"sonarrEpisodeId=?", (removed_episode,), only_one=True)
database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (removed_episode,))
episode_to_delete = TableEpisodes.select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)\
.where(TableEpisodes.sonarrEpisodeId == removed_episode)\
.dicts()\
.get()
TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == removed_episode).execute()
if send_event:
event_stream(type='episode', action='delete', payload=episode_to_delete['sonarrEpisodeId'])
# Update existing episodes in DB
episode_in_db_list = []
episodes_in_db = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId, title, path, season, episode, "
"scene_name, monitored, format, resolution, video_codec, audio_codec, "
"episode_file_id, audio_language, file_size FROM table_episodes")
episodes_in_db = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.scene_name,
TableEpisodes.monitored,
TableEpisodes.format,
TableEpisodes.resolution,
TableEpisodes.video_codec,
TableEpisodes.audio_codec,
TableEpisodes.episode_file_id,
TableEpisodes.audio_language,
TableEpisodes.file_size).dicts()
for item in episodes_in_db:
episode_in_db_list.append(item)
@ -100,19 +114,15 @@ def sync_episodes(series_id=None, send_event=True):
episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list]
for updated_episode in episodes_to_update_list:
query = dict_converter.convert(updated_episode)
database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''',
query.values + (updated_episode['sonarrEpisodeId'],))
TableEpisodes.update(updated_episode).where(TableEpisodes.sonarrEpisodeId ==
updated_episode['sonarrEpisodeId']).execute()
altered_episodes.append([updated_episode['sonarrEpisodeId'],
updated_episode['path'],
updated_episode['sonarrSeriesId']])
# Insert new episodes in DB
for added_episode in episodes_to_add:
query = dict_converter.convert(added_episode)
result = database.execute(
'''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' + query.question_marks +
''')''', query.values)
result = TableEpisodes.insert(added_episode).on_conflict(action='IGNORE').execute()
if result > 0:
altered_episodes.append([added_episode['sonarrEpisodeId'],
added_episode['path'],
@ -134,8 +144,10 @@ def sync_one_episode(episode_id):
logging.debug('BAZARR syncing this specific episode from Sonarr: {}'.format(episode_id))
# Check if there's a row in database for this episode ID
existing_episode = database.execute('SELECT path FROM table_episodes WHERE sonarrEpisodeId = ?', (episode_id,),
only_one=True)
existing_episode = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrEpisodeId == episode_id)\
.dicts()\
.get()
try:
# Get episode data from sonarr api
@ -156,38 +168,34 @@ def sync_one_episode(episode_id):
# Remove episode from DB
if not episode and existing_episode:
database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (episode_id,))
TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == episode_id).execute()
event_stream(type='episode', action='delete', payload=int(episode_id))
logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace(
existing_episode['path'])))
existing_episode['path)'])))
return
# Update existing episodes in DB
elif episode and existing_episode:
query = dict_converter.convert(episode)
database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''',
query.values + (episode['sonarrEpisodeId'],))
TableEpisodes.update(episode).where(TableEpisodes.sonarrEpisodeId == episode.sonarrEpisodeId).execute()
event_stream(type='episode', action='update', payload=int(episode_id))
logging.debug('BAZARR updated this episode into the database:{}'.format(path_mappings.path_replace(
episode['path'])))
episode.path)))
# Insert new episodes in DB
elif episode and not existing_episode:
query = dict_converter.convert(episode)
database.execute('''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
TableEpisodes.insert(episode).on_conflict(action='IGNORE').execute()
event_stream(type='episode', action='update', payload=int(episode_id))
logging.debug('BAZARR inserted this episode into the database:{}'.format(path_mappings.path_replace(
episode['path'])))
episode.path)))
# Storing existing subtitles
logging.debug('BAZARR storing subtitles for this episode: {}'.format(path_mappings.path_replace(
episode['path'])))
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
episode.path)))
store_subtitles(episode.path, path_mappings.path_replace(episode.path))
# Downloading missing subtitles
logging.debug('BAZARR downloading missing subtitles for this episode: {}'.format(path_mappings.path_replace(
episode['path'])))
episode.path)))
episode_download_subtitles(episode_id)
@ -248,9 +256,7 @@ def episodeParser(episode):
if 'name' in item:
audio_language.append(item['name'])
else:
audio_language = database.execute("SELECT audio_language FROM table_shows WHERE "
"sonarrSeriesId=?", (episode['seriesId'],),
only_one=True)['audio_language']
audio_language = TableShows.get(TableShows == episode['seriesId']).audio_language
if 'mediaInfo' in episode['episodeFile']:
if 'videoCodec' in episode['episodeFile']['mediaInfo']:
@ -317,8 +323,9 @@ def get_series_from_sonarr_api(series_id, url, apikey_sonarr):
logging.exception("BAZARR Error trying to get series from Sonarr.")
return
else:
series_json = []
if series_id:
series_json = list(r.json())
series_json.append(r.json())
else:
series_json = r.json()
series_list = []

@ -3,7 +3,7 @@
import pycountry
from subzero.language import Language
from database import database
from database import database, TableSettingsLanguages
def load_language_in_db():
@ -13,21 +13,35 @@ def load_language_in_db():
if hasattr(lang, 'alpha_2')]
# Insert languages in database table
database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) VALUES (?, ?, ?)",
langs, execute_many=True)
database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) "
"VALUES ('pob', 'pb', 'Brazilian Portuguese')")
database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) "
"VALUES ('zht', 'zt', 'Chinese Traditional')")
TableSettingsLanguages.insert_many(langs,
fields=[TableSettingsLanguages.code3, TableSettingsLanguages.code2,
TableSettingsLanguages.name]) \
.on_conflict(action='IGNORE') \
.execute()
TableSettingsLanguages.insert({TableSettingsLanguages.code3: 'pob', TableSettingsLanguages.code2: 'pb',
TableSettingsLanguages.name: 'Brazilian Portuguese'}) \
.on_conflict(action='IGNORE') \
.execute()
# update/insert chinese languages
TableSettingsLanguages.update({TableSettingsLanguages.name: 'Chinese Simplified'}) \
.where(TableSettingsLanguages.code2 == 'zt')\
.execute()
TableSettingsLanguages.insert({TableSettingsLanguages.code3: 'zht', TableSettingsLanguages.code2: 'zt',
TableSettingsLanguages.name: 'Chinese Traditional'}) \
.on_conflict(action='IGNORE')\
.execute()
langs = [[lang.bibliographic, lang.alpha_3]
for lang in pycountry.languages
if hasattr(lang, 'alpha_2') and hasattr(lang, 'bibliographic')]
# Update languages in database table
database.execute("UPDATE table_settings_languages SET code3b=? WHERE code3=?", langs, execute_many=True)
for lang in langs:
TableSettingsLanguages.update({TableSettingsLanguages.code3b: lang[0]}) \
.where(TableSettingsLanguages.code3 == lang[1]) \
.execute()
# Create languages dictionary for faster conversion than calling database
create_languages_dict()
@ -35,10 +49,10 @@ def load_language_in_db():
def create_languages_dict():
global languages_dict
#replace chinese by chinese simplified
database.execute("UPDATE table_settings_languages SET name = 'Chinese Simplified' WHERE code3 = 'zho'")
languages_dict = database.execute("SELECT name, code2, code3, code3b FROM table_settings_languages")
languages_dict = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.code3,
TableSettingsLanguages.code3b).dicts()
def language_from_alpha2(lang):
@ -68,7 +82,8 @@ def alpha3_from_language(lang):
def get_language_set():
languages = database.execute("SELECT code3 FROM table_settings_languages WHERE enabled=1")
languages = TableSettingsLanguages.select(TableSettingsLanguages.code3) \
.where(TableSettingsLanguages.enabled == 1).dicts()
language_set = set()

@ -3,6 +3,8 @@
import os
import requests
import logging
import operator
from functools import reduce
from config import settings, url_radarr
from helper import path_mappings
@ -11,7 +13,7 @@ from list_subtitles import store_subtitles_movie, movies_full_scan_subtitles
from get_rootfolder import check_radarr_rootfolder
from get_subtitle import movies_download_subtitles
from database import database, dict_converter, get_exclusion_clause
from database import get_exclusion_clause, TableMovies
from event_handler import event_stream, show_progress, hide_progress
headers = {"User-Agent": os.environ["SZ_USER_AGENT"]}
@ -50,7 +52,7 @@ def update_movies(send_event=True):
return
else:
# Get current movies in DB
current_movies_db = database.execute("SELECT tmdbId, path, radarrId FROM table_movies")
current_movies_db = TableMovies.select(TableMovies.tmdbId, TableMovies.path, TableMovies.radarrId).dicts()
current_movies_db_list = [x['tmdbId'] for x in current_movies_db]
@ -101,14 +103,31 @@ def update_movies(send_event=True):
removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr))
for removed_movie in removed_movies:
database.execute("DELETE FROM table_movies WHERE tmdbId=?", (removed_movie,))
TableMovies.delete().where(TableMovies.tmdbId == removed_movie).execute()
# Update movies in DB
movies_in_db_list = []
movies_in_db = database.execute("SELECT radarrId, title, path, tmdbId, overview, poster, fanart, "
"audio_language, sceneName, monitored, sortTitle, year, "
"alternativeTitles, format, resolution, video_codec, audio_codec, imdbId,"
"movie_file_id, tags, file_size FROM table_movies")
movies_in_db = TableMovies.select(TableMovies.radarrId,
TableMovies.title,
TableMovies.path,
TableMovies.tmdbId,
TableMovies.overview,
TableMovies.poster,
TableMovies.fanart,
TableMovies.audio_language,
TableMovies.sceneName,
TableMovies.monitored,
TableMovies.sortTitle,
TableMovies.year,
TableMovies.alternativeTitles,
TableMovies.format,
TableMovies.resolution,
TableMovies.video_codec,
TableMovies.audio_codec,
TableMovies.imdbId,
TableMovies.movie_file_id,
TableMovies.tags,
TableMovies.file_size).dicts()
for item in movies_in_db:
movies_in_db_list.append(item)
@ -116,9 +135,7 @@ def update_movies(send_event=True):
movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list]
for updated_movie in movies_to_update_list:
query = dict_converter.convert(updated_movie)
database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE tmdbId = ?''',
query.values + (updated_movie['tmdbId'],))
TableMovies.update(updated_movie).where(TableMovies.tmdbId == updated_movie['tmdbId']).execute()
altered_movies.append([updated_movie['tmdbId'],
updated_movie['path'],
updated_movie['radarrId'],
@ -126,10 +143,7 @@ def update_movies(send_event=True):
# Insert new movies in DB
for added_movie in movies_to_add:
query = dict_converter.convert(added_movie)
result = database.execute(
'''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
result = TableMovies.insert(added_movie).on_conflict(action='IGNORE').execute()
if result > 0:
altered_movies.append([added_movie['tmdbId'],
added_movie['path'],
@ -151,8 +165,9 @@ def update_movies(send_event=True):
if len(altered_movies) <= 5:
logging.debug("BAZARR No more than 5 movies were added during this sync then we'll search for subtitles.")
for altered_movie in altered_movies:
data = database.execute("SELECT * FROM table_movies WHERE radarrId = ?" +
get_exclusion_clause('movie'), (altered_movie[2],), only_one=True)
conditions = [(TableMovies.radarrId == altered_movie[2])]
conditions += get_exclusion_clause('movie')
data = TableMovies.get(reduce(operator.and_, conditions))
if data:
movies_download_subtitles(data['radarrId'])
else:
@ -165,15 +180,15 @@ def update_one_movie(movie_id, action):
logging.debug('BAZARR syncing this specific movie from Radarr: {}'.format(movie_id))
# Check if there's a row in database for this movie ID
existing_movie = database.execute('SELECT path FROM table_movies WHERE radarrId = ?', (movie_id,), only_one=True)
existing_movie = TableMovies.get_or_none(TableMovies.radarrId == movie_id)
# Remove movie from DB
if action == 'deleted':
if existing_movie:
database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,))
TableMovies.delete().where(TableMovies.radarrId == movie_id).execute()
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
existing_movie.path)))
return
radarr_version = get_radarr_version()
@ -215,26 +230,22 @@ def update_one_movie(movie_id, action):
# Remove movie from DB
if not movie and existing_movie:
database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,))
TableMovies.delete().where(TableMovies.radarrId == movie_id).execute()
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
existing_movie.path)))
return
# Update existing movie in DB
elif movie and existing_movie:
query = dict_converter.convert(movie)
database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE radarrId = ?''',
query.values + (movie['radarrId'],))
TableMovies.update(movie).where(TableMovies.radarrId == movie['radarrId']).execute()
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))
# Insert new movie in DB
elif movie and not existing_movie:
query = dict_converter.convert(movie)
database.execute('''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
TableMovies.insert(movie).on_conflict(action='IGNORE').execute()
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR inserted this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))

@ -6,7 +6,7 @@ import logging
from config import settings, url_sonarr, url_radarr
from helper import path_mappings
from database import database
from database import TableShowsRootfolder, TableMoviesRootfolder
headers = {"User-Agent": os.environ["SZ_USER_AGENT"]}
@ -32,7 +32,7 @@ def get_sonarr_rootfolder():
else:
for folder in rootfolder.json():
sonarr_rootfolder.append({'id': folder['id'], 'path': folder['path']})
db_rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder')
db_rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts()
rootfolder_to_remove = [x for x in db_rootfolder if not
next((item for item in sonarr_rootfolder if item['id'] == x['id']), False)]
rootfolder_to_update = [x for x in sonarr_rootfolder if
@ -41,26 +41,37 @@ def get_sonarr_rootfolder():
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
for item in rootfolder_to_remove:
database.execute('DELETE FROM table_shows_rootfolder WHERE id = ?', (item['id'],))
TableShowsRootfolder.delete().where(TableShowsRootfolder.id == item['id']).execute()
for item in rootfolder_to_update:
database.execute('UPDATE table_shows_rootfolder SET path=? WHERE id = ?', (item['path'], item['id']))
TableShowsRootfolder.update({TableShowsRootfolder.path: item['path']})\
.where(TableShowsRootfolder.id == item['id'])\
.execute()
for item in rootfolder_to_insert:
database.execute('INSERT INTO table_shows_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path']))
TableShowsRootfolder.insert({TableShowsRootfolder.id: item['id'], TableShowsRootfolder.path: item['path']})\
.execute()
def check_sonarr_rootfolder():
get_sonarr_rootfolder()
rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder')
rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts()
for item in rootfolder:
if not os.path.isdir(path_mappings.path_replace(item['path'])):
database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'This Sonarr root directory "
"does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?",
(item['id'],))
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0,
TableShowsRootfolder.error: 'This Sonarr root directory does not seems to '
'be accessible by Bazarr. Please check path '
'mapping.'})\
.where(TableShowsRootfolder.id == item['id'])\
.execute()
elif not os.access(path_mappings.path_replace(item['path']), os.W_OK):
database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'Bazarr cannot write to "
"this directory' WHERE id = ?", (item['id'],))
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0,
TableShowsRootfolder.error: 'Bazarr cannot write to this directory.'}) \
.where(TableShowsRootfolder.id == item['id']) \
.execute()
else:
database.execute("UPDATE table_shows_rootfolder SET accessible = 1, error = '' WHERE id = ?", (item['id'],))
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 1,
TableShowsRootfolder.error: ''}) \
.where(TableShowsRootfolder.id == item['id']) \
.execute()
def get_radarr_rootfolder():
@ -84,7 +95,7 @@ def get_radarr_rootfolder():
else:
for folder in rootfolder.json():
radarr_rootfolder.append({'id': folder['id'], 'path': folder['path']})
db_rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder')
db_rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts()
rootfolder_to_remove = [x for x in db_rootfolder if not
next((item for item in radarr_rootfolder if item['id'] == x['id']), False)]
rootfolder_to_update = [x for x in radarr_rootfolder if
@ -93,24 +104,33 @@ def get_radarr_rootfolder():
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
for item in rootfolder_to_remove:
database.execute('DELETE FROM table_movies_rootfolder WHERE id = ?', (item['id'],))
TableMoviesRootfolder.delete().where(TableMoviesRootfolder.id == item['id']).execute()
for item in rootfolder_to_update:
database.execute('UPDATE table_movies_rootfolder SET path=? WHERE id = ?', (item['path'], item['id']))
TableMoviesRootfolder.update({TableMoviesRootfolder.path: item['path']})\
.where(TableMoviesRootfolder.id == item['id']).execute()
for item in rootfolder_to_insert:
database.execute('INSERT INTO table_movies_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path']))
TableMoviesRootfolder.insert({TableMoviesRootfolder.id: item['id'],
TableMoviesRootfolder.path: item['path']}).execute()
def check_radarr_rootfolder():
get_radarr_rootfolder()
rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder')
rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts()
for item in rootfolder:
if not os.path.isdir(path_mappings.path_replace_movie(item['path'])):
database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'This Radarr root directory "
"does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?",
(item['id'],))
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0,
TableMoviesRootfolder.error: 'This Radarr root directory does not seems to '
'be accessible by Bazarr. Please check path '
'mapping.'}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()
elif not os.access(path_mappings.path_replace_movie(item['path']), os.W_OK):
database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'Bazarr cannot write to "
"this directory' WHERE id = ?", (item['id'],))
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0,
TableMoviesRootfolder.error: 'Bazarr cannot write to this directory'}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()
else:
database.execute("UPDATE table_movies_rootfolder SET accessible = 1, error = '' WHERE id = ?",
(item['id'],))
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 1,
TableMoviesRootfolder.error: ''}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()

@ -7,7 +7,7 @@ import logging
from config import settings, url_sonarr
from list_subtitles import list_missing_subtitles
from get_rootfolder import check_sonarr_rootfolder
from database import database, dict_converter
from database import TableShows
from utils import get_sonarr_version
from helper import path_mappings
from event_handler import event_stream, show_progress, hide_progress
@ -40,7 +40,7 @@ def update_series(send_event=True):
return
else:
# Get current shows in DB
current_shows_db = database.execute("SELECT sonarrSeriesId FROM table_shows")
current_shows_db = TableShows.select(TableShows.sonarrSeriesId).dicts()
current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db]
current_shows_sonarr = []
@ -81,15 +81,26 @@ def update_series(send_event=True):
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))
for series in removed_series:
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?",(series,))
TableShows.delete().where(TableShows.sonarrSeriesId == series).execute()
if send_event:
event_stream(type='series', action='delete', payload=series)
# Update existing series in DB
series_in_db_list = []
series_in_db = database.execute("SELECT title, path, tvdbId, sonarrSeriesId, overview, poster, fanart, "
"audio_language, sortTitle, year, alternateTitles, tags, seriesType, imdbId "
"FROM table_shows")
series_in_db = TableShows.select(TableShows.title,
TableShows.path,
TableShows.tvdbId,
TableShows.sonarrSeriesId,
TableShows.overview,
TableShows.poster,
TableShows.fanart,
TableShows.audio_language,
TableShows.sortTitle,
TableShows.year,
TableShows.alternateTitles,
TableShows.tags,
TableShows.seriesType,
TableShows.imdbId).dicts()
for item in series_in_db:
series_in_db_list.append(item)
@ -97,18 +108,14 @@ def update_series(send_event=True):
series_to_update_list = [i for i in series_to_update if i not in series_in_db_list]
for updated_series in series_to_update_list:
query = dict_converter.convert(updated_series)
database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''',
query.values + (updated_series['sonarrSeriesId'],))
TableShows.update(updated_series).where(TableShows.sonarrSeriesId ==
updated_series['sonarrSeriesId']).execute()
if send_event:
event_stream(type='series', payload=updated_series['sonarrSeriesId'])
# Insert new series in DB
for added_series in series_to_add:
query = dict_converter.convert(added_series)
result = database.execute(
'''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
result = TableShows.insert(added_series).on_conflict(action='IGNORE').execute()
if result:
list_missing_subtitles(no=added_series['sonarrSeriesId'])
else:
@ -125,8 +132,7 @@ def update_one_series(series_id, action):
logging.debug('BAZARR syncing this specific series from RSonarr: {}'.format(series_id))
# Check if there's a row in database for this series ID
existing_series = database.execute('SELECT path FROM table_shows WHERE sonarrSeriesId = ?', (series_id,),
only_one=True)
existing_series = TableShows.get_or_none(TableShows.sonarrSeriesId == series_id)
sonarr_version = get_sonarr_version()
serie_default_enabled = settings.general.getboolean('serie_default_enabled')
@ -149,7 +155,7 @@ def update_one_series(series_id, action):
series_data = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=settings.sonarr.apikey,
sonarr_series_id=series_id)
except requests.exceptions.HTTPError:
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,))
TableShows.delete().where(TableShows.sonarrSeriesId == series_id).execute()
event_stream(type='series', action='delete', payload=int(series_id))
return
@ -170,26 +176,22 @@ def update_one_series(series_id, action):
# Remove series from DB
if action == 'deleted':
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,))
TableShows.delete().where(TableShows.sonarrSeriesId == series_id).execute()
event_stream(type='series', action='delete', payload=int(series_id))
logging.debug('BAZARR deleted this series from the database:{}'.format(path_mappings.path_replace(
existing_series['path'])))
existing_series.path)))
return
# Update existing series in DB
elif action == 'updated' and existing_series:
query = dict_converter.convert(series)
database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''',
query.values + (series['sonarrSeriesId'],))
TableShows.update(series).where(TableShows.sonarrSeriesId == series['sonarrSeriesId']).execute()
event_stream(type='series', action='update', payload=int(series_id))
logging.debug('BAZARR updated this series into the database:{}'.format(path_mappings.path_replace(
series['path'])))
# Insert new series in DB
elif action == 'updated' and not existing_series:
query = dict_converter.convert(series)
database.execute('''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
TableShows.insert(series).on_conflict(action='IGNORE').execute()
event_stream(type='series', action='update', payload=int(series_id))
logging.debug('BAZARR inserted this series into the database:{}'.format(path_mappings.path_replace(
series['path'])))

@ -11,6 +11,9 @@ import codecs
import re
import subliminal
import copy
import operator
from functools import reduce
from peewee import fn
from datetime import datetime, timedelta
from subzero.language import Language
from subzero.video import parse_video
@ -31,8 +34,8 @@ from get_providers import get_providers, get_providers_auth, provider_throttle,
from knowit import api
from subsyncer import subsync
from guessit import guessit
from database import database, dict_mapper, get_exclusion_clause, get_profiles_list, get_audio_profile_languages, \
get_desired_languages
from database import dict_mapper, get_exclusion_clause, get_profiles_list, get_audio_profile_languages, \
get_desired_languages, TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie
from event_handler import event_stream, show_progress, hide_progress
from embedded_subs_reader import parse_video_metadata
@ -253,10 +256,11 @@ def download_subtitle(path, language, audio_language, hi, forced, providers, pro
downloaded_provider + " with a score of " + str(percent_score) + "%."
if media_type == 'series':
episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM "
"table_episodes WHERE path = ?",
(path_mappings.path_replace_reverse(path),),
only_one=True)
episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\
.dicts()\
.get()
series_id = episode_metadata['sonarrSeriesId']
episode_id = episode_metadata['sonarrEpisodeId']
sync_subtitles(video_path=path, srt_path=downloaded_path,
@ -265,9 +269,10 @@ def download_subtitle(path, language, audio_language, hi, forced, providers, pro
sonarr_series_id=episode_metadata['sonarrSeriesId'],
sonarr_episode_id=episode_metadata['sonarrEpisodeId'])
else:
movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?",
(path_mappings.path_replace_reverse_movie(path),),
only_one=True)
movie_metadata = TableMovies.select(TableMovies.radarrId)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()\
.get()
series_id = ""
episode_id = movie_metadata['radarrId']
sync_subtitles(video_path=path, srt_path=downloaded_path,
@ -580,10 +585,11 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl
downloaded_provider + " with a score of " + str(score) + "% using manual search."
if media_type == 'series':
episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM "
"table_episodes WHERE path = ?",
(path_mappings.path_replace_reverse(path),),
only_one=True)
episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\
.dicts()\
.get()
series_id = episode_metadata['sonarrSeriesId']
episode_id = episode_metadata['sonarrEpisodeId']
sync_subtitles(video_path=path, srt_path=downloaded_path,
@ -592,9 +598,10 @@ def manual_download_subtitle(path, language, audio_language, hi, forced, subtitl
sonarr_series_id=episode_metadata['sonarrSeriesId'],
sonarr_episode_id=episode_metadata['sonarrEpisodeId'])
else:
movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?",
(path_mappings.path_replace_reverse_movie(path),),
only_one=True)
movie_metadata = TableMovies.select(TableMovies.radarrId)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()\
.get()
series_id = ""
episode_id = movie_metadata['radarrId']
sync_subtitles(video_path=path, srt_path=downloaded_path,
@ -710,18 +717,20 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type
audio_language_code3 = alpha3_from_language(audio_language)
if media_type == 'series':
episode_metadata = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM table_episodes WHERE path = ?",
(path_mappings.path_replace_reverse(path),),
only_one=True)
episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\
.dicts()\
.get()
series_id = episode_metadata['sonarrSeriesId']
episode_id = episode_metadata['sonarrEpisodeId']
sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type,
percent_score=100, sonarr_series_id=episode_metadata['sonarrSeriesId'],
sonarr_episode_id=episode_metadata['sonarrEpisodeId'])
else:
movie_metadata = database.execute("SELECT radarrId FROM table_movies WHERE path = ?",
(path_mappings.path_replace_reverse_movie(path),),
only_one=True)
movie_metadata = TableMovies.select(TableMovies.radarrId)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()\
.get()
series_id = ""
episode_id = movie_metadata['radarrId']
sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type,
@ -746,13 +755,24 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type
def series_download_subtitles(no):
episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, monitored, "
"table_episodes.sonarrEpisodeId, table_episodes.scene_name, table_shows.tags, "
"table_shows.seriesType, table_episodes.audio_language, table_shows.title, "
"table_episodes.season, table_episodes.episode, table_episodes.title as episodeTitle "
"FROM table_episodes INNER JOIN table_shows on table_shows.sonarrSeriesId = "
"table_episodes.sonarrSeriesId WHERE table_episodes.sonarrSeriesId=? and "
"missing_subtitles!='[]'" + get_exclusion_clause('series'), (no,))
conditions = [(TableEpisodes.sonarrSeriesId == no),
(TableEpisodes.missing_subtitles != '[]')]
conditions += get_exclusion_clause('series')
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.scene_name,
TableShows.tags,
TableShows.seriesType,
TableEpisodes.audio_language,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle'))\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, conditions))\
.dicts()
if not episodes_details:
logging.debug("BAZARR no episode for that sonarrSeriesId can be found in database:", str(no))
return
@ -821,16 +841,23 @@ def series_download_subtitles(no):
hide_progress(id='series_search_progress_{}'.format(no))
def episode_download_subtitles(no):
episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, monitored, "
"table_episodes.sonarrEpisodeId, table_episodes.scene_name, table_shows.tags, "
"table_shows.title, table_shows.sonarrSeriesId, table_episodes.audio_language, "
"table_shows.seriesType FROM table_episodes LEFT JOIN table_shows on "
"table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId WHERE sonarrEpisodeId=?" +
get_exclusion_clause('series'), (no,))
conditions = [(TableEpisodes.sonarrEpisodeId == no)]
conditions += get_exclusion_clause('series')
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.scene_name,
TableShows.tags,
TableShows.title,
TableShows.sonarrSeriesId,
TableEpisodes.audio_language,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, conditions))\
.dicts()
if not episodes_details:
logging.debug("BAZARR no episode with that sonarrEpisodeId can be found in database:", str(no))
return
@ -882,9 +909,18 @@ def episode_download_subtitles(no):
def movies_download_subtitles(no):
movies = database.execute(
"SELECT path, missing_subtitles, audio_language, radarrId, sceneName, title, tags, "
"monitored FROM table_movies WHERE radarrId=?" + get_exclusion_clause('movie'), (no,))
conditions = [(TableMovies.radarrId == no)]
conditions += get_exclusion_clause('movie')
movies = TableMovies.select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.audio_language,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.title,
TableMovies.tags,
TableMovies.monitored)\
.where(reduce(operator.and_, conditions))\
.dicts()
if not len(movies):
logging.debug("BAZARR no movie with that radarrId can be found in database:", str(no))
return
@ -955,14 +991,19 @@ def movies_download_subtitles(no):
def wanted_download_subtitles(path):
episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, "
"table_episodes.sonarrEpisodeId, table_episodes.sonarrSeriesId, "
"table_episodes.audio_language, table_episodes.scene_name,"
"table_episodes.failedAttempts, table_shows.title "
"FROM table_episodes LEFT JOIN table_shows on "
"table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId "
"WHERE table_episodes.path=? and table_episodes.missing_subtitles!='[]'",
(path_mappings.path_replace_reverse(path),))
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sonarrSeriesId,
TableEpisodes.audio_language,
TableEpisodes.scene_name,
TableEpisodes.failedAttempts,
TableShows.title)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where((TableEpisodes.path == path_mappings.path_replace_reverse(path)) and
TableEpisodes.missing_subtitles != '[]')\
.dicts()
episodes_details = list(episodes_details)
providers_list = get_providers()
providers_auth = get_providers_auth()
@ -980,8 +1021,9 @@ def wanted_download_subtitles(path):
if language not in att:
attempt.append([language, time.time()])
database.execute("UPDATE table_episodes SET failedAttempts=? WHERE sonarrEpisodeId=?",
(str(attempt), episode['sonarrEpisodeId']))
TableEpisodes.update({TableEpisodes.failedAttempts: str(attempt)})\
.where(TableEpisodes.sonarrEpisodeId == episode['sonarrEpisodeId'])\
.execute()
for i in range(len(attempt)):
if attempt[i][0] == language:
@ -1028,10 +1070,17 @@ def wanted_download_subtitles(path):
def wanted_download_subtitles_movie(path):
movies_details = database.execute(
"SELECT path, missing_subtitles, radarrId, audio_language, sceneName, "
"failedAttempts, title FROM table_movies WHERE path = ? "
"AND missing_subtitles != '[]'", (path_mappings.path_replace_reverse_movie(path),))
movies_details = TableMovies.select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.audio_language,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.title)\
.where((TableMovies.path == path_mappings.path_replace_reverse_movie(path)) and
(TableMovies.missing_subtitles != '[]'))\
.dicts()
movies_details = list(movies_details)
providers_list = get_providers()
providers_auth = get_providers_auth()
@ -1049,8 +1098,9 @@ def wanted_download_subtitles_movie(path):
if language not in att:
attempt.append([language, time.time()])
database.execute("UPDATE table_movies SET failedAttempts=? WHERE radarrId=?",
(str(attempt), movie['radarrId']))
TableMovies.update({TableMovies.failedAttempts: str(attempt)})\
.where(TableMovies.radarrId == movie['radarrId'])\
.execute()
for i in range(len(attempt)):
if attempt[i][0] == language:
@ -1097,11 +1147,20 @@ def wanted_download_subtitles_movie(path):
def wanted_search_missing_subtitles_series():
episodes = database.execute("SELECT table_episodes.path, table_shows.tags, table_episodes.monitored, "
"table_shows.title, table_episodes.season, table_episodes.episode, table_episodes.title"
" as episodeTitle, table_shows.seriesType FROM table_episodes INNER JOIN table_shows on"
" table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE missing_subtitles !="
" '[]'" + get_exclusion_clause('series'))
conditions = [(TableEpisodes.missing_subtitles != '[]')]
conditions += get_exclusion_clause('series')
episodes = TableEpisodes.select(TableEpisodes.path,
TableShows.tags,
TableEpisodes.monitored,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle'),
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, conditions))\
.dicts()
episodes = list(episodes)
# path_replace
dict_mapper.path_replace(episodes)
@ -1135,8 +1194,15 @@ def wanted_search_missing_subtitles_series():
def wanted_search_missing_subtitles_movies():
movies = database.execute("SELECT path, tags, monitored, title FROM table_movies WHERE missing_subtitles != '[]'" +
get_exclusion_clause('movie'))
conditions = [(TableMovies.missing_subtitles != '[]')]
conditions + get_exclusion_clause('movie')
movies = TableMovies.select(TableMovies.path,
TableMovies.tags,
TableMovies.monitored,
TableMovies.title)\
.where(reduce(operator.and_, conditions))\
.dicts()
movies = list(movies)
# path_replace
dict_mapper.path_replace_movie(movies)
@ -1194,16 +1260,25 @@ def convert_to_guessit(guessit_key, attr_from_db):
def refine_from_db(path, video):
if isinstance(video, Episode):
data = database.execute(
"SELECT table_shows.title as seriesTitle, table_episodes.season, table_episodes.episode, "
"table_episodes.title as episodeTitle, table_shows.year, table_shows.tvdbId, "
"table_shows.alternateTitles, table_episodes.format, table_episodes.resolution, "
"table_episodes.video_codec, table_episodes.audio_codec, table_episodes.path, table_shows.imdbId "
"FROM table_episodes INNER JOIN table_shows on "
"table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId "
"WHERE table_episodes.path = ?", (path_mappings.path_replace_reverse(path),), only_one=True)
if data:
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle'),
TableShows.year,
TableShows.tvdbId,
TableShows.alternateTitles,
TableEpisodes.format,
TableEpisodes.resolution,
TableEpisodes.video_codec,
TableEpisodes.audio_codec,
TableEpisodes.path,
TableShows.imdbId)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))\
.dicts()
if len(data):
data = data[0]
video.series = re.sub(r'\s(\(\d\d\d\d\))', '', data['seriesTitle'])
video.season = int(data['season'])
video.episode = int(data['episode'])
@ -1224,11 +1299,19 @@ def refine_from_db(path, video):
if not video.audio_codec:
if data['audio_codec']: video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec'])
elif isinstance(video, Movie):
data = database.execute("SELECT title, year, alternativeTitles, format, resolution, video_codec, audio_codec, "
"imdbId FROM table_movies WHERE path = ?",
(path_mappings.path_replace_reverse_movie(path),), only_one=True)
if data:
data = TableMovies.select(TableMovies.title,
TableMovies.year,
TableMovies.alternativeTitles,
TableMovies.format,
TableMovies.resolution,
TableMovies.video_codec,
TableMovies.audio_codec,
TableMovies.imdbId)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()
if len(data):
data = data[0]
video.title = re.sub(r'\s(\(\d\d\d\d\))', '', data['title'])
# Commented out because Radarr provided so much bad year
# if data['year']:
@ -1250,11 +1333,15 @@ def refine_from_db(path, video):
def refine_from_ffprobe(path, video):
if isinstance(video, Movie):
file_id = database.execute("SELECT movie_file_id FROM table_shows WHERE path = ?",
(path_mappings.path_replace_reverse_movie(path),), only_one=True)
file_id = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()\
.get()
else:
file_id = database.execute("SELECT episode_file_id, file_size FROM table_episodes WHERE path = ?",
(path_mappings.path_replace_reverse(path),), only_one=True)
file_id = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\
.dicts()\
.get()
if not isinstance(file_id, dict):
return video
@ -1312,21 +1399,33 @@ def upgrade_subtitles():
query_actions = [1, 3]
if settings.general.getboolean('use_sonarr'):
upgradable_episodes = database.execute("SELECT table_history.video_path, table_history.language, "
"table_history.score, table_shows.tags, table_shows.profileId, "
"table_episodes.audio_language, table_episodes.scene_name, "
"table_episodes.title, table_episodes.sonarrSeriesId, table_history.action, "
"table_history.subtitles_path, table_episodes.sonarrEpisodeId, "
"MAX(table_history.timestamp) as timestamp, table_episodes.monitored, "
"table_episodes.season, table_episodes.episode, table_shows.title as seriesTitle, "
"table_shows.seriesType FROM table_history INNER JOIN table_shows on "
"table_shows.sonarrSeriesId = table_history.sonarrSeriesId INNER JOIN "
"table_episodes on table_episodes.sonarrEpisodeId = "
"table_history.sonarrEpisodeId WHERE action IN "
"(" + ','.join(map(str, query_actions)) + ") AND timestamp > ? AND "
"score is not null" + get_exclusion_clause('series') + " GROUP BY "
"table_history.video_path, table_history.language",
(minimum_timestamp,))
upgradable_episodes_conditions = [(TableHistory.action << query_actions),
(TableHistory.timestamp > minimum_timestamp),
(TableHistory.score is not None)]
upgradable_episodes_conditions += get_exclusion_clause('series')
upgradable_episodes = TableHistory.select(TableHistory.video_path,
TableHistory.language,
TableHistory.score,
TableShows.tags,
TableShows.profileId,
TableEpisodes.audio_language,
TableEpisodes.scene_name,
TableEpisodes.title,
TableEpisodes.sonarrSeriesId,
TableHistory.action,
TableHistory.subtitles_path,
TableEpisodes.sonarrEpisodeId,
fn.MAX(TableHistory.timestamp).alias('timestamp'),
TableEpisodes.monitored,
TableEpisodes.season,
TableEpisodes.episode,
TableShows.title.alias('seriesTitle'),
TableShows.seriesType)\
.join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId))\
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\
.where(reduce(operator.and_, upgradable_episodes_conditions))\
.group_by(TableHistory.video_path, TableHistory.language)\
.dicts()
upgradable_episodes_not_perfect = []
for upgradable_episode in upgradable_episodes:
if upgradable_episode['timestamp'] > minimum_timestamp:
@ -1347,17 +1446,25 @@ def upgrade_subtitles():
count_episode_to_upgrade = len(episodes_to_upgrade)
if settings.general.getboolean('use_radarr'):
upgradable_movies = database.execute("SELECT table_history_movie.video_path, table_history_movie.language, "
"table_history_movie.score, table_movies.profileId, table_history_movie.action, "
"table_history_movie.subtitles_path, table_movies.audio_language, "
"table_movies.sceneName, table_movies.title, table_movies.radarrId, "
"MAX(table_history_movie.timestamp) as timestamp, table_movies.tags, "
"table_movies.monitored FROM table_history_movie INNER JOIN table_movies "
"on table_movies.radarrId = table_history_movie.radarrId WHERE action IN "
"(" + ','.join(map(str, query_actions)) + ") AND timestamp > ? AND score "
"is not null" + get_exclusion_clause('movie') + " GROUP BY "
"table_history_movie.video_path, table_history_movie.language",
(minimum_timestamp,))
upgradable_movies_conditions = [(TableHistoryMovie.action << query_actions),
(TableHistoryMovie.timestamp > minimum_timestamp),
(TableHistoryMovie.score is not None)]
upgradable_movies_conditions += get_exclusion_clause('movie')
upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path,
TableHistoryMovie.language,
TableHistoryMovie.score,
TableMovies.profileId,
TableHistoryMovie.action,
TableHistoryMovie.subtitles_path,
TableMovies.audio_language,
TableMovies.sceneName,
fn.MAX(TableHistoryMovie.timestamp).alias('timestamp'),
TableMovies.tags,
TableMovies.radarrId)\
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId))\
.where(reduce(operator.and_, upgradable_movies_conditions))\
.group_by(TableHistoryMovie.video_path, TableHistoryMovie.language)\
.dicts()
upgradable_movies_not_perfect = []
for upgradable_movie in upgradable_movies:
if upgradable_movie['timestamp'] > minimum_timestamp:

@ -126,27 +126,6 @@ if os.path.isfile(package_info_file):
with open(os.path.join(args.config_dir, 'config', 'config.ini'), 'w+') as handle:
settings.write(handle)
# create database file
if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')):
import sqlite3
# Get SQL script from file
fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r')
script = fd.read()
# Close SQL script file
fd.close()
# Open database connection
db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30)
c = db.cursor()
# Execute script and commit change to database
c.executescript(script)
# Close database connection
db.close()
logging.info('BAZARR Database created successfully')
# upgrade database schema
from database import db_upgrade
db_upgrade()
# Configure dogpile file caching for Subliminal request
register_cache_backend("subzero.cache.file", "subzero.cache_backends.file", "SZFileBackend")
subliminal.region.configure('subzero.cache.file', expiration_time=datetime.timedelta(days=30),
@ -175,41 +154,6 @@ with open(os.path.normpath(os.path.join(args.config_dir, 'config', 'config.ini')
settings.write(handle)
# Commenting out the password reset process as it could be having unwanted effects and most of the users have already
# moved to new password hashing algorithm.
# Reset form login password for Bazarr after migration from 0.8.x to 0.9. Password will be equal to username.
# if settings.auth.type == 'form' and \
# os.path.exists(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))):
# username = False
# with open(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))) as json_file:
# try:
# data = json.load(json_file)
# username = next(iter(data))
# except:
# logging.error('BAZARR is unable to migrate credentials. You should disable login by modifying config.ini '
# 'file and settings [auth]-->type = None')
# if username:
# settings.auth.username = username
# settings.auth.password = hashlib.md5(username.encode('utf-8')).hexdigest()
# with open(os.path.join(args.config_dir, 'config', 'config.ini'), 'w+') as handle:
# settings.write(handle)
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json')))
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'roles.json')))
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'register.json')))
# logging.info('BAZARR your login credentials have been migrated successfully and your password is now equal '
# 'to your username. Please change it as soon as possible in settings.')
# else:
# if os.path.exists(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json'))):
# try:
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'users.json')))
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'roles.json')))
# os.remove(os.path.normpath(os.path.join(args.config_dir, 'config', 'register.json')))
# except:
# logging.error("BAZARR cannot delete those file. Please do it manually: users.json, roles.json, "
# "register.json")
def init_binaries():
from utils import get_binary
exe = get_binary("unrar")

@ -9,7 +9,7 @@ from guess_language import guess_language
from subliminal_patch import core, search_external_subtitles
from subzero.language import Language
from database import database, get_profiles_list, get_profile_cutoff
from database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, TableMovies
from get_languages import alpha2_from_alpha3, language_from_alpha2, get_language_set
from config import settings
from helper import path_mappings, get_subtitle_destination_folder
@ -31,8 +31,10 @@ def store_subtitles(original_path, reversed_path):
if settings.general.getboolean('use_embedded_subs'):
logging.debug("BAZARR is trying to index embedded subtitles.")
try:
item = database.execute('SELECT file_size, episode_file_id FROM table_episodes '
'WHERE path = ?', (original_path,), only_one=True)
item = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\
.where(TableEpisodes.path == original_path)\
.dicts()\
.get()
subtitle_languages = embedded_subs_reader(reversed_path,
file_size=item['file_size'],
episode_file_id=item['episode_file_id'])
@ -125,10 +127,12 @@ def store_subtitles(original_path, reversed_path):
logging.debug("BAZARR external subtitles detected: " + language_str)
actual_subtitles.append([language_str, path_mappings.path_replace_reverse(subtitle_path)])
database.execute("UPDATE table_episodes SET subtitles=? WHERE path=?",
(str(actual_subtitles), original_path))
matching_episodes = database.execute("SELECT sonarrEpisodeId, sonarrSeriesId FROM table_episodes WHERE path=?",
(original_path,))
TableEpisodes.update({TableEpisodes.subtitles: str(actual_subtitles)})\
.where(TableEpisodes.path == original_path)\
.execute()
matching_episodes = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId)\
.where(TableEpisodes.path == original_path)\
.dicts()
for episode in matching_episodes:
if episode:
@ -151,8 +155,10 @@ def store_subtitles_movie(original_path, reversed_path):
if settings.general.getboolean('use_embedded_subs'):
logging.debug("BAZARR is trying to index embedded subtitles.")
try:
item = database.execute('SELECT file_size, movie_file_id FROM table_movies '
'WHERE path = ?', (original_path,), only_one=True)
item = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\
.where(TableMovies.path == original_path)\
.dicts()\
.get()
subtitle_languages = embedded_subs_reader(reversed_path,
file_size=item['file_size'],
movie_file_id=item['movie_file_id'])
@ -237,9 +243,10 @@ def store_subtitles_movie(original_path, reversed_path):
logging.debug("BAZARR external subtitles detected: " + language_str)
actual_subtitles.append([language_str, path_mappings.path_replace_reverse_movie(subtitle_path)])
database.execute("UPDATE table_movies SET subtitles=? WHERE path=?",
(str(actual_subtitles), original_path))
matching_movies = database.execute("SELECT radarrId FROM table_movies WHERE path=?", (original_path,))
TableMovies.update({TableMovies.subtitles: str(actual_subtitles)})\
.where(TableMovies.path == original_path)\
.execute()
matching_movies = TableMovies.select(TableMovies.radarrId).where(TableMovies.path == original_path).dicts()
for movie in matching_movies:
if movie:
@ -257,16 +264,19 @@ def store_subtitles_movie(original_path, reversed_path):
def list_missing_subtitles(no=None, epno=None, send_event=True):
if epno is not None:
episodes_subtitles_clause = " WHERE table_episodes.sonarrEpisodeId=" + str(epno)
episodes_subtitles_clause = (TableEpisodes.sonarrEpisodeId == epno)
elif no is not None:
episodes_subtitles_clause = " WHERE table_episodes.sonarrSeriesId=" + str(no)
episodes_subtitles_clause = (TableEpisodes.sonarrSeriesId == no)
else:
episodes_subtitles_clause = ""
episodes_subtitles = database.execute("SELECT table_shows.sonarrSeriesId, table_episodes.sonarrEpisodeId, "
"table_episodes.subtitles, table_shows.profileId, "
"table_episodes.audio_language FROM table_episodes "
"LEFT JOIN table_shows on table_episodes.sonarrSeriesId = "
"table_shows.sonarrSeriesId" + episodes_subtitles_clause)
episodes_subtitles_clause = ()
episodes_subtitles = TableEpisodes.select(TableShows.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.subtitles,
TableShows.profileId,
TableEpisodes.audio_language)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(episodes_subtitles_clause)\
.dicts()
if isinstance(episodes_subtitles, str):
logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + episodes_subtitles)
return
@ -361,27 +371,26 @@ def list_missing_subtitles(no=None, epno=None, send_event=True):
missing_subtitles_text = str(missing_subtitles_output_list)
database.execute("UPDATE table_episodes SET missing_subtitles=? WHERE sonarrEpisodeId=?",
(missing_subtitles_text, episode_subtitles['sonarrEpisodeId']))
TableEpisodes.update({TableEpisodes.missing_subtitles: missing_subtitles_text})\
.where(TableEpisodes.sonarrEpisodeId == episode_subtitles['sonarrEpisodeId'])\
.execute()
if send_event:
event_stream(type='episode', payload=episode_subtitles['sonarrEpisodeId'])
event_stream(type='badges')
def list_missing_subtitles_movies(no=None, epno=None, send_event=True):
if no is not None:
movies_subtitles_clause = " WHERE radarrId=" + str(no)
else:
movies_subtitles_clause = ""
movies_subtitles = database.execute("SELECT radarrId, subtitles, profileId, audio_language FROM table_movies" +
movies_subtitles_clause)
def list_missing_subtitles_movies(no=None, send_event=True):
movies_subtitles = TableMovies.select(TableMovies.radarrId,
TableMovies.subtitles,
TableMovies.profileId,
TableMovies.audio_language)\
.where((TableMovies.radarrId == no) if no else None)\
.dicts()
if isinstance(movies_subtitles, str):
logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + movies_subtitles)
return
use_embedded_subs = settings.general.getboolean('use_embedded_subs')
for movie_subtitles in movies_subtitles:
@ -470,8 +479,9 @@ def list_missing_subtitles_movies(no=None, epno=None, send_event=True):
missing_subtitles_text = str(missing_subtitles_output_list)
database.execute("UPDATE table_movies SET missing_subtitles=? WHERE radarrId=?",
(missing_subtitles_text, movie_subtitles['radarrId']))
TableMovies.update({TableMovies.missing_subtitles: missing_subtitles_text})\
.where(TableMovies.radarrId == movie_subtitles['radarrId'])\
.execute()
if send_event:
event_stream(type='movie', payload=movie_subtitles['radarrId'])
@ -479,7 +489,7 @@ def list_missing_subtitles_movies(no=None, epno=None, send_event=True):
def series_full_scan_subtitles():
episodes = database.execute("SELECT path FROM table_episodes")
episodes = TableEpisodes.select(TableEpisodes.path).dicts()
count_episodes = len(episodes)
for i, episode in enumerate(episodes, 1):
@ -502,7 +512,7 @@ def series_full_scan_subtitles():
def movies_full_scan_subtitles():
movies = database.execute("SELECT path FROM table_movies")
movies = TableMovies.select(TableMovies.path).dicts()
count_movies = len(movies)
for i, movie in enumerate(movies, 1):
@ -525,15 +535,20 @@ def movies_full_scan_subtitles():
def series_scan_subtitles(no):
episodes = database.execute("SELECT path FROM table_episodes WHERE sonarrSeriesId=? ORDER BY sonarrEpisodeId",
(no,))
episodes = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrSeriesId == no)\
.order_by(TableEpisodes.sonarrEpisodeId)\
.dicts()
for episode in episodes:
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
def movies_scan_subtitles(no):
movies = database.execute("SELECT path FROM table_movies WHERE radarrId=? ORDER BY radarrId", (no,))
movies = TableMovies.select(TableMovies.path)\
.where(TableMovies.radarrId == no)\
.order_by(TableMovies.radarrId)\
.dicts()
for movie in movies:
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']))

@ -26,7 +26,7 @@ from get_args import args
from config import settings, url_sonarr, url_radarr, configure_proxy_func, base_url
from init import *
from database import database
from database import System
from notifier import update_notifier
@ -53,7 +53,7 @@ check_releases()
configure_proxy_func()
# Reset the updated once Bazarr have been restarted after an update
database.execute("UPDATE system SET updated='0'")
System.update({System.updated: '0'}).execute()
# Load languages in database
load_language_in_db()
@ -100,14 +100,14 @@ def catch_all(path):
auth = False
try:
updated = database.execute("SELECT updated FROM system", only_one=True)['updated']
updated = System.select().where(updated='1')
except:
updated = False
inject = dict()
inject["baseUrl"] = base_url
inject["canUpdate"] = not args.no_update
inject["hasUpdate"] = updated != '0'
inject["hasUpdate"] = len(updated)
if auth:
inject["apiKey"] = settings.auth.apikey
@ -164,7 +164,7 @@ def movies_images(url):
def configured():
database.execute("UPDATE system SET configured = 1")
System.update({System.configured: '1'}).execute()
@check_login

@ -3,7 +3,7 @@
import apprise
import logging
from database import database
from database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies
def update_notifier():
@ -16,7 +16,7 @@ def update_notifier():
notifiers_new = []
notifiers_old = []
notifiers_current_db = database.execute("SELECT name FROM table_settings_notifier")
notifiers_current_db = TableSettingsNotifier.select(TableSettingsNotifier.name).dicts()
notifiers_current = []
for notifier in notifiers_current_db:
@ -24,41 +24,51 @@ def update_notifier():
for x in results['schemas']:
if [x['service_name']] not in notifiers_current:
notifiers_new.append([x['service_name'], 0])
notifiers_new.append({'name': x['service_name'], 'enabled': 0})
logging.debug('Adding new notifier agent: ' + x['service_name'])
else:
notifiers_old.append([x['service_name']])
notifiers_to_delete = [item for item in notifiers_current if item not in notifiers_old]
database.execute("INSERT INTO table_settings_notifier (name, enabled) VALUES (?, ?)", notifiers_new,
execute_many=True)
TableSettingsNotifier.insert_many(notifiers_new).execute()
database.execute("DELETE FROM table_settings_notifier WHERE name=?", notifiers_to_delete, execute_many=True)
for item in notifiers_to_delete:
TableSettingsNotifier.delete().where(TableSettingsNotifier.name == item).execute()
def get_notifier_providers():
providers = database.execute("SELECT name, url FROM table_settings_notifier WHERE enabled=1")
providers = TableSettingsNotifier.select(TableSettingsNotifier.name,
TableSettingsNotifier.enabled)\
.where(TableSettingsNotifier.enabled == 1)\
.dicts()
return providers
def get_series(sonarr_series_id):
data = database.execute("SELECT title, year FROM table_shows WHERE sonarrSeriesId=?", (sonarr_series_id,),
only_one=True)
data = TableShows.select(TableShows.title, TableShows.year)\
.where(TableShows.sonarrSeriesId == sonarr_series_id)\
.dicts()\
.get()
return {'title': data['title'], 'year': data['year']}
def get_episode_name(sonarr_episode_id):
data = database.execute("SELECT title, season, episode FROM table_episodes WHERE sonarrEpisodeId=?",
(sonarr_episode_id,), only_one=True)
data = TableEpisodes.select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode)\
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\
.dicts()\
.get()
return data['title'], data['season'], data['episode']
def get_movie(radarr_id):
data = database.execute("SELECT title, year FROM table_movies WHERE radarrId=?", (radarr_id,), only_one=True)
data = TableMovies.select(TableMovies.title, TableMovies.year)\
.where(TableMovies.radarrId == radarr_id)\
.dicts()\
.get()
return {'title': data['title'], 'year': data['year']}

@ -56,8 +56,9 @@ class SonarrSignalrClient:
logging.info('BAZARR SignalR client for Sonarr is now disconnected.')
def restart(self):
if self.connection.is_open:
self.stop(log=False)
if self.connection:
if self.connection.is_open:
self.stop(log=False)
if settings.general.getboolean('use_sonarr'):
self.start()
@ -94,8 +95,9 @@ class RadarrSignalrClient:
self.connection.stop()
def restart(self):
if self.connection.transport.state.value in [0, 1, 2]:
self.stop()
if self.connection:
if self.connection.transport.state.value in [0, 1, 2]:
self.stop()
if settings.general.getboolean('use_radarr'):
self.start()
@ -131,35 +133,38 @@ class RadarrSignalrClient:
def dispatcher(data):
topic = media_id = action = None
episodesChanged = None
if isinstance(data, dict):
topic = data['name']
try:
media_id = data['body']['resource']['id']
action = data['body']['action']
if 'episodesChanged' in data['body']['resource']:
episodesChanged = data['body']['resource']['episodesChanged']
except KeyError:
return
elif isinstance(data, list):
topic = data[0]['name']
try:
media_id = data[0]['body']['resource']['id']
action = data[0]['body']['action']
except KeyError:
return
if topic == 'series':
update_one_series(series_id=media_id, action=action)
if episodesChanged:
# this will happen if an episode monitored status is changed.
sync_episodes(series_id=media_id, send_event=True)
elif topic == 'episode':
sync_one_episode(episode_id=media_id)
elif topic == 'movie':
update_one_movie(movie_id=media_id, action=action)
else:
try:
topic = media_id = action = None
episodesChanged = None
if isinstance(data, dict):
topic = data['name']
try:
media_id = data['body']['resource']['id']
action = data['body']['action']
if 'episodesChanged' in data['body']['resource']:
episodesChanged = data['body']['resource']['episodesChanged']
except KeyError:
return
elif isinstance(data, list):
topic = data[0]['name']
try:
media_id = data[0]['body']['resource']['id']
action = data[0]['body']['action']
except KeyError:
return
if topic == 'series':
update_one_series(series_id=media_id, action=action)
if episodesChanged:
# this will happen if an episode monitored status is changed.
sync_episodes(series_id=media_id, send_event=True)
elif topic == 'episode':
sync_one_episode(episode_id=media_id)
elif topic == 'movie':
update_one_movie(movie_id=media_id, action=action)
except Exception as e:
logging.debug('BAZARR an exception occurred while parsing SignalR feed: {}'.format(repr(e)))
finally:
return

@ -13,7 +13,8 @@ import stat
from whichcraft import which
from get_args import args
from config import settings, url_sonarr, url_radarr
from database import database
from database import TableHistory, TableHistoryMovie, TableBlacklist, TableBlacklistMovie, TableShowsRootfolder, \
TableMoviesRootfolder
from event_handler import event_stream
from get_languages import alpha2_from_alpha3, language_from_alpha3, language_from_alpha2, alpha3_from_alpha2
from helper import path_mappings
@ -37,51 +38,83 @@ class BinaryNotFound(Exception):
def history_log(action, sonarr_series_id, sonarr_episode_id, description, video_path=None, language=None, provider=None,
score=None, subs_id=None, subtitles_path=None):
database.execute("INSERT INTO table_history (action, sonarrSeriesId, sonarrEpisodeId, timestamp, description,"
"video_path, language, provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?,?)",
(action, sonarr_series_id, sonarr_episode_id, time.time(), description, video_path, language,
provider, score, subs_id, subtitles_path))
TableHistory.insert({
TableHistory.action: action,
TableHistory.sonarrSeriesId: sonarr_series_id,
TableHistory.sonarrEpisodeId: sonarr_episode_id,
TableHistory.timestamp: time.time(),
TableHistory.description: description,
TableHistory.video_path: video_path,
TableHistory.language: language,
TableHistory.provider: provider,
TableHistory.score: score,
TableHistory.subs_id: subs_id,
TableHistory.subtitles_path: subtitles_path
}).execute()
event_stream(type='episode-history')
def blacklist_log(sonarr_series_id, sonarr_episode_id, provider, subs_id, language):
database.execute("INSERT INTO table_blacklist (sonarr_series_id, sonarr_episode_id, timestamp, provider, "
"subs_id, language) VALUES (?,?,?,?,?,?)",
(sonarr_series_id, sonarr_episode_id, time.time(), provider, subs_id, language))
TableBlacklist.insert({
TableBlacklist.sonarr_series_id: sonarr_series_id,
TableBlacklist.sonarr_episode_id: sonarr_episode_id,
TableBlacklist.timestamp: time.time(),
TableBlacklist.provider: provider,
TableBlacklist.subs_id: subs_id,
TableBlacklist.language: language
}).execute()
event_stream(type='episode-blacklist')
def blacklist_delete(provider, subs_id):
database.execute("DELETE FROM table_blacklist WHERE provider=? AND subs_id=?", (provider, subs_id))
TableBlacklist.delete().where((TableBlacklist.provider == provider) and
(TableBlacklist.subs_id == subs_id))\
.execute()
event_stream(type='episode-blacklist', action='delete')
def blacklist_delete_all():
database.execute("DELETE FROM table_blacklist")
TableBlacklist.delete().execute()
event_stream(type='episode-blacklist', action='delete')
def history_log_movie(action, radarr_id, description, video_path=None, language=None, provider=None, score=None,
subs_id=None, subtitles_path=None):
database.execute("INSERT INTO table_history_movie (action, radarrId, timestamp, description, video_path, language, "
"provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?)",
(action, radarr_id, time.time(), description, video_path, language, provider, score, subs_id, subtitles_path))
TableHistoryMovie.insert({
TableHistoryMovie.action: action,
TableHistoryMovie.radarrId: radarr_id,
TableHistoryMovie.timestamp: time.time(),
TableHistoryMovie.description: description,
TableHistoryMovie.video_path: video_path,
TableHistoryMovie.language: language,
TableHistoryMovie.provider: provider,
TableHistoryMovie.score: score,
TableHistoryMovie.subs_id: subs_id,
TableHistoryMovie.subtitles_path: subtitles_path
}).execute()
event_stream(type='movie-history')
def blacklist_log_movie(radarr_id, provider, subs_id, language):
database.execute("INSERT INTO table_blacklist_movie (radarr_id, timestamp, provider, subs_id, language) "
"VALUES (?,?,?,?,?)", (radarr_id, time.time(), provider, subs_id, language))
TableBlacklistMovie.insert({
TableBlacklistMovie.radarr_id: radarr_id,
TableBlacklistMovie.timestamp: time.time(),
TableBlacklistMovie.provider: provider,
TableBlacklistMovie.subs_id: subs_id,
TableBlacklistMovie.language: language
}).execute()
event_stream(type='movie-blacklist')
def blacklist_delete_movie(provider, subs_id):
database.execute("DELETE FROM table_blacklist_movie WHERE provider=? AND subs_id=?", (provider, subs_id))
TableBlacklistMovie.delete().where((TableBlacklistMovie.provider == provider) and
(TableBlacklistMovie.subs_id == subs_id))\
.execute()
event_stream(type='movie-blacklist', action='delete')
def blacklist_delete_all_movie():
database.execute("DELETE FROM table_blacklist_movie")
TableBlacklistMovie.delete().execute()
event_stream(type='movie-blacklist', action='delete')
@ -167,9 +200,9 @@ def get_binary(name):
def get_blacklist(media_type):
if media_type == 'series':
blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist")
blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts()
else:
blacklist_db = database.execute("SELECT provider, subs_id FROM table_blacklist_movie")
blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts()
blacklist_list = []
for item in blacklist_db:
@ -427,15 +460,22 @@ def get_health_issues():
# get Sonarr rootfolder issues
if settings.general.getboolean('use_sonarr'):
rootfolder = database.execute('SELECT path, accessible, error FROM table_shows_rootfolder WHERE accessible = 0')
rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.path,
TableShowsRootfolder.accessible,
TableShowsRootfolder.error)\
.where(TableShowsRootfolder.accessible == 0)\
.dicts()
for item in rootfolder:
health_issues.append({'object': path_mappings.path_replace(item['path']),
'issue': item['error']})
# get Radarr rootfolder issues
if settings.general.getboolean('use_radarr'):
rootfolder = database.execute('SELECT path, accessible, error FROM table_movies_rootfolder '
'WHERE accessible = 0')
rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.path,
TableMoviesRootfolder.accessible,
TableMoviesRootfolder.error)\
.where(TableMoviesRootfolder.accessible == 0)\
.dicts()
for item in rootfolder:
health_issues.append({'object': path_mappings.path_replace_movie(item['path']),
'issue': item['error']})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,48 @@
## Playhouse
The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection.
### Vendor extensions
* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html)
* Full-text search (FTS3/4/5)
* BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4
* Virtual tables and C extensions
* Closure tables
* JSON extension support
* LSM1 (key/value database) support
* BLOB API
* Online backup API
* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver.
* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases.
* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/).
* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext)
* JSON and JSONB
* HStore
* Arrays
* Server-side cursors
* Full-text search
* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext)
### High-level libraries
* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields)
* Compressed field
* PickleField
* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts)
* Model-to-dict serializer
* Dict-to-model deserializer
* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid)
* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init.
* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/).
* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries.
### Database management and framework support
* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database.
* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite.
* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling.
* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection
* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database
* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications.
* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more.

@ -0,0 +1,73 @@
/* cache.h - definitions for the LRU cache
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CACHE_H
#define PYSQLITE_CACHE_H
#include "Python.h"
/* The LRU cache is implemented as a combination of a doubly-linked with a
* dictionary. The list items are of type 'Node' and the dictionary has the
* nodes as values. */
typedef struct _pysqlite_Node
{
PyObject_HEAD
PyObject* key;
PyObject* data;
long count;
struct _pysqlite_Node* prev;
struct _pysqlite_Node* next;
} pysqlite_Node;
typedef struct
{
PyObject_HEAD
int size;
/* a dictionary mapping keys to Node entries */
PyObject* mapping;
/* the factory callable */
PyObject* factory;
pysqlite_Node* first;
pysqlite_Node* last;
/* if set, decrement the factory function when the Cache is deallocated.
* this is almost always desirable, but not in the pysqlite context */
int decref_factory;
} pysqlite_Cache;
extern PyTypeObject pysqlite_NodeType;
extern PyTypeObject pysqlite_CacheType;
int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs);
void pysqlite_node_dealloc(pysqlite_Node* self);
int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs);
void pysqlite_cache_dealloc(pysqlite_Cache* self);
PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args);
int pysqlite_cache_setup_types(void);
#endif

@ -0,0 +1,129 @@
/* connection.h - definitions for the connection type
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CONNECTION_H
#define PYSQLITE_CONNECTION_H
#include "Python.h"
#include "pythread.h"
#include "structmember.h"
#include "cache.h"
#include "module.h"
#include "sqlite3.h"
typedef struct
{
PyObject_HEAD
sqlite3* db;
/* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a
* bitwise combination thereof makes sense */
int detect_types;
/* the timeout value in seconds for database locks */
double timeout;
/* for internal use in the timeout handler: when did the timeout handler
* first get called with count=0? */
double timeout_started;
/* None for autocommit, otherwise a PyString with the isolation level */
PyObject* isolation_level;
/* NULL for autocommit, otherwise a string with the BEGIN statement; will be
* freed in connection destructor */
char* begin_statement;
/* 1 if a check should be performed for each API call if the connection is
* used from the same thread it was created in */
int check_same_thread;
int initialized;
/* thread identification of the thread the connection was created in */
long thread_ident;
pysqlite_Cache* statement_cache;
/* Lists of weak references to statements and cursors used within this connection */
PyObject* statements;
PyObject* cursors;
/* Counters for how many statements/cursors were created in the connection. May be
* reset to 0 at certain intervals */
int created_statements;
int created_cursors;
PyObject* row_factory;
/* Determines how bytestrings from SQLite are converted to Python objects:
* - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings
* - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created.
* - PyString_Type: PyStrings are created as-is.
* - Any custom callable: Any object returned from the callable called with the bytestring
* as single parameter.
*/
PyObject* text_factory;
/* remember references to functions/classes used in
* create_function/create/aggregate, use these as dictionary keys, so we
* can keep the total system refcount constant by clearing that dictionary
* in connection_dealloc */
PyObject* function_pinboard;
/* a dictionary of registered collation name => collation callable mappings */
PyObject* collations;
/* Exception objects */
PyObject* Warning;
PyObject* Error;
PyObject* InterfaceError;
PyObject* DatabaseError;
PyObject* DataError;
PyObject* OperationalError;
PyObject* IntegrityError;
PyObject* InternalError;
PyObject* ProgrammingError;
PyObject* NotSupportedError;
} pysqlite_Connection;
extern PyTypeObject pysqlite_ConnectionType;
PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware);
void pysqlite_connection_dealloc(pysqlite_Connection* self);
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args);
PyObject* _pysqlite_connection_begin(pysqlite_Connection* self);
PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw);
int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor);
int pysqlite_check_thread(pysqlite_Connection* self);
int pysqlite_check_connection(pysqlite_Connection* con);
int pysqlite_connection_setup_types(void);
#endif

@ -0,0 +1,58 @@
/* module.h - definitions for the module
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_MODULE_H
#define PYSQLITE_MODULE_H
#include "Python.h"
#define PYSQLITE_VERSION "2.8.2"
extern PyObject* pysqlite_Error;
extern PyObject* pysqlite_Warning;
extern PyObject* pysqlite_InterfaceError;
extern PyObject* pysqlite_DatabaseError;
extern PyObject* pysqlite_InternalError;
extern PyObject* pysqlite_OperationalError;
extern PyObject* pysqlite_ProgrammingError;
extern PyObject* pysqlite_IntegrityError;
extern PyObject* pysqlite_DataError;
extern PyObject* pysqlite_NotSupportedError;
extern PyObject* pysqlite_OptimizedUnicode;
/* the functions time.time() and time.sleep() */
extern PyObject* time_time;
extern PyObject* time_sleep;
/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter
* functions, that convert the SQL value to the appropriate Python value.
* The key is uppercase.
*/
extern PyObject* converters;
extern int _enable_callback_tracebacks;
extern int pysqlite_BaseTypeAdapted;
#define PARSE_DECLTYPES 1
#define PARSE_COLNAMES 2
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,137 @@
import sys
from difflib import SequenceMatcher
from random import randint
IS_PY3K = sys.version_info[0] == 3
# String UDF.
def damerau_levenshtein_dist(s1, s2):
cdef:
int i, j, del_cost, add_cost, sub_cost
int s1_len = len(s1), s2_len = len(s2)
list one_ago, two_ago, current_row
list zeroes = [0] * (s2_len + 1)
if IS_PY3K:
current_row = list(range(1, s2_len + 2))
else:
current_row = range(1, s2_len + 2)
current_row[-1] = 0
one_ago = None
for i in range(s1_len):
two_ago = one_ago
one_ago = current_row
current_row = list(zeroes)
current_row[-1] = i + 1
for j in range(s2_len):
del_cost = one_ago[j] + 1
add_cost = current_row[j - 1] + 1
sub_cost = one_ago[j - 1] + (s1[i] != s2[j])
current_row[j] = min(del_cost, add_cost, sub_cost)
# Handle transpositions.
if (i > 0 and j > 0 and s1[i] == s2[j - 1]
and s1[i-1] == s2[j] and s1[i] != s2[j]):
current_row[j] = min(current_row[j], two_ago[j - 2] + 1)
return current_row[s2_len - 1]
# String UDF.
def levenshtein_dist(a, b):
cdef:
int add, delete, change
int i, j
int n = len(a), m = len(b)
list current, previous
list zeroes
if n > m:
a, b = b, a
n, m = m, n
zeroes = [0] * (m + 1)
if IS_PY3K:
current = list(range(n + 1))
else:
current = range(n + 1)
for i in range(1, m + 1):
previous = current
current = list(zeroes)
current[0] = i
for j in range(1, n + 1):
add = previous[j] + 1
delete = current[j - 1] + 1
change = previous[j - 1]
if a[j - 1] != b[i - 1]:
change +=1
current[j] = min(add, delete, change)
return current[n]
# String UDF.
def str_dist(a, b):
cdef:
int t = 0
for i in SequenceMatcher(None, a, b).get_opcodes():
if i[0] == 'equal':
continue
t = t + max(i[4] - i[3], i[2] - i[1])
return t
# Math Aggregate.
cdef class median(object):
cdef:
int ct
list items
def __init__(self):
self.ct = 0
self.items = []
cdef selectKth(self, int k, int s=0, int e=-1):
cdef:
int idx
if e < 0:
e = len(self.items)
idx = randint(s, e-1)
idx = self.partition_k(idx, s, e)
if idx > k:
return self.selectKth(k, s, idx)
elif idx < k:
return self.selectKth(k, idx + 1, e)
else:
return self.items[idx]
cdef int partition_k(self, int pi, int s, int e):
cdef:
int i, x
val = self.items[pi]
# Swap pivot w/last item.
self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1]
x = s
for i in range(s, e):
if self.items[i] < val:
self.items[i], self.items[x] = self.items[x], self.items[i]
x += 1
self.items[x], self.items[e-1] = self.items[e-1], self.items[x]
return x
def step(self, item):
self.items.append(item)
self.ct += 1
def finalize(self):
if self.ct == 0:
return None
elif self.ct < 3:
return self.items[0]
else:
return self.selectKth(self.ct / 2)

@ -0,0 +1,146 @@
"""
Peewee integration with APSW, "another python sqlite wrapper".
Project page: https://rogerbinns.github.io/apsw/
APSW is a really neat library that provides a thin wrapper on top of SQLite's
C interface.
Here are just a few reasons to use APSW, taken from the documentation:
* APSW gives all functionality of SQLite, including virtual tables, virtual
file system, blob i/o, backups and file control.
* Connections can be shared across threads without any additional locking.
* Transactions are managed explicitly by your code.
* APSW can handle nested transactions.
* Unicode is handled correctly.
* APSW is faster.
"""
import apsw
from peewee import *
from peewee import __exception_wrapper__
from peewee import BooleanField as _BooleanField
from peewee import DateField as _DateField
from peewee import DateTimeField as _DateTimeField
from peewee import DecimalField as _DecimalField
from peewee import TimeField as _TimeField
from peewee import logger
from playhouse.sqlite_ext import SqliteExtDatabase
class APSWDatabase(SqliteExtDatabase):
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
def __init__(self, database, **kwargs):
self._modules = {}
super(APSWDatabase, self).__init__(database, **kwargs)
def register_module(self, mod_name, mod_inst):
self._modules[mod_name] = mod_inst
if not self.is_closed():
self.connection().createmodule(mod_name, mod_inst)
def unregister_module(self, mod_name):
del(self._modules[mod_name])
def _connect(self):
conn = apsw.Connection(self.database, **self.connect_params)
if self._timeout is not None:
conn.setbusytimeout(self._timeout * 1000)
try:
self._add_conn_hooks(conn)
except:
conn.close()
raise
return conn
def _add_conn_hooks(self, conn):
super(APSWDatabase, self)._add_conn_hooks(conn)
self._load_modules(conn) # APSW-only.
def _load_modules(self, conn):
for mod_name, mod_inst in self._modules.items():
conn.createmodule(mod_name, mod_inst)
return conn
def _load_aggregates(self, conn):
for name, (klass, num_params) in self._aggregates.items():
def make_aggregate():
return (klass(), klass.step, klass.finalize)
conn.createaggregatefunction(name, make_aggregate)
def _load_collations(self, conn):
for name, fn in self._collations.items():
conn.createcollation(name, fn)
def _load_functions(self, conn):
for name, (fn, num_params) in self._functions.items():
conn.createscalarfunction(name, fn, num_params)
def _load_extensions(self, conn):
conn.enableloadextension(True)
for extension in self._extensions:
conn.loadextension(extension)
def load_extension(self, extension):
self._extensions.add(extension)
if not self.is_closed():
conn = self.connection()
conn.enableloadextension(True)
conn.loadextension(extension)
def last_insert_id(self, cursor, query_type=None):
return cursor.getconnection().last_insert_rowid()
def rows_affected(self, cursor):
return cursor.getconnection().changes()
def begin(self, lock_type='deferred'):
self.cursor().execute('begin %s;' % lock_type)
def commit(self):
with __exception_wrapper__:
curs = self.cursor()
if curs.getconnection().getautocommit():
return False
curs.execute('commit;')
return True
def rollback(self):
with __exception_wrapper__:
curs = self.cursor()
if curs.getconnection().getautocommit():
return False
curs.execute('rollback;')
return True
def execute_sql(self, sql, params=None, commit=True):
logger.debug((sql, params))
with __exception_wrapper__:
cursor = self.cursor()
cursor.execute(sql, params or ())
return cursor
def nh(s, v):
if v is not None:
return str(v)
class BooleanField(_BooleanField):
def db_value(self, v):
v = super(BooleanField, self).db_value(v)
if v is not None:
return v and 1 or 0
class DateField(_DateField):
db_value = nh
class TimeField(_TimeField):
db_value = nh
class DateTimeField(_DateTimeField):
db_value = nh
class DecimalField(_DecimalField):
db_value = nh

@ -0,0 +1,207 @@
import functools
import re
from peewee import *
from peewee import _atomic
from peewee import _manual
from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default)
from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table).
from peewee import IndexMetadata
from playhouse.pool import _PooledPostgresqlDatabase
try:
from playhouse.postgres_ext import ArrayField
from playhouse.postgres_ext import BinaryJSONField
from playhouse.postgres_ext import IntervalField
JSONField = BinaryJSONField
except ImportError: # psycopg2 not installed, ignore.
ArrayField = BinaryJSONField = IntervalField = JSONField = None
NESTED_TX_MIN_VERSION = 200100
TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may '
'alternatively use the @transaction context-manager/decorator, '
'which only wraps the outer-most block in transactional logic. '
'To run a transaction with automatic retries, use the '
'run_transaction() helper.')
class ExceededMaxAttempts(OperationalError): pass
class UUIDKeyField(UUIDField):
auto_increment = True
def __init__(self, *args, **kwargs):
if kwargs.get('constraints'):
raise ValueError('%s cannot specify constraints.' % type(self))
kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')]
kwargs.setdefault('primary_key', True)
super(UUIDKeyField, self).__init__(*args, **kwargs)
class RowIDField(AutoField):
field_type = 'INT'
def __init__(self, *args, **kwargs):
if kwargs.get('constraints'):
raise ValueError('%s cannot specify constraints.' % type(self))
kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')]
super(RowIDField, self).__init__(*args, **kwargs)
class CockroachDatabase(PostgresqlDatabase):
field_types = PostgresqlDatabase.field_types.copy()
field_types.update({
'BLOB': 'BYTES',
})
for_update = False
nulls_ordering = False
release_after_rollback = True
def __init__(self, *args, **kwargs):
kwargs.setdefault('user', 'root')
kwargs.setdefault('port', 26257)
super(CockroachDatabase, self).__init__(*args, **kwargs)
def _set_server_version(self, conn):
curs = conn.cursor()
curs.execute('select version()')
raw, = curs.fetchone()
match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw)
if match_obj is not None:
clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups())
self.server_version = int(clean) # 19.1.5 -> 190105.
else:
# Fallback to use whatever cockroachdb tells us via protocol.
super(CockroachDatabase, self)._set_server_version(conn)
def _get_pk_constraint(self, table, schema=None):
query = ('SELECT constraint_name '
'FROM information_schema.table_constraints '
'WHERE table_name = %s AND table_schema = %s '
'AND constraint_type = %s')
cursor = self.execute_sql(query, (table, schema or 'public',
'PRIMARY KEY'))
row = cursor.fetchone()
return row and row[0] or None
def get_indexes(self, table, schema=None):
# The primary-key index is returned by default, so we will just strip
# it out here.
indexes = super(CockroachDatabase, self).get_indexes(table, schema)
pkc = self._get_pk_constraint(table, schema)
return [idx for idx in indexes if (not pkc) or (idx.name != pkc)]
def conflict_statement(self, on_conflict, query):
if not on_conflict._action: return
action = on_conflict._action.lower()
if action in ('replace', 'upsert'):
return SQL('UPSERT')
elif action not in ('ignore', 'nothing', 'update'):
raise ValueError('Un-supported action for conflict resolution. '
'CockroachDB supports REPLACE (UPSERT), IGNORE '
'and UPDATE.')
def conflict_update(self, oc, query):
action = oc._action.lower() if oc._action else ''
if action in ('ignore', 'nothing'):
return SQL('ON CONFLICT DO NOTHING')
elif action in ('replace', 'upsert'):
# No special stuff is necessary, this is just indicated by starting
# the statement with UPSERT instead of INSERT.
return
elif oc._conflict_constraint:
raise ValueError('CockroachDB does not support the usage of a '
'constraint name. Use the column(s) instead.')
return super(CockroachDatabase, self).conflict_update(oc, query)
def extract_date(self, date_part, date_field):
return fn.extract(date_part, date_field)
def from_timestamp(self, date_field):
# CRDB does not allow casting a decimal/float to timestamp, so we first
# cast to int, then to timestamptz.
return date_field.cast('int').cast('timestamptz')
def begin(self, system_time=None, priority=None):
super(CockroachDatabase, self).begin()
if system_time is not None:
self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s',
(system_time,), commit=False)
if priority is not None:
priority = priority.lower()
if priority not in ('low', 'normal', 'high'):
raise ValueError('priority must be low, normal or high')
self.execute_sql('SET TRANSACTION PRIORITY %s' % priority,
commit=False)
def atomic(self, system_time=None, priority=None):
if self.server_version < NESTED_TX_MIN_VERSION:
return _crdb_atomic(self, system_time, priority)
return super(CockroachDatabase, self).atomic(system_time, priority)
def savepoint(self):
if self.server_version < NESTED_TX_MIN_VERSION:
raise NotImplementedError(TXN_ERR_MSG)
return super(CockroachDatabase, self).savepoint()
def retry_transaction(self, max_attempts=None, system_time=None,
priority=None):
def deco(cb):
@functools.wraps(cb)
def new_fn():
return run_transaction(self, cb, max_attempts, system_time,
priority)
return new_fn
return deco
def run_transaction(self, cb, max_attempts=None, system_time=None,
priority=None):
return run_transaction(self, cb, max_attempts, system_time, priority)
class _crdb_atomic(_atomic):
def __enter__(self):
if self.db.transaction_depth() > 0:
if not isinstance(self.db.top_transaction(), _manual):
raise NotImplementedError(TXN_ERR_MSG)
return super(_crdb_atomic, self).__enter__()
def run_transaction(db, callback, max_attempts=None, system_time=None,
priority=None):
"""
Run transactional SQL in a transaction with automatic retries.
User-provided `callback`:
* Must accept one parameter, the `db` instance representing the connection
the transaction is running under.
* Must not attempt to commit, rollback or otherwise manage transactions.
* May be called more than once.
* Should ideally only contain SQL operations.
Additionally, the database must not have any open transaction at the time
this function is called, as CRDB does not support nested transactions.
"""
max_attempts = max_attempts or -1
with db.atomic(system_time=system_time, priority=priority) as txn:
db.execute_sql('SAVEPOINT cockroach_restart')
while max_attempts != 0:
try:
result = callback(db)
db.execute_sql('RELEASE SAVEPOINT cockroach_restart')
return result
except OperationalError as exc:
if exc.orig.pgcode == '40001':
max_attempts -= 1
db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart')
continue
raise
raise ExceededMaxAttempts(None, 'unable to commit transaction')
class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase):
pass

@ -0,0 +1,451 @@
import csv
import datetime
from decimal import Decimal
import json
import operator
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import sys
from peewee import *
from playhouse.db_url import connect
from playhouse.migrate import migrate
from playhouse.migrate import SchemaMigrator
from playhouse.reflection import Introspector
if sys.version_info[0] == 3:
basestring = str
from functools import reduce
def open_file(f, mode):
return open(f, mode, encoding='utf8')
else:
open_file = open
class DataSet(object):
def __init__(self, url, **kwargs):
if isinstance(url, Database):
self._url = None
self._database = url
self._database_path = self._database.database
else:
self._url = url
parse_result = urlparse(url)
self._database_path = parse_result.path[1:]
# Connect to the database.
self._database = connect(url)
self._database.connect()
# Introspect the database and generate models.
self._introspector = Introspector.from_database(self._database)
self._models = self._introspector.generate_models(
skip_invalid=True,
literal_column_names=True,
**kwargs)
self._migrator = SchemaMigrator.from_database(self._database)
class BaseModel(Model):
class Meta:
database = self._database
self._base_model = BaseModel
self._export_formats = self.get_export_formats()
self._import_formats = self.get_import_formats()
def __repr__(self):
return '<DataSet: %s>' % self._database_path
def get_export_formats(self):
return {
'csv': CSVExporter,
'json': JSONExporter,
'tsv': TSVExporter}
def get_import_formats(self):
return {
'csv': CSVImporter,
'json': JSONImporter,
'tsv': TSVImporter}
def __getitem__(self, table):
if table not in self._models and table in self.tables:
self.update_cache(table)
return Table(self, table, self._models.get(table))
@property
def tables(self):
return self._database.get_tables()
def __contains__(self, table):
return table in self.tables
def connect(self):
self._database.connect()
def close(self):
self._database.close()
def update_cache(self, table=None):
if table:
dependencies = [table]
if table in self._models:
model_class = self._models[table]
dependencies.extend([
related._meta.table_name for _, related, _ in
model_class._meta.model_graph()])
else:
dependencies.extend(self.get_table_dependencies(table))
else:
dependencies = None # Update all tables.
self._models = {}
updated = self._introspector.generate_models(
skip_invalid=True,
table_names=dependencies,
literal_column_names=True)
self._models.update(updated)
def get_table_dependencies(self, table):
stack = [table]
accum = []
seen = set()
while stack:
table = stack.pop()
for fk_meta in self._database.get_foreign_keys(table):
dest = fk_meta.dest_table
if dest not in seen:
stack.append(dest)
accum.append(dest)
return accum
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not self._database.is_closed():
self.close()
def query(self, sql, params=None, commit=True):
return self._database.execute_sql(sql, params, commit)
def transaction(self):
if self._database.transaction_depth() == 0:
return self._database.transaction()
else:
return self._database.savepoint()
def _check_arguments(self, filename, file_obj, format, format_dict):
if filename and file_obj:
raise ValueError('file is over-specified. Please use either '
'filename or file_obj, but not both.')
if not filename and not file_obj:
raise ValueError('A filename or file-like object must be '
'specified.')
if format not in format_dict:
valid_formats = ', '.join(sorted(format_dict.keys()))
raise ValueError('Unsupported format "%s". Use one of %s.' % (
format, valid_formats))
def freeze(self, query, format='csv', filename=None, file_obj=None,
**kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'w')
exporter = self._export_formats[format](query)
exporter.export(file_obj, **kwargs)
if filename:
file_obj.close()
def thaw(self, table, format='csv', filename=None, file_obj=None,
strict=False, **kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'r')
importer = self._import_formats[format](self[table], strict)
count = importer.load(file_obj, **kwargs)
if filename:
file_obj.close()
return count
class Table(object):
def __init__(self, dataset, name, model_class):
self.dataset = dataset
self.name = name
if model_class is None:
model_class = self._create_model()
model_class.create_table()
self.dataset._models[name] = model_class
@property
def model_class(self):
return self.dataset._models[self.name]
def __repr__(self):
return '<Table: %s>' % self.name
def __len__(self):
return self.find().count()
def __iter__(self):
return iter(self.find().iterator())
def _create_model(self):
class Meta:
table_name = self.name
return type(
str(self.name),
(self.dataset._base_model,),
{'Meta': Meta})
def create_index(self, columns, unique=False):
index = ModelIndex(self.model_class, columns, unique=unique)
self.model_class.add_index(index)
self.dataset._database.execute(index)
def _guess_field_type(self, value):
if isinstance(value, basestring):
return TextField
if isinstance(value, (datetime.date, datetime.datetime)):
return DateTimeField
elif value is True or value is False:
return BooleanField
elif isinstance(value, int):
return IntegerField
elif isinstance(value, float):
return FloatField
elif isinstance(value, Decimal):
return DecimalField
return TextField
@property
def columns(self):
return [f.name for f in self.model_class._meta.sorted_fields]
def _migrate_new_columns(self, data):
new_keys = set(data) - set(self.model_class._meta.fields)
if new_keys:
operations = []
for key in new_keys:
field_class = self._guess_field_type(data[key])
field = field_class(null=True)
operations.append(
self.dataset._migrator.add_column(self.name, key, field))
field.bind(self.model_class, key)
migrate(*operations)
self.dataset.update_cache(self.name)
def __getitem__(self, item):
try:
return self.model_class[item]
except self.model_class.DoesNotExist:
pass
def __setitem__(self, item, value):
if not isinstance(value, dict):
raise ValueError('Table.__setitem__() value must be a dict')
pk = self.model_class._meta.primary_key
value[pk.name] = item
try:
with self.dataset.transaction() as txn:
self.insert(**value)
except IntegrityError:
self.dataset.update_cache(self.name)
self.update(columns=[pk.name], **value)
def __delitem__(self, item):
del self.model_class[item]
def insert(self, **data):
self._migrate_new_columns(data)
return self.model_class.insert(**data).execute()
def _apply_where(self, query, filters, conjunction=None):
conjunction = conjunction or operator.and_
if filters:
expressions = [
(self.model_class._meta.fields[column] == value)
for column, value in filters.items()]
query = query.where(reduce(conjunction, expressions))
return query
def update(self, columns=None, conjunction=None, **data):
self._migrate_new_columns(data)
filters = {}
if columns:
for column in columns:
filters[column] = data.pop(column)
return self._apply_where(
self.model_class.update(**data),
filters,
conjunction).execute()
def _query(self, **query):
return self._apply_where(self.model_class.select(), query)
def find(self, **query):
return self._query(**query).dicts()
def find_one(self, **query):
try:
return self.find(**query).get()
except self.model_class.DoesNotExist:
return None
def all(self):
return self.find()
def delete(self, **query):
return self._apply_where(self.model_class.delete(), query).execute()
def freeze(self, *args, **kwargs):
return self.dataset.freeze(self.all(), *args, **kwargs)
def thaw(self, *args, **kwargs):
return self.dataset.thaw(self.name, *args, **kwargs)
class Exporter(object):
def __init__(self, query):
self.query = query
def export(self, file_obj):
raise NotImplementedError
class JSONExporter(Exporter):
def __init__(self, query, iso8601_datetimes=False):
super(JSONExporter, self).__init__(query)
self.iso8601_datetimes = iso8601_datetimes
def _make_default(self):
datetime_types = (datetime.datetime, datetime.date, datetime.time)
if self.iso8601_datetimes:
def default(o):
if isinstance(o, datetime_types):
return o.isoformat()
elif isinstance(o, Decimal):
return str(o)
raise TypeError('Unable to serialize %r as JSON' % o)
else:
def default(o):
if isinstance(o, datetime_types + (Decimal,)):
return str(o)
raise TypeError('Unable to serialize %r as JSON' % o)
return default
def export(self, file_obj, **kwargs):
json.dump(
list(self.query),
file_obj,
default=self._make_default(),
**kwargs)
class CSVExporter(Exporter):
def export(self, file_obj, header=True, **kwargs):
writer = csv.writer(file_obj, **kwargs)
tuples = self.query.tuples().execute()
tuples.initialize()
if header and getattr(tuples, 'columns', None):
writer.writerow([column for column in tuples.columns])
for row in tuples:
writer.writerow(row)
class TSVExporter(CSVExporter):
def export(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVExporter, self).export(file_obj, header, **kwargs)
class Importer(object):
def __init__(self, table, strict=False):
self.table = table
self.strict = strict
model = self.table.model_class
self.columns = model._meta.columns
self.columns.update(model._meta.fields)
def load(self, file_obj):
raise NotImplementedError
class JSONImporter(Importer):
def load(self, file_obj, **kwargs):
data = json.load(file_obj, **kwargs)
count = 0
for row in data:
if self.strict:
obj = {}
for key in row:
field = self.columns.get(key)
if field is not None:
obj[field.name] = field.python_value(row[key])
else:
obj = row
if obj:
self.table.insert(**obj)
count += 1
return count
class CSVImporter(Importer):
def load(self, file_obj, header=True, **kwargs):
count = 0
reader = csv.reader(file_obj, **kwargs)
if header:
try:
header_keys = next(reader)
except StopIteration:
return count
if self.strict:
header_fields = []
for idx, key in enumerate(header_keys):
if key in self.columns:
header_fields.append((idx, self.columns[key]))
else:
header_fields = list(enumerate(header_keys))
else:
header_fields = list(enumerate(self.model._meta.sorted_fields))
if not header_fields:
return count
for row in reader:
obj = {}
for idx, field in header_fields:
if self.strict:
obj[field.name] = field.python_value(row[idx])
else:
obj[field] = row[idx]
self.table.insert(**obj)
count += 1
return count
class TSVImporter(CSVImporter):
def load(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVImporter, self).load(file_obj, header, **kwargs)

@ -0,0 +1,130 @@
try:
from urlparse import parse_qsl, unquote, urlparse
except ImportError:
from urllib.parse import parse_qsl, unquote, urlparse
from peewee import *
from playhouse.cockroachdb import CockroachDatabase
from playhouse.cockroachdb import PooledCockroachDatabase
from playhouse.pool import PooledMySQLDatabase
from playhouse.pool import PooledPostgresqlDatabase
from playhouse.pool import PooledSqliteDatabase
from playhouse.pool import PooledSqliteExtDatabase
from playhouse.sqlite_ext import SqliteExtDatabase
schemes = {
'cockroachdb': CockroachDatabase,
'cockroachdb+pool': PooledCockroachDatabase,
'crdb': CockroachDatabase,
'crdb+pool': PooledCockroachDatabase,
'mysql': MySQLDatabase,
'mysql+pool': PooledMySQLDatabase,
'postgres': PostgresqlDatabase,
'postgresql': PostgresqlDatabase,
'postgres+pool': PooledPostgresqlDatabase,
'postgresql+pool': PooledPostgresqlDatabase,
'sqlite': SqliteDatabase,
'sqliteext': SqliteExtDatabase,
'sqlite+pool': PooledSqliteDatabase,
'sqliteext+pool': PooledSqliteExtDatabase,
}
def register_database(db_class, *names):
global schemes
for name in names:
schemes[name] = db_class
def parseresult_to_dict(parsed, unquote_password=False):
# urlparse in python 2.6 is broken so query will be empty and instead
# appended to path complete with '?'
path_parts = parsed.path[1:].split('?')
try:
query = path_parts[1]
except IndexError:
query = parsed.query
connect_kwargs = {'database': path_parts[0]}
if parsed.username:
connect_kwargs['user'] = parsed.username
if parsed.password:
connect_kwargs['password'] = parsed.password
if unquote_password:
connect_kwargs['password'] = unquote(connect_kwargs['password'])
if parsed.hostname:
connect_kwargs['host'] = parsed.hostname
if parsed.port:
connect_kwargs['port'] = parsed.port
# Adjust parameters for MySQL.
if parsed.scheme == 'mysql' and 'password' in connect_kwargs:
connect_kwargs['passwd'] = connect_kwargs.pop('password')
elif 'sqlite' in parsed.scheme and not connect_kwargs['database']:
connect_kwargs['database'] = ':memory:'
# Get additional connection args from the query string
qs_args = parse_qsl(query, keep_blank_values=True)
for key, value in qs_args:
if value.lower() == 'false':
value = False
elif value.lower() == 'true':
value = True
elif value.isdigit():
value = int(value)
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
try:
value = float(value)
except ValueError:
pass
elif value.lower() in ('null', 'none'):
value = None
connect_kwargs[key] = value
return connect_kwargs
def parse(url, unquote_password=False):
parsed = urlparse(url)
return parseresult_to_dict(parsed, unquote_password)
def connect(url, unquote_password=False, **connect_params):
parsed = urlparse(url)
connect_kwargs = parseresult_to_dict(parsed, unquote_password)
connect_kwargs.update(connect_params)
database_class = schemes.get(parsed.scheme)
if database_class is None:
if database_class in schemes:
raise RuntimeError('Attempted to use "%s" but a required library '
'could not be imported.' % parsed.scheme)
else:
raise RuntimeError('Unrecognized or unsupported scheme: "%s".' %
parsed.scheme)
return database_class(**connect_kwargs)
# Conditionally register additional databases.
try:
from playhouse.pool import PooledPostgresqlExtDatabase
except ImportError:
pass
else:
register_database(
PooledPostgresqlExtDatabase,
'postgresext+pool',
'postgresqlext+pool')
try:
from playhouse.apsw_ext import APSWDatabase
except ImportError:
pass
else:
register_database(APSWDatabase, 'apsw')
try:
from playhouse.postgres_ext import PostgresqlExtDatabase
except ImportError:
pass
else:
register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext')

@ -0,0 +1,64 @@
try:
import bz2
except ImportError:
bz2 = None
try:
import zlib
except ImportError:
zlib = None
try:
import cPickle as pickle
except ImportError:
import pickle
import sys
from peewee import BlobField
from peewee import buffer_type
PY2 = sys.version_info[0] == 2
class CompressedField(BlobField):
ZLIB = 'zlib'
BZ2 = 'bz2'
algorithm_to_import = {
ZLIB: zlib,
BZ2: bz2,
}
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
**kwargs):
self.compression_level = compression_level
if algorithm not in self.algorithm_to_import:
raise ValueError('Unrecognized algorithm %s' % algorithm)
compress_module = self.algorithm_to_import[algorithm]
if compress_module is None:
raise ValueError('Missing library required for %s.' % algorithm)
self.algorithm = algorithm
self.compress = compress_module.compress
self.decompress = compress_module.decompress
super(CompressedField, self).__init__(*args, **kwargs)
def python_value(self, value):
if value is not None:
return self.decompress(value)
def db_value(self, value):
if value is not None:
return self._constructor(
self.compress(value, self.compression_level))
class PickleField(BlobField):
def python_value(self, value):
if value is not None:
if isinstance(value, buffer_type):
value = bytes(value)
return pickle.loads(value)
def db_value(self, value):
if value is not None:
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
return self._constructor(pickled)

@ -0,0 +1,185 @@
import math
import sys
from flask import abort
from flask import render_template
from flask import request
from peewee import Database
from peewee import DoesNotExist
from peewee import Model
from peewee import Proxy
from peewee import SelectQuery
from playhouse.db_url import connect as db_url_connect
class PaginatedQuery(object):
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
check_bounds=False):
self.paginate_by = paginate_by
self.page_var = page_var
self.page = page or None
self.check_bounds = check_bounds
if isinstance(query_or_model, SelectQuery):
self.query = query_or_model
self.model = self.query.model
else:
self.model = query_or_model
self.query = self.model.select()
def get_page(self):
if self.page is not None:
return self.page
curr_page = request.args.get(self.page_var)
if curr_page and curr_page.isdigit():
return max(1, int(curr_page))
return 1
def get_page_count(self):
if not hasattr(self, '_page_count'):
self._page_count = int(math.ceil(
float(self.query.count()) / self.paginate_by))
return self._page_count
def get_object_list(self):
if self.check_bounds and self.get_page() > self.get_page_count():
abort(404)
return self.query.paginate(self.get_page(), self.paginate_by)
def get_object_or_404(query_or_model, *query):
if not isinstance(query_or_model, SelectQuery):
query_or_model = query_or_model.select()
try:
return query_or_model.where(*query).get()
except DoesNotExist:
abort(404)
def object_list(template_name, query, context_variable='object_list',
paginate_by=20, page_var='page', page=None, check_bounds=True,
**kwargs):
paginated_query = PaginatedQuery(
query,
paginate_by=paginate_by,
page_var=page_var,
page=page,
check_bounds=check_bounds)
kwargs[context_variable] = paginated_query.get_object_list()
return render_template(
template_name,
pagination=paginated_query,
page=paginated_query.get_page(),
**kwargs)
def get_current_url():
if not request.query_string:
return request.path
return '%s?%s' % (request.path, request.query_string)
def get_next_url(default='/'):
if request.args.get('next'):
return request.args['next']
elif request.form.get('next'):
return request.form['next']
return default
class FlaskDB(object):
def __init__(self, app=None, database=None, model_class=Model):
self.database = None # Reference to actual Peewee database instance.
self.base_model_class = model_class
self._app = app
self._db = database # dict, url, Database, or None (default).
if app is not None:
self.init_app(app)
def init_app(self, app):
self._app = app
if self._db is None:
if 'DATABASE' in app.config:
initial_db = app.config['DATABASE']
elif 'DATABASE_URL' in app.config:
initial_db = app.config['DATABASE_URL']
else:
raise ValueError('Missing required configuration data for '
'database: DATABASE or DATABASE_URL.')
else:
initial_db = self._db
self._load_database(app, initial_db)
self._register_handlers(app)
def _load_database(self, app, config_value):
if isinstance(config_value, Database):
database = config_value
elif isinstance(config_value, dict):
database = self._load_from_config_dict(dict(config_value))
else:
# Assume a database connection URL.
database = db_url_connect(config_value)
if isinstance(self.database, Proxy):
self.database.initialize(database)
else:
self.database = database
def _load_from_config_dict(self, config_dict):
try:
name = config_dict.pop('name')
engine = config_dict.pop('engine')
except KeyError:
raise RuntimeError('DATABASE configuration must specify a '
'`name` and `engine`.')
if '.' in engine:
path, class_name = engine.rsplit('.', 1)
else:
path, class_name = 'peewee', engine
try:
__import__(path)
module = sys.modules[path]
database_class = getattr(module, class_name)
assert issubclass(database_class, Database)
except ImportError:
raise RuntimeError('Unable to import %s' % engine)
except AttributeError:
raise RuntimeError('Database engine not found %s' % engine)
except AssertionError:
raise RuntimeError('Database engine not a subclass of '
'peewee.Database: %s' % engine)
return database_class(name, **config_dict)
def _register_handlers(self, app):
app.before_request(self.connect_db)
app.teardown_request(self.close_db)
def get_model_class(self):
if self.database is None:
raise RuntimeError('Database must be initialized.')
class BaseModel(self.base_model_class):
class Meta:
database = self.database
return BaseModel
@property
def Model(self):
if self._app is None:
database = getattr(self, 'database', None)
if database is None:
self.database = Proxy()
if not hasattr(self, '_model_class'):
self._model_class = self.get_model_class()
return self._model_class
def connect_db(self):
self.database.connect()
def close_db(self, exc):
if not self.database.is_closed():
self.database.close()

@ -0,0 +1,53 @@
from peewee import ModelDescriptor
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
class hybrid_method(ModelDescriptor):
def __init__(self, func, expr=None):
self.func = func
self.expr = expr or func
def __get__(self, instance, instance_type):
if instance is None:
return self.expr.__get__(instance_type, instance_type.__class__)
return self.func.__get__(instance, instance_type)
def expression(self, expr):
self.expr = expr
return self
class hybrid_property(ModelDescriptor):
def __init__(self, fget, fset=None, fdel=None, expr=None):
self.fget = fget
self.fset = fset
self.fdel = fdel
self.expr = expr or fget
def __get__(self, instance, instance_type):
if instance is None:
return self.expr(instance_type)
return self.fget(instance)
def __set__(self, instance, value):
if self.fset is None:
raise AttributeError('Cannot set attribute.')
self.fset(instance, value)
def __delete__(self, instance):
if self.fdel is None:
raise AttributeError('Cannot delete attribute.')
self.fdel(instance)
def setter(self, fset):
self.fset = fset
return self
def deleter(self, fdel):
self.fdel = fdel
return self
def expression(self, expr):
self.expr = expr
return self

@ -0,0 +1,172 @@
import operator
from peewee import *
from peewee import Expression
from playhouse.fields import PickleField
try:
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
except ImportError:
from playhouse.sqlite_ext import SqliteExtDatabase
Sentinel = type('Sentinel', (object,), {})
class KeyValue(object):
"""
Persistent dictionary.
:param Field key_field: field to use for key. Defaults to CharField.
:param Field value_field: field to use for value. Defaults to PickleField.
:param bool ordered: data should be returned in key-sorted order.
:param Database database: database where key/value data is stored.
:param str table_name: table name for data.
"""
def __init__(self, key_field=None, value_field=None, ordered=False,
database=None, table_name='keyvalue'):
if key_field is None:
key_field = CharField(max_length=255, primary_key=True)
if not key_field.primary_key:
raise ValueError('key_field must have primary_key=True.')
if value_field is None:
value_field = PickleField()
self._key_field = key_field
self._value_field = value_field
self._ordered = ordered
self._database = database or SqliteExtDatabase(':memory:')
self._table_name = table_name
if isinstance(self._database, PostgresqlDatabase):
self.upsert = self._postgres_upsert
self.update = self._postgres_update
else:
self.upsert = self._upsert
self.update = self._update
self.model = self.create_model()
self.key = self.model.key
self.value = self.model.value
# Ensure table exists.
self.model.create_table()
def create_model(self):
class KeyValue(Model):
key = self._key_field
value = self._value_field
class Meta:
database = self._database
table_name = self._table_name
return KeyValue
def query(self, *select):
query = self.model.select(*select).tuples()
if self._ordered:
query = query.order_by(self.key)
return query
def convert_expression(self, expr):
if not isinstance(expr, Expression):
return (self.key == expr), True
return expr, False
def __contains__(self, key):
expr, _ = self.convert_expression(key)
return self.model.select().where(expr).exists()
def __len__(self):
return len(self.model)
def __getitem__(self, expr):
converted, is_single = self.convert_expression(expr)
query = self.query(self.value).where(converted)
item_getter = operator.itemgetter(0)
result = [item_getter(row) for row in query]
if len(result) == 0 and is_single:
raise KeyError(expr)
elif is_single:
return result[0]
return result
def _upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict('replace')
.execute())
def _postgres_upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def __setitem__(self, expr, value):
if isinstance(expr, Expression):
self.model.update(value=value).where(expr).execute()
else:
self.upsert(expr, value)
def __delitem__(self, expr):
converted, _ = self.convert_expression(expr)
self.model.delete().where(converted).execute()
def __iter__(self):
return iter(self.query().execute())
def keys(self):
return map(operator.itemgetter(0), self.query(self.key))
def values(self):
return map(operator.itemgetter(0), self.query(self.value))
def items(self):
return iter(self.query().execute())
def _update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict('replace')
.execute())
def _postgres_update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return default
def pop(self, key, default=Sentinel):
with self._database.atomic():
try:
result = self[key]
except KeyError:
if default is Sentinel:
raise
return default
del self[key]
return result
def clear(self):
self.model.delete().execute()

@ -0,0 +1,886 @@
"""
Lightweight schema migrations.
NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some
features.
Example Usage
-------------
Instantiate a migrator:
# Postgres example:
my_db = PostgresqlDatabase(...)
migrator = PostgresqlMigrator(my_db)
# SQLite example:
my_db = SqliteDatabase('my_database.db')
migrator = SqliteMigrator(my_db)
Then you will use the `migrate` function to run various `Operation`s which
are generated by the migrator:
migrate(
migrator.add_column('some_table', 'column_name', CharField(default=''))
)
Migrations are not run inside a transaction, so if you wish the migration to
run in a transaction you will need to wrap the call to `migrate` in a
transaction block, e.g.:
with my_db.transaction():
migrate(...)
Supported Operations
--------------------
Add new field(s) to an existing model:
# Create your field instances. For non-null fields you must specify a
# default value.
pubdate_field = DateTimeField(null=True)
comment_field = TextField(default='')
# Run the migration, specifying the database table, field name and field.
migrate(
migrator.add_column('comment_tbl', 'pub_date', pubdate_field),
migrator.add_column('comment_tbl', 'comment', comment_field),
)
Renaming a field:
# Specify the table, original name of the column, and its new name.
migrate(
migrator.rename_column('story', 'pub_date', 'publish_date'),
migrator.rename_column('story', 'mod_date', 'modified_date'),
)
Dropping a field:
migrate(
migrator.drop_column('story', 'some_old_field'),
)
Making a field nullable or not nullable:
# Note that when making a field not null that field must not have any
# NULL values present.
migrate(
# Make `pub_date` allow NULL values.
migrator.drop_not_null('story', 'pub_date'),
# Prevent `modified_date` from containing NULL values.
migrator.add_not_null('story', 'modified_date'),
)
Renaming a table:
migrate(
migrator.rename_table('story', 'stories_tbl'),
)
Adding an index:
# Specify the table, column names, and whether the index should be
# UNIQUE or not.
migrate(
# Create an index on the `pub_date` column.
migrator.add_index('story', ('pub_date',), False),
# Create a multi-column index on the `pub_date` and `status` fields.
migrator.add_index('story', ('pub_date', 'status'), False),
# Create a unique index on the category and title fields.
migrator.add_index('story', ('category_id', 'title'), True),
)
Dropping an index:
# Specify the index name.
migrate(migrator.drop_index('story', 'story_pub_date_status'))
Adding or dropping table constraints:
.. code-block:: python
# Add a CHECK() constraint to enforce the price cannot be negative.
migrate(migrator.add_constraint(
'products',
'price_check',
Check('price >= 0')))
# Remove the price check constraint.
migrate(migrator.drop_constraint('products', 'price_check'))
# Add a UNIQUE constraint on the first and last names.
migrate(migrator.add_unique('person', 'first_name', 'last_name'))
"""
from collections import namedtuple
import functools
import hashlib
import re
from peewee import *
from peewee import CommaNodeList
from peewee import EnclosedNodeList
from peewee import Entity
from peewee import Expression
from peewee import Node
from peewee import NodeList
from peewee import OP
from peewee import callable_
from peewee import sort_models
from peewee import _truncate_constraint_name
try:
from playhouse.cockroachdb import CockroachDatabase
except ImportError:
CockroachDatabase = None
class Operation(object):
"""Encapsulate a single schema altering operation."""
def __init__(self, migrator, method, *args, **kwargs):
self.migrator = migrator
self.method = method
self.args = args
self.kwargs = kwargs
def execute(self, node):
self.migrator.database.execute(node)
def _handle_result(self, result):
if isinstance(result, (Node, Context)):
self.execute(result)
elif isinstance(result, Operation):
result.run()
elif isinstance(result, (list, tuple)):
for item in result:
self._handle_result(item)
def run(self):
kwargs = self.kwargs.copy()
kwargs['with_context'] = True
method = getattr(self.migrator, self.method)
self._handle_result(method(*self.args, **kwargs))
def operation(fn):
@functools.wraps(fn)
def inner(self, *args, **kwargs):
with_context = kwargs.pop('with_context', False)
if with_context:
return fn(self, *args, **kwargs)
return Operation(self, fn.__name__, *args, **kwargs)
return inner
def make_index_name(table_name, columns):
index_name = '_'.join((table_name,) + tuple(columns))
if len(index_name) > 64:
index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest()
index_name = '%s_%s' % (index_name[:56], index_hash[:7])
return index_name
class SchemaMigrator(object):
explicit_create_foreign_key = False
explicit_delete_foreign_key = False
def __init__(self, database):
self.database = database
def make_context(self):
return self.database.get_sql_context()
@classmethod
def from_database(cls, database):
if CockroachDatabase and isinstance(database, CockroachDatabase):
return CockroachDBMigrator(database)
elif isinstance(database, PostgresqlDatabase):
return PostgresqlMigrator(database)
elif isinstance(database, MySQLDatabase):
return MySQLMigrator(database)
elif isinstance(database, SqliteDatabase):
return SqliteMigrator(database)
raise ValueError('Unsupported database: %s' % database)
@operation
def apply_default(self, table, column_name, field):
default = field.default
if callable_(default):
default = default()
return (self.make_context()
.literal('UPDATE ')
.sql(Entity(table))
.literal(' SET ')
.sql(Expression(
Entity(column_name),
OP.EQ,
field.db_value(default),
flat=True)))
def _alter_table(self, ctx, table):
return ctx.literal('ALTER TABLE ').sql(Entity(table))
def _alter_column(self, ctx, table, column):
return (self
._alter_table(ctx, table)
.literal(' ALTER COLUMN ')
.sql(Entity(column)))
@operation
def alter_add_column(self, table, column_name, field):
# Make field null at first.
ctx = self.make_context()
field_null, field.null = field.null, True
# Set the field's column-name and name, if it is not set or doesn't
# match the new value.
if field.column_name != column_name:
field.name = field.column_name = column_name
(self
._alter_table(ctx, table)
.literal(' ADD COLUMN ')
.sql(field.ddl(ctx)))
field.null = field_null
if isinstance(field, ForeignKeyField):
self.add_inline_fk_sql(ctx, field)
return ctx
@operation
def add_constraint(self, table, name, constraint):
return (self
._alter_table(self.make_context(), table)
.literal(' ADD CONSTRAINT ')
.sql(Entity(name))
.literal(' ')
.sql(constraint))
@operation
def add_unique(self, table, *column_names):
constraint_name = 'uniq_%s' % '_'.join(column_names)
constraint = NodeList((
SQL('UNIQUE'),
EnclosedNodeList([Entity(column) for column in column_names])))
return self.add_constraint(table, constraint_name, constraint)
@operation
def drop_constraint(self, table, name):
return (self
._alter_table(self.make_context(), table)
.literal(' DROP CONSTRAINT ')
.sql(Entity(name)))
def add_inline_fk_sql(self, ctx, field):
ctx = (ctx
.literal(' REFERENCES ')
.sql(Entity(field.rel_model._meta.table_name))
.literal(' ')
.sql(EnclosedNodeList((Entity(field.rel_field.column_name),))))
if field.on_delete is not None:
ctx = ctx.literal(' ON DELETE %s' % field.on_delete)
if field.on_update is not None:
ctx = ctx.literal(' ON UPDATE %s' % field.on_update)
return ctx
@operation
def add_foreign_key_constraint(self, table, column_name, rel, rel_column,
on_delete=None, on_update=None):
constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel)
ctx = (self
.make_context()
.literal('ALTER TABLE ')
.sql(Entity(table))
.literal(' ADD CONSTRAINT ')
.sql(Entity(_truncate_constraint_name(constraint)))
.literal(' FOREIGN KEY ')
.sql(EnclosedNodeList((Entity(column_name),)))
.literal(' REFERENCES ')
.sql(Entity(rel))
.literal(' (')
.sql(Entity(rel_column))
.literal(')'))
if on_delete is not None:
ctx = ctx.literal(' ON DELETE %s' % on_delete)
if on_update is not None:
ctx = ctx.literal(' ON UPDATE %s' % on_update)
return ctx
@operation
def add_column(self, table, column_name, field):
# Adding a column is complicated by the fact that if there are rows
# present and the field is non-null, then we need to first add the
# column as a nullable field, then set the value, then add a not null
# constraint.
if not field.null and field.default is None:
raise ValueError('%s is not null but has no default' % column_name)
is_foreign_key = isinstance(field, ForeignKeyField)
if is_foreign_key and not field.rel_field:
raise ValueError('Foreign keys must specify a `field`.')
operations = [self.alter_add_column(table, column_name, field)]
# In the event the field is *not* nullable, update with the default
# value and set not null.
if not field.null:
operations.extend([
self.apply_default(table, column_name, field),
self.add_not_null(table, column_name)])
if is_foreign_key and self.explicit_create_foreign_key:
operations.append(
self.add_foreign_key_constraint(
table,
column_name,
field.rel_model._meta.table_name,
field.rel_field.column_name,
field.on_delete,
field.on_update))
if field.index or field.unique:
using = getattr(field, 'index_type', None)
operations.append(self.add_index(table, (column_name,),
field.unique, using))
return operations
@operation
def drop_foreign_key_constraint(self, table, column_name):
raise NotImplementedError
@operation
def drop_column(self, table, column_name, cascade=True):
ctx = self.make_context()
(self._alter_table(ctx, table)
.literal(' DROP COLUMN ')
.sql(Entity(column_name)))
if cascade:
ctx.literal(' CASCADE')
fk_columns = [
foreign_key.column
for foreign_key in self.database.get_foreign_keys(table)]
if column_name in fk_columns and self.explicit_delete_foreign_key:
return [self.drop_foreign_key_constraint(table, column_name), ctx]
return ctx
@operation
def rename_column(self, table, old_name, new_name):
return (self
._alter_table(self.make_context(), table)
.literal(' RENAME COLUMN ')
.sql(Entity(old_name))
.literal(' TO ')
.sql(Entity(new_name)))
@operation
def add_not_null(self, table, column):
return (self
._alter_column(self.make_context(), table, column)
.literal(' SET NOT NULL'))
@operation
def drop_not_null(self, table, column):
return (self
._alter_column(self.make_context(), table, column)
.literal(' DROP NOT NULL'))
@operation
def alter_column_type(self, table, column, field, cast=None):
# ALTER TABLE <table> ALTER COLUMN <column>
ctx = self.make_context()
ctx = (self
._alter_column(ctx, table, column)
.literal(' TYPE ')
.sql(field.ddl_datatype(ctx)))
if cast is not None:
if not isinstance(cast, Node):
cast = SQL(cast)
ctx = ctx.literal(' USING ').sql(cast)
return ctx
@operation
def rename_table(self, old_name, new_name):
return (self
._alter_table(self.make_context(), old_name)
.literal(' RENAME TO ')
.sql(Entity(new_name)))
@operation
def add_index(self, table, columns, unique=False, using=None):
ctx = self.make_context()
index_name = make_index_name(table, columns)
table_obj = Table(table)
cols = [getattr(table_obj.c, column) for column in columns]
index = Index(index_name, table_obj, cols, unique=unique, using=using)
return ctx.sql(index)
@operation
def drop_index(self, table, index_name):
return (self
.make_context()
.literal('DROP INDEX ')
.sql(Entity(index_name)))
class PostgresqlMigrator(SchemaMigrator):
def _primary_key_columns(self, tbl):
query = """
SELECT pg_attribute.attname
FROM pg_index, pg_class, pg_attribute
WHERE
pg_class.oid = '%s'::regclass AND
indrelid = pg_class.oid AND
pg_attribute.attrelid = pg_class.oid AND
pg_attribute.attnum = any(pg_index.indkey) AND
indisprimary;
"""
cursor = self.database.execute_sql(query % tbl)
return [row[0] for row in cursor.fetchall()]
@operation
def set_search_path(self, schema_name):
return (self
.make_context()
.literal('SET search_path TO %s' % schema_name))
@operation
def rename_table(self, old_name, new_name):
pk_names = self._primary_key_columns(old_name)
ParentClass = super(PostgresqlMigrator, self)
operations = [
ParentClass.rename_table(old_name, new_name, with_context=True)]
if len(pk_names) == 1:
# Check for existence of primary key sequence.
seq_name = '%s_%s_seq' % (old_name, pk_names[0])
query = """
SELECT 1
FROM information_schema.sequences
WHERE LOWER(sequence_name) = LOWER(%s)
"""
cursor = self.database.execute_sql(query, (seq_name,))
if bool(cursor.fetchone()):
new_seq_name = '%s_%s_seq' % (new_name, pk_names[0])
operations.append(ParentClass.rename_table(
seq_name, new_seq_name))
return operations
class CockroachDBMigrator(PostgresqlMigrator):
explicit_create_foreign_key = True
def add_inline_fk_sql(self, ctx, field):
pass
@operation
def drop_index(self, table, index_name):
return (self
.make_context()
.literal('DROP INDEX ')
.sql(Entity(index_name))
.literal(' CASCADE'))
class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk',
'default', 'extra'))):
@property
def is_pk(self):
return self.pk == 'PRI'
@property
def is_unique(self):
return self.pk == 'UNI'
@property
def is_null(self):
return self.null == 'YES'
def sql(self, column_name=None, is_null=None):
if is_null is None:
is_null = self.is_null
if column_name is None:
column_name = self.name
parts = [
Entity(column_name),
SQL(self.definition)]
if self.is_unique:
parts.append(SQL('UNIQUE'))
if is_null:
parts.append(SQL('NULL'))
else:
parts.append(SQL('NOT NULL'))
if self.is_pk:
parts.append(SQL('PRIMARY KEY'))
if self.extra:
parts.append(SQL(self.extra))
return NodeList(parts)
class MySQLMigrator(SchemaMigrator):
explicit_create_foreign_key = True
explicit_delete_foreign_key = True
def _alter_column(self, ctx, table, column):
return (self
._alter_table(ctx, table)
.literal(' MODIFY ')
.sql(Entity(column)))
@operation
def rename_table(self, old_name, new_name):
return (self
.make_context()
.literal('RENAME TABLE ')
.sql(Entity(old_name))
.literal(' TO ')
.sql(Entity(new_name)))
def _get_column_definition(self, table, column_name):
cursor = self.database.execute_sql('DESCRIBE `%s`;' % table)
rows = cursor.fetchall()
for row in rows:
column = MySQLColumn(*row)
if column.name == column_name:
return column
return False
def get_foreign_key_constraint(self, table, column_name):
cursor = self.database.execute_sql(
('SELECT constraint_name '
'FROM information_schema.key_column_usage WHERE '
'table_schema = DATABASE() AND '
'table_name = %s AND '
'column_name = %s AND '
'referenced_table_name IS NOT NULL AND '
'referenced_column_name IS NOT NULL;'),
(table, column_name))
result = cursor.fetchone()
if not result:
raise AttributeError(
'Unable to find foreign key constraint for '
'"%s" on table "%s".' % (table, column_name))
return result[0]
@operation
def drop_foreign_key_constraint(self, table, column_name):
fk_constraint = self.get_foreign_key_constraint(table, column_name)
return (self
._alter_table(self.make_context(), table)
.literal(' DROP FOREIGN KEY ')
.sql(Entity(fk_constraint)))
def add_inline_fk_sql(self, ctx, field):
pass
@operation
def add_not_null(self, table, column):
column_def = self._get_column_definition(table, column)
add_not_null = (self
._alter_table(self.make_context(), table)
.literal(' MODIFY ')
.sql(column_def.sql(is_null=False)))
fk_objects = dict(
(fk.column, fk)
for fk in self.database.get_foreign_keys(table))
if column not in fk_objects:
return add_not_null
fk_metadata = fk_objects[column]
return (self.drop_foreign_key_constraint(table, column),
add_not_null,
self.add_foreign_key_constraint(
table,
column,
fk_metadata.dest_table,
fk_metadata.dest_column))
@operation
def drop_not_null(self, table, column):
column = self._get_column_definition(table, column)
if column.is_pk:
raise ValueError('Primary keys can not be null')
return (self
._alter_table(self.make_context(), table)
.literal(' MODIFY ')
.sql(column.sql(is_null=True)))
@operation
def rename_column(self, table, old_name, new_name):
fk_objects = dict(
(fk.column, fk)
for fk in self.database.get_foreign_keys(table))
is_foreign_key = old_name in fk_objects
column = self._get_column_definition(table, old_name)
rename_ctx = (self
._alter_table(self.make_context(), table)
.literal(' CHANGE ')
.sql(Entity(old_name))
.literal(' ')
.sql(column.sql(column_name=new_name)))
if is_foreign_key:
fk_metadata = fk_objects[old_name]
return [
self.drop_foreign_key_constraint(table, old_name),
rename_ctx,
self.add_foreign_key_constraint(
table,
new_name,
fk_metadata.dest_table,
fk_metadata.dest_column),
]
else:
return rename_ctx
@operation
def alter_column_type(self, table, column, field, cast=None):
if cast is not None:
raise ValueError('alter_column_type() does not support cast with '
'MySQL.')
ctx = self.make_context()
return (self
._alter_table(ctx, table)
.literal(' MODIFY ')
.sql(Entity(column))
.literal(' ')
.sql(field.ddl(ctx)))
@operation
def drop_index(self, table, index_name):
return (self
.make_context()
.literal('DROP INDEX ')
.sql(Entity(index_name))
.literal(' ON ')
.sql(Entity(table)))
class SqliteMigrator(SchemaMigrator):
"""
SQLite supports a subset of ALTER TABLE queries, view the docs for the
full details http://sqlite.org/lang_altertable.html
"""
column_re = re.compile('(.+?)\((.+)\)')
column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+')
column_name_re = re.compile(r'''["`']?([\w]+)''')
fk_re = re.compile(r'FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I)
def _get_column_names(self, table):
res = self.database.execute_sql('select * from "%s" limit 1' % table)
return [item[0] for item in res.description]
def _get_create_table(self, table):
res = self.database.execute_sql(
('select name, sql from sqlite_master '
'where type=? and LOWER(name)=?'),
['table', table.lower()])
return res.fetchone()
@operation
def _update_column(self, table, column_to_update, fn):
columns = set(column.name.lower()
for column in self.database.get_columns(table))
if column_to_update.lower() not in columns:
raise ValueError('Column "%s" does not exist on "%s"' %
(column_to_update, table))
# Get the SQL used to create the given table.
table, create_table = self._get_create_table(table)
# Get the indexes and SQL to re-create indexes.
indexes = self.database.get_indexes(table)
# Find any foreign keys we may need to remove.
self.database.get_foreign_keys(table)
# Make sure the create_table does not contain any newlines or tabs,
# allowing the regex to work correctly.
create_table = re.sub(r'\s+', ' ', create_table)
# Parse out the `CREATE TABLE` and column list portions of the query.
raw_create, raw_columns = self.column_re.search(create_table).groups()
# Clean up the individual column definitions.
split_columns = self.column_split_re.findall(raw_columns)
column_defs = [col.strip() for col in split_columns]
new_column_defs = []
new_column_names = []
original_column_names = []
constraint_terms = ('foreign ', 'primary ', 'constraint ', 'check ')
for column_def in column_defs:
column_name, = self.column_name_re.match(column_def).groups()
if column_name == column_to_update:
new_column_def = fn(column_name, column_def)
if new_column_def:
new_column_defs.append(new_column_def)
original_column_names.append(column_name)
column_name, = self.column_name_re.match(
new_column_def).groups()
new_column_names.append(column_name)
else:
new_column_defs.append(column_def)
# Avoid treating constraints as columns.
if not column_def.lower().startswith(constraint_terms):
new_column_names.append(column_name)
original_column_names.append(column_name)
# Create a mapping of original columns to new columns.
original_to_new = dict(zip(original_column_names, new_column_names))
new_column = original_to_new.get(column_to_update)
fk_filter_fn = lambda column_def: column_def
if not new_column:
# Remove any foreign keys associated with this column.
fk_filter_fn = lambda column_def: None
elif new_column != column_to_update:
# Update any foreign keys for this column.
fk_filter_fn = lambda column_def: self.fk_re.sub(
'FOREIGN KEY ("%s") ' % new_column,
column_def)
cleaned_columns = []
for column_def in new_column_defs:
match = self.fk_re.match(column_def)
if match is not None and match.groups()[0] == column_to_update:
column_def = fk_filter_fn(column_def)
if column_def:
cleaned_columns.append(column_def)
# Update the name of the new CREATE TABLE query.
temp_table = table + '__tmp__'
rgx = re.compile('("?)%s("?)' % table, re.I)
create = rgx.sub(
'\\1%s\\2' % temp_table,
raw_create)
# Create the new table.
columns = ', '.join(cleaned_columns)
queries = [
NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]),
SQL('%s (%s)' % (create.strip(), columns))]
# Populate new table.
populate_table = NodeList((
SQL('INSERT INTO'),
Entity(temp_table),
EnclosedNodeList([Entity(col) for col in new_column_names]),
SQL('SELECT'),
CommaNodeList([Entity(col) for col in original_column_names]),
SQL('FROM'),
Entity(table)))
drop_original = NodeList([SQL('DROP TABLE'), Entity(table)])
# Drop existing table and rename temp table.
queries += [
populate_table,
drop_original,
self.rename_table(temp_table, table)]
# Re-create user-defined indexes. User-defined indexes will have a
# non-empty SQL attribute.
for index in filter(lambda idx: idx.sql, indexes):
if column_to_update not in index.columns:
queries.append(SQL(index.sql))
elif new_column:
sql = self._fix_index(index.sql, column_to_update, new_column)
if sql is not None:
queries.append(SQL(sql))
return queries
def _fix_index(self, sql, column_to_update, new_column):
# Split on the name of the column to update. If it splits into two
# pieces, then there's no ambiguity and we can simply replace the
# old with the new.
parts = sql.split(column_to_update)
if len(parts) == 2:
return sql.replace(column_to_update, new_column)
# Find the list of columns in the index expression.
lhs, rhs = sql.rsplit('(', 1)
# Apply the same "split in two" logic to the column list portion of
# the query.
if len(rhs.split(column_to_update)) == 2:
return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column))
# Strip off the trailing parentheses and go through each column.
parts = rhs.rsplit(')', 1)[0].split(',')
columns = [part.strip('"`[]\' ') for part in parts]
# `columns` looks something like: ['status', 'timestamp" DESC']
# https://www.sqlite.org/lang_keywords.html
# Strip out any junk after the column name.
clean = []
for column in columns:
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
column = new_column + column[len(column_to_update):]
clean.append(column)
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))
@operation
def drop_column(self, table, column_name, cascade=True):
return self._update_column(table, column_name, lambda a, b: None)
@operation
def rename_column(self, table, old_name, new_name):
def _rename(column_name, column_def):
return column_def.replace(column_name, new_name)
return self._update_column(table, old_name, _rename)
@operation
def add_not_null(self, table, column):
def _add_not_null(column_name, column_def):
return column_def + ' NOT NULL'
return self._update_column(table, column, _add_not_null)
@operation
def drop_not_null(self, table, column):
def _drop_not_null(column_name, column_def):
return column_def.replace('NOT NULL', '')
return self._update_column(table, column, _drop_not_null)
@operation
def alter_column_type(self, table, column, field, cast=None):
if cast is not None:
raise ValueError('alter_column_type() does not support cast with '
'Sqlite.')
ctx = self.make_context()
def _alter_column_type(column_name, column_def):
node_list = field.ddl(ctx)
sql, _ = ctx.sql(Entity(column)).sql(node_list).query()
return sql
return self._update_column(table, column, _alter_column_type)
@operation
def add_constraint(self, table, name, constraint):
raise NotImplementedError
@operation
def drop_constraint(self, table, name):
raise NotImplementedError
@operation
def add_foreign_key_constraint(self, table, column_name, field,
on_delete=None, on_update=None):
raise NotImplementedError
def migrate(*operations, **kwargs):
for operation in operations:
operation.run()

@ -0,0 +1,49 @@
import json
try:
import mysql.connector as mysql_connector
except ImportError:
mysql_connector = None
from peewee import ImproperlyConfigured
from peewee import MySQLDatabase
from peewee import NodeList
from peewee import SQL
from peewee import TextField
from peewee import fn
class MySQLConnectorDatabase(MySQLDatabase):
def _connect(self):
if mysql_connector is None:
raise ImproperlyConfigured('MySQL connector not installed!')
return mysql_connector.connect(db=self.database, **self.connect_params)
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor(buffered=True)
class JSONField(TextField):
field_type = 'JSON'
def db_value(self, value):
if value is not None:
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
def Match(columns, expr, modifier=None):
if isinstance(columns, (list, tuple)):
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
else:
match = fn.MATCH(columns) # Single column / field.
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
return NodeList((match, fn.AGAINST(args)))

@ -0,0 +1,318 @@
"""
Lightweight connection pooling for peewee.
In a multi-threaded application, up to `max_connections` will be opened. Each
thread (or, if using gevent, greenlet) will have it's own connection.
In a single-threaded application, only one connection will be created. It will
be continually recycled until either it exceeds the stale timeout or is closed
explicitly (using `.manual_close()`).
By default, all your application needs to do is ensure that connections are
closed when you are finished with them, and they will be returned to the pool.
For web applications, this typically means that at the beginning of a request,
you will open a connection, and when you return a response, you will close the
connection.
Simple Postgres pool example code:
# Use the special postgresql extensions.
from playhouse.pool import PooledPostgresqlExtDatabase
db = PooledPostgresqlExtDatabase(
'my_app',
max_connections=32,
stale_timeout=300, # 5 minutes.
user='postgres')
class BaseModel(Model):
class Meta:
database = db
That's it!
"""
import heapq
import logging
import random
import time
from collections import namedtuple
from itertools import chain
try:
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
from psycopg2.extensions import TRANSACTION_STATUS_INERROR
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN
except ImportError:
TRANSACTION_STATUS_IDLE = \
TRANSACTION_STATUS_INERROR = \
TRANSACTION_STATUS_UNKNOWN = None
from peewee import MySQLDatabase
from peewee import PostgresqlDatabase
from peewee import SqliteDatabase
logger = logging.getLogger('peewee.pool')
def make_int(val):
if val is not None and not isinstance(val, (int, float)):
return int(val)
return val
class MaxConnectionsExceeded(ValueError): pass
PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection',
'checked_out'))
class PooledDatabase(object):
def __init__(self, database, max_connections=20, stale_timeout=None,
timeout=None, **kwargs):
self._max_connections = make_int(max_connections)
self._stale_timeout = make_int(stale_timeout)
self._wait_timeout = make_int(timeout)
if self._wait_timeout == 0:
self._wait_timeout = float('inf')
# Available / idle connections stored in a heap, sorted oldest first.
self._connections = []
# Mapping of connection id to PoolConnection. Ordinarily we would want
# to use something like a WeakKeyDictionary, but Python typically won't
# allow us to create weak references to connection objects.
self._in_use = {}
# Use the memory address of the connection as the key in the event the
# connection object is not hashable. Connections will not get
# garbage-collected, however, because a reference to them will persist
# in "_in_use" as long as the conn has not been closed.
self.conn_key = id
super(PooledDatabase, self).__init__(database, **kwargs)
def init(self, database, max_connections=None, stale_timeout=None,
timeout=None, **connect_kwargs):
super(PooledDatabase, self).init(database, **connect_kwargs)
if max_connections is not None:
self._max_connections = make_int(max_connections)
if stale_timeout is not None:
self._stale_timeout = make_int(stale_timeout)
if timeout is not None:
self._wait_timeout = make_int(timeout)
if self._wait_timeout == 0:
self._wait_timeout = float('inf')
def connect(self, reuse_if_open=False):
if not self._wait_timeout:
return super(PooledDatabase, self).connect(reuse_if_open)
expires = time.time() + self._wait_timeout
while expires > time.time():
try:
ret = super(PooledDatabase, self).connect(reuse_if_open)
except MaxConnectionsExceeded:
time.sleep(0.1)
else:
return ret
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
'attempting to connect.')
def _connect(self):
while True:
try:
# Remove the oldest connection from the heap.
ts, conn = heapq.heappop(self._connections)
key = self.conn_key(conn)
except IndexError:
ts = conn = None
logger.debug('No connection available in pool.')
break
else:
if self._is_closed(conn):
# This connecton was closed, but since it was not stale
# it got added back to the queue of available conns. We
# then closed it and marked it as explicitly closed, so
# it's safe to throw it away now.
# (Because Database.close() calls Database._close()).
logger.debug('Connection %s was closed.', key)
ts = conn = None
elif self._stale_timeout and self._is_stale(ts):
# If we are attempting to check out a stale connection,
# then close it. We don't need to mark it in the "closed"
# set, because it is not in the list of available conns
# anymore.
logger.debug('Connection %s was stale, closing.', key)
self._close(conn, True)
ts = conn = None
else:
break
if conn is None:
if self._max_connections and (
len(self._in_use) >= self._max_connections):
raise MaxConnectionsExceeded('Exceeded maximum connections.')
conn = super(PooledDatabase, self)._connect()
ts = time.time() - random.random() / 1000
key = self.conn_key(conn)
logger.debug('Created new connection %s.', key)
self._in_use[key] = PoolConnection(ts, conn, time.time())
return conn
def _is_stale(self, timestamp):
# Called on check-out and check-in to ensure the connection has
# not outlived the stale timeout.
return (time.time() - timestamp) > self._stale_timeout
def _is_closed(self, conn):
return False
def _can_reuse(self, conn):
# Called on check-in to make sure the connection can be re-used.
return True
def _close(self, conn, close_conn=False):
key = self.conn_key(conn)
if close_conn:
super(PooledDatabase, self)._close(conn)
elif key in self._in_use:
pool_conn = self._in_use.pop(key)
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
logger.debug('Closing stale connection %s.', key)
super(PooledDatabase, self)._close(conn)
elif self._can_reuse(conn):
logger.debug('Returning %s to pool.', key)
heapq.heappush(self._connections, (pool_conn.timestamp, conn))
else:
logger.debug('Closed %s.', key)
def manual_close(self):
"""
Close the underlying connection without returning it to the pool.
"""
if self.is_closed():
return False
# Obtain reference to the connection in-use by the calling thread.
conn = self.connection()
# A connection will only be re-added to the available list if it is
# marked as "in use" at the time it is closed. We will explicitly
# remove it from the "in use" list, call "close()" for the
# side-effects, and then explicitly close the connection.
self._in_use.pop(self.conn_key(conn), None)
self.close()
self._close(conn, close_conn=True)
def close_idle(self):
# Close any open connections that are not currently in-use.
with self._lock:
for _, conn in self._connections:
self._close(conn, close_conn=True)
self._connections = []
def close_stale(self, age=600):
# Close any connections that are in-use but were checked out quite some
# time ago and can be considered stale.
with self._lock:
in_use = {}
cutoff = time.time() - age
n = 0
for key, pool_conn in self._in_use.items():
if pool_conn.checked_out < cutoff:
self._close(pool_conn.connection, close_conn=True)
n += 1
else:
in_use[key] = pool_conn
self._in_use = in_use
return n
def close_all(self):
# Close all connections -- available and in-use. Warning: may break any
# active connections used by other threads.
self.close()
with self._lock:
for _, conn in self._connections:
self._close(conn, close_conn=True)
for pool_conn in self._in_use.values():
self._close(pool_conn.connection, close_conn=True)
self._connections = []
self._in_use = {}
class PooledMySQLDatabase(PooledDatabase, MySQLDatabase):
def _is_closed(self, conn):
try:
conn.ping(False)
except:
return True
else:
return False
class _PooledPostgresqlDatabase(PooledDatabase):
def _is_closed(self, conn):
if conn.closed:
return True
txn_status = conn.get_transaction_status()
if txn_status == TRANSACTION_STATUS_UNKNOWN:
return True
elif txn_status != TRANSACTION_STATUS_IDLE:
conn.rollback()
return False
def _can_reuse(self, conn):
txn_status = conn.get_transaction_status()
# Do not return connection in an error state, as subsequent queries
# will all fail. If the status is unknown then we lost the connection
# to the server and the connection should not be re-used.
if txn_status == TRANSACTION_STATUS_UNKNOWN:
return False
elif txn_status == TRANSACTION_STATUS_INERROR:
conn.reset()
elif txn_status != TRANSACTION_STATUS_IDLE:
conn.rollback()
return True
class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase):
pass
try:
from playhouse.postgres_ext import PostgresqlExtDatabase
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
pass
except ImportError:
PooledPostgresqlExtDatabase = None
class _PooledSqliteDatabase(PooledDatabase):
def _is_closed(self, conn):
try:
conn.total_changes
except:
return True
else:
return False
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
pass
try:
from playhouse.sqlite_ext import SqliteExtDatabase
class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase):
pass
except ImportError:
PooledSqliteExtDatabase = None
try:
from playhouse.sqlite_ext import CSqliteExtDatabase
class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase):
pass
except ImportError:
PooledCSqliteExtDatabase = None

@ -0,0 +1,493 @@
"""
Collection of postgres-specific extensions, currently including:
* Support for hstore, a key/value type storage
"""
import json
import logging
import uuid
from peewee import *
from peewee import ColumnBase
from peewee import Expression
from peewee import Node
from peewee import NodeList
from peewee import SENTINEL
from peewee import __exception_wrapper__
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
try:
from psycopg2.extras import register_hstore
except ImportError:
def register_hstore(c, globally):
pass
try:
from psycopg2.extras import Json
except:
Json = None
logger = logging.getLogger('peewee')
HCONTAINS_DICT = '@>'
HCONTAINS_KEYS = '?&'
HCONTAINS_KEY = '?'
HCONTAINS_ANY_KEY = '?|'
HKEY = '->'
HUPDATE = '||'
ACONTAINS = '@>'
ACONTAINS_ANY = '&&'
TS_MATCH = '@@'
JSONB_CONTAINS = '@>'
JSONB_CONTAINED_BY = '<@'
JSONB_CONTAINS_KEY = '?'
JSONB_CONTAINS_ANY_KEY = '?|'
JSONB_CONTAINS_ALL_KEYS = '?&'
JSONB_EXISTS = '?'
JSONB_REMOVE = '-'
class _LookupNode(ColumnBase):
def __init__(self, node, parts):
self.node = node
self.parts = parts
super(_LookupNode, self).__init__()
def clone(self):
return type(self)(self.node, list(self.parts))
def __hash__(self):
return hash((self.__class__.__name__, id(self)))
class _JsonLookupBase(_LookupNode):
def __init__(self, node, parts, as_json=False):
super(_JsonLookupBase, self).__init__(node, parts)
self._as_json = as_json
def clone(self):
return type(self)(self.node, list(self.parts), self._as_json)
@Node.copy
def as_json(self, as_json=True):
self._as_json = as_json
def concat(self, rhs):
if not isinstance(rhs, Node):
rhs = Json(rhs)
return Expression(self.as_json(True), OP.CONCAT, rhs)
def contains(self, other):
clone = self.as_json(True)
if isinstance(other, (list, dict)):
return Expression(clone, JSONB_CONTAINS, Json(other))
return Expression(clone, JSONB_EXISTS, other)
def contains_any(self, *keys):
return Expression(
self.as_json(True),
JSONB_CONTAINS_ANY_KEY,
Value(list(keys), unpack=False))
def contains_all(self, *keys):
return Expression(
self.as_json(True),
JSONB_CONTAINS_ALL_KEYS,
Value(list(keys), unpack=False))
def has_key(self, key):
return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key)
class JsonLookup(_JsonLookupBase):
def __getitem__(self, value):
return JsonLookup(self.node, self.parts + [value], self._as_json)
def __sql__(self, ctx):
ctx.sql(self.node)
for part in self.parts[:-1]:
ctx.literal('->').sql(part)
if self.parts:
(ctx
.literal('->' if self._as_json else '->>')
.sql(self.parts[-1]))
return ctx
class JsonPath(_JsonLookupBase):
def __sql__(self, ctx):
return (ctx
.sql(self.node)
.literal('#>' if self._as_json else '#>>')
.sql(Value('{%s}' % ','.join(map(str, self.parts)))))
class ObjectSlice(_LookupNode):
@classmethod
def create(cls, node, value):
if isinstance(value, slice):
parts = [value.start or 0, value.stop or 0]
elif isinstance(value, int):
parts = [value]
elif isinstance(value, Node):
parts = value
else:
# Assumes colon-separated integer indexes.
parts = [int(i) for i in value.split(':')]
return cls(node, parts)
def __sql__(self, ctx):
ctx.sql(self.node)
if isinstance(self.parts, Node):
ctx.literal('[').sql(self.parts).literal(']')
else:
ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts))
return ctx
def __getitem__(self, value):
return ObjectSlice.create(self, value)
class IndexedFieldMixin(object):
default_index_type = 'GIN'
def __init__(self, *args, **kwargs):
kwargs.setdefault('index', True) # By default, use an index.
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
class ArrayField(IndexedFieldMixin, Field):
passthrough = True
def __init__(self, field_class=IntegerField, field_kwargs=None,
dimensions=1, convert_values=False, *args, **kwargs):
self.__field = field_class(**(field_kwargs or {}))
self.dimensions = dimensions
self.convert_values = convert_values
self.field_type = self.__field.field_type
super(ArrayField, self).__init__(*args, **kwargs)
def bind(self, model, name, set_attribute=True):
ret = super(ArrayField, self).bind(model, name, set_attribute)
self.__field.bind(model, '__array_%s' % name, False)
return ret
def ddl_datatype(self, ctx):
data_type = self.__field.ddl_datatype(ctx)
return NodeList((data_type, SQL('[]' * self.dimensions)), glue='')
def db_value(self, value):
if value is None or isinstance(value, Node):
return value
elif self.convert_values:
return self._process(self.__field.db_value, value, self.dimensions)
else:
return value if isinstance(value, list) else list(value)
def python_value(self, value):
if self.convert_values and value is not None:
conv = self.__field.python_value
if isinstance(value, list):
return self._process(conv, value, self.dimensions)
else:
return conv(value)
else:
return value
def _process(self, conv, value, dimensions):
dimensions -= 1
if dimensions == 0:
return [conv(v) for v in value]
else:
return [self._process(conv, v, dimensions) for v in value]
def __getitem__(self, value):
return ObjectSlice.create(self, value)
def _e(op):
def inner(self, rhs):
return Expression(self, op, ArrayValue(self, rhs))
return inner
__eq__ = _e(OP.EQ)
__ne__ = _e(OP.NE)
__gt__ = _e(OP.GT)
__ge__ = _e(OP.GTE)
__lt__ = _e(OP.LT)
__le__ = _e(OP.LTE)
__hash__ = Field.__hash__
def contains(self, *items):
return Expression(self, ACONTAINS, ArrayValue(self, items))
def contains_any(self, *items):
return Expression(self, ACONTAINS_ANY, ArrayValue(self, items))
class ArrayValue(Node):
def __init__(self, field, value):
self.field = field
self.value = value
def __sql__(self, ctx):
return (ctx
.sql(Value(self.value, unpack=False))
.literal('::')
.sql(self.field.ddl_datatype(ctx)))
class DateTimeTZField(DateTimeField):
field_type = 'TIMESTAMPTZ'
class HStoreField(IndexedFieldMixin, Field):
field_type = 'HSTORE'
__hash__ = Field.__hash__
def __getitem__(self, key):
return Expression(self, HKEY, Value(key))
def keys(self):
return fn.akeys(self)
def values(self):
return fn.avals(self)
def items(self):
return fn.hstore_to_matrix(self)
def slice(self, *args):
return fn.slice(self, Value(list(args), unpack=False))
def exists(self, key):
return fn.exist(self, key)
def defined(self, key):
return fn.defined(self, key)
def update(self, **data):
return Expression(self, HUPDATE, data)
def delete(self, *keys):
return fn.delete(self, Value(list(keys), unpack=False))
def contains(self, value):
if isinstance(value, dict):
rhs = Value(value, unpack=False)
return Expression(self, HCONTAINS_DICT, rhs)
elif isinstance(value, (list, tuple)):
rhs = Value(value, unpack=False)
return Expression(self, HCONTAINS_KEYS, rhs)
return Expression(self, HCONTAINS_KEY, value)
def contains_any(self, *keys):
return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys),
unpack=False))
class JSONField(Field):
field_type = 'JSON'
_json_datatype = 'json'
def __init__(self, dumps=None, *args, **kwargs):
if Json is None:
raise Exception('Your version of psycopg2 does not support JSON.')
self.dumps = dumps or json.dumps
super(JSONField, self).__init__(*args, **kwargs)
def db_value(self, value):
if value is None:
return value
if not isinstance(value, Json):
return Cast(self.dumps(value), self._json_datatype)
return value
def __getitem__(self, value):
return JsonLookup(self, [value])
def path(self, *keys):
return JsonPath(self, keys)
def concat(self, value):
if not isinstance(value, Node):
value = Json(value)
return super(JSONField, self).concat(value)
def cast_jsonb(node):
return NodeList((node, SQL('::jsonb')), glue='')
class BinaryJSONField(IndexedFieldMixin, JSONField):
field_type = 'JSONB'
_json_datatype = 'jsonb'
__hash__ = Field.__hash__
def contains(self, other):
if isinstance(other, (list, dict)):
return Expression(self, JSONB_CONTAINS, Json(other))
elif isinstance(other, JSONField):
return Expression(self, JSONB_CONTAINS, other)
return Expression(cast_jsonb(self), JSONB_EXISTS, other)
def contained_by(self, other):
return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other))
def contains_any(self, *items):
return Expression(
cast_jsonb(self),
JSONB_CONTAINS_ANY_KEY,
Value(list(items), unpack=False))
def contains_all(self, *items):
return Expression(
cast_jsonb(self),
JSONB_CONTAINS_ALL_KEYS,
Value(list(items), unpack=False))
def has_key(self, key):
return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key)
def remove(self, *items):
return Expression(
cast_jsonb(self),
JSONB_REMOVE,
Value(list(items), unpack=False))
class TSVectorField(IndexedFieldMixin, TextField):
field_type = 'TSVECTOR'
__hash__ = Field.__hash__
def match(self, query, language=None, plain=False):
params = (language, query) if language is not None else (query,)
func = fn.plainto_tsquery if plain else fn.to_tsquery
return Expression(self, TS_MATCH, func(*params))
def Match(field, query, language=None):
params = (language, query) if language is not None else (query,)
field_params = (language, field) if language is not None else (field,)
return Expression(
fn.to_tsvector(*field_params),
TS_MATCH,
fn.to_tsquery(*params))
class IntervalField(Field):
field_type = 'INTERVAL'
class FetchManyCursor(object):
__slots__ = ('cursor', 'array_size', 'exhausted', 'iterable')
def __init__(self, cursor, array_size=None):
self.cursor = cursor
self.array_size = array_size or cursor.itersize
self.exhausted = False
self.iterable = self.row_gen()
@property
def description(self):
return self.cursor.description
def close(self):
self.cursor.close()
def row_gen(self):
while True:
rows = self.cursor.fetchmany(self.array_size)
if not rows:
return
for row in rows:
yield row
def fetchone(self):
if self.exhausted:
return
try:
return next(self.iterable)
except StopIteration:
self.exhausted = True
class ServerSideQuery(Node):
def __init__(self, query, array_size=None):
self.query = query
self.array_size = array_size
self._cursor_wrapper = None
def __sql__(self, ctx):
return self.query.__sql__(ctx)
def __iter__(self):
if self._cursor_wrapper is None:
self._execute(self.query._database)
return iter(self._cursor_wrapper.iterator())
def _execute(self, database):
if self._cursor_wrapper is None:
cursor = database.execute(self.query, named_cursor=True,
array_size=self.array_size)
self._cursor_wrapper = self.query._get_cursor_wrapper(cursor)
return self._cursor_wrapper
def ServerSide(query, database=None, array_size=None):
if database is None:
database = query._database
with database.transaction():
server_side_query = ServerSideQuery(query, array_size=array_size)
for row in server_side_query:
yield row
class _empty_object(object):
__slots__ = ()
def __nonzero__(self):
return False
__bool__ = __nonzero__
__named_cursor__ = _empty_object()
class PostgresqlExtDatabase(PostgresqlDatabase):
def __init__(self, *args, **kwargs):
self._register_hstore = kwargs.pop('register_hstore', False)
self._server_side_cursors = kwargs.pop('server_side_cursors', False)
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
def _connect(self):
conn = super(PostgresqlExtDatabase, self)._connect()
if self._register_hstore:
register_hstore(conn, globally=True)
return conn
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
if commit is __named_cursor__:
return self._state.conn.cursor(name=str(uuid.uuid1()))
return self._state.conn.cursor()
def execute(self, query, commit=SENTINEL, named_cursor=False,
array_size=None, **context_options):
ctx = self.get_sql_context(**context_options)
sql, params = ctx.sql(query).query()
named_cursor = named_cursor or (self._server_side_cursors and
sql[:6].lower() == 'select')
if named_cursor:
commit = __named_cursor__
cursor = self.execute_sql(sql, params, commit=commit)
if named_cursor:
cursor = FetchManyCursor(cursor, array_size)
return cursor

@ -0,0 +1,833 @@
try:
from collections import OrderedDict
except ImportError:
OrderedDict = dict
from collections import namedtuple
from inspect import isclass
import re
from peewee import *
from peewee import _StringField
from peewee import _query_val_transform
from peewee import CommaNodeList
from peewee import SCOPE_VALUES
from peewee import make_snake_case
from peewee import text_type
try:
from pymysql.constants import FIELD_TYPE
except ImportError:
try:
from MySQLdb.constants import FIELD_TYPE
except ImportError:
FIELD_TYPE = None
try:
from playhouse import postgres_ext
except ImportError:
postgres_ext = None
try:
from playhouse.cockroachdb import CockroachDatabase
except ImportError:
CockroachDatabase = None
RESERVED_WORDS = set([
'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if',
'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise',
'return', 'try', 'while', 'with', 'yield',
])
class UnknownField(object):
pass
class Column(object):
"""
Store metadata about a database column.
"""
primary_key_types = (IntegerField, AutoField)
def __init__(self, name, field_class, raw_column_type, nullable,
primary_key=False, column_name=None, index=False,
unique=False, default=None, extra_parameters=None):
self.name = name
self.field_class = field_class
self.raw_column_type = raw_column_type
self.nullable = nullable
self.primary_key = primary_key
self.column_name = column_name
self.index = index
self.unique = unique
self.default = default
self.extra_parameters = extra_parameters
# Foreign key metadata.
self.rel_model = None
self.related_name = None
self.to_field = None
def __repr__(self):
attrs = [
'field_class',
'raw_column_type',
'nullable',
'primary_key',
'column_name']
keyword_args = ', '.join(
'%s=%s' % (attr, getattr(self, attr))
for attr in attrs)
return 'Column(%s, %s)' % (self.name, keyword_args)
def get_field_parameters(self):
params = {}
if self.extra_parameters is not None:
params.update(self.extra_parameters)
# Set up default attributes.
if self.nullable:
params['null'] = True
if self.field_class is ForeignKeyField or self.name != self.column_name:
params['column_name'] = "'%s'" % self.column_name
if self.primary_key and not issubclass(self.field_class, AutoField):
params['primary_key'] = True
if self.default is not None:
params['constraints'] = '[SQL("DEFAULT %s")]' % self.default
# Handle ForeignKeyField-specific attributes.
if self.is_foreign_key():
params['model'] = self.rel_model
if self.to_field:
params['field'] = "'%s'" % self.to_field
if self.related_name:
params['backref'] = "'%s'" % self.related_name
# Handle indexes on column.
if not self.is_primary_key():
if self.unique:
params['unique'] = 'True'
elif self.index and not self.is_foreign_key():
params['index'] = 'True'
return params
def is_primary_key(self):
return self.field_class is AutoField or self.primary_key
def is_foreign_key(self):
return self.field_class is ForeignKeyField
def is_self_referential_fk(self):
return (self.field_class is ForeignKeyField and
self.rel_model == "'self'")
def set_foreign_key(self, foreign_key, model_names, dest=None,
related_name=None):
self.foreign_key = foreign_key
self.field_class = ForeignKeyField
if foreign_key.dest_table == foreign_key.table:
self.rel_model = "'self'"
else:
self.rel_model = model_names[foreign_key.dest_table]
self.to_field = dest and dest.name or None
self.related_name = related_name or None
def get_field(self):
# Generate the field definition for this column.
field_params = {}
for key, value in self.get_field_parameters().items():
if isclass(value) and issubclass(value, Field):
value = value.__name__
field_params[key] = value
param_str = ', '.join('%s=%s' % (k, v)
for k, v in sorted(field_params.items()))
field = '%s = %s(%s)' % (
self.name,
self.field_class.__name__,
param_str)
if self.field_class is UnknownField:
field = '%s # %s' % (field, self.raw_column_type)
return field
class Metadata(object):
column_map = {}
extension_import = ''
def __init__(self, database):
self.database = database
self.requires_extension = False
def execute(self, sql, *params):
return self.database.execute_sql(sql, params)
def get_columns(self, table, schema=None):
metadata = OrderedDict(
(metadata.name, metadata)
for metadata in self.database.get_columns(table, schema))
# Look up the actual column type for each column.
column_types, extra_params = self.get_column_types(table, schema)
# Look up the primary keys.
pk_names = self.get_primary_keys(table, schema)
if len(pk_names) == 1:
pk = pk_names[0]
if column_types[pk] is IntegerField:
column_types[pk] = AutoField
elif column_types[pk] is BigIntegerField:
column_types[pk] = BigAutoField
columns = OrderedDict()
for name, column_data in metadata.items():
field_class = column_types[name]
default = self._clean_default(field_class, column_data.default)
columns[name] = Column(
name,
field_class=field_class,
raw_column_type=column_data.data_type,
nullable=column_data.null,
primary_key=column_data.primary_key,
column_name=name,
default=default,
extra_parameters=extra_params.get(name))
return columns
def get_column_types(self, table, schema=None):
raise NotImplementedError
def _clean_default(self, field_class, default):
if default is None or field_class in (AutoField, BigAutoField) or \
default.lower() == 'null':
return
if issubclass(field_class, _StringField) and \
isinstance(default, text_type) and not default.startswith("'"):
default = "'%s'" % default
return default or "''"
def get_foreign_keys(self, table, schema=None):
return self.database.get_foreign_keys(table, schema)
def get_primary_keys(self, table, schema=None):
return self.database.get_primary_keys(table, schema)
def get_indexes(self, table, schema=None):
return self.database.get_indexes(table, schema)
class PostgresqlMetadata(Metadata):
column_map = {
16: BooleanField,
17: BlobField,
20: BigIntegerField,
21: SmallIntegerField,
23: IntegerField,
25: TextField,
700: FloatField,
701: DoubleField,
1042: CharField, # blank-padded CHAR
1043: CharField,
1082: DateField,
1114: DateTimeField,
1184: DateTimeField,
1083: TimeField,
1266: TimeField,
1700: DecimalField,
2950: UUIDField, # UUID
}
array_types = {
1000: BooleanField,
1001: BlobField,
1005: SmallIntegerField,
1007: IntegerField,
1009: TextField,
1014: CharField,
1015: CharField,
1016: BigIntegerField,
1115: DateTimeField,
1182: DateField,
1183: TimeField,
2951: UUIDField,
}
extension_import = 'from playhouse.postgres_ext import *'
def __init__(self, database):
super(PostgresqlMetadata, self).__init__(database)
if postgres_ext is not None:
# Attempt to add types like HStore and JSON.
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
' from pg_type;')
results = cursor.fetchall()
for oid, typname, formatted_type in results:
if typname == 'json':
self.column_map[oid] = postgres_ext.JSONField
elif typname == 'jsonb':
self.column_map[oid] = postgres_ext.BinaryJSONField
elif typname == 'hstore':
self.column_map[oid] = postgres_ext.HStoreField
elif typname == 'tsvector':
self.column_map[oid] = postgres_ext.TSVectorField
for oid in self.array_types:
self.column_map[oid] = postgres_ext.ArrayField
def get_column_types(self, table, schema):
column_types = {}
extra_params = {}
extension_types = set((
postgres_ext.ArrayField,
postgres_ext.BinaryJSONField,
postgres_ext.JSONField,
postgres_ext.TSVectorField,
postgres_ext.HStoreField)) if postgres_ext is not None else set()
# Look up the actual column type for each column.
identifier = '%s."%s"' % (schema, table)
cursor = self.execute(
'SELECT attname, atttypid FROM pg_catalog.pg_attribute '
'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0)
# Store column metadata in dictionary keyed by column name.
for name, oid in cursor.fetchall():
column_types[name] = self.column_map.get(oid, UnknownField)
if column_types[name] in extension_types:
self.requires_extension = True
if oid in self.array_types:
extra_params[name] = {'field_class': self.array_types[oid]}
return column_types, extra_params
def get_columns(self, table, schema=None):
schema = schema or 'public'
return super(PostgresqlMetadata, self).get_columns(table, schema)
def get_foreign_keys(self, table, schema=None):
schema = schema or 'public'
return super(PostgresqlMetadata, self).get_foreign_keys(table, schema)
def get_primary_keys(self, table, schema=None):
schema = schema or 'public'
return super(PostgresqlMetadata, self).get_primary_keys(table, schema)
def get_indexes(self, table, schema=None):
schema = schema or 'public'
return super(PostgresqlMetadata, self).get_indexes(table, schema)
class CockroachDBMetadata(PostgresqlMetadata):
# CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to
# regular IntegerField.
column_map = PostgresqlMetadata.column_map.copy()
column_map[20] = IntegerField
array_types = PostgresqlMetadata.array_types.copy()
array_types[1016] = IntegerField
extension_import = 'from playhouse.cockroachdb import *'
def __init__(self, database):
Metadata.__init__(self, database)
self.requires_extension = True
if postgres_ext is not None:
# Attempt to add JSON types.
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
' from pg_type;')
results = cursor.fetchall()
for oid, typname, formatted_type in results:
if typname == 'jsonb':
self.column_map[oid] = postgres_ext.BinaryJSONField
for oid in self.array_types:
self.column_map[oid] = postgres_ext.ArrayField
class MySQLMetadata(Metadata):
if FIELD_TYPE is None:
column_map = {}
else:
column_map = {
FIELD_TYPE.BLOB: TextField,
FIELD_TYPE.CHAR: CharField,
FIELD_TYPE.DATE: DateField,
FIELD_TYPE.DATETIME: DateTimeField,
FIELD_TYPE.DECIMAL: DecimalField,
FIELD_TYPE.DOUBLE: FloatField,
FIELD_TYPE.FLOAT: FloatField,
FIELD_TYPE.INT24: IntegerField,
FIELD_TYPE.LONG_BLOB: TextField,
FIELD_TYPE.LONG: IntegerField,
FIELD_TYPE.LONGLONG: BigIntegerField,
FIELD_TYPE.MEDIUM_BLOB: TextField,
FIELD_TYPE.NEWDECIMAL: DecimalField,
FIELD_TYPE.SHORT: IntegerField,
FIELD_TYPE.STRING: CharField,
FIELD_TYPE.TIMESTAMP: DateTimeField,
FIELD_TYPE.TIME: TimeField,
FIELD_TYPE.TINY_BLOB: TextField,
FIELD_TYPE.TINY: IntegerField,
FIELD_TYPE.VAR_STRING: CharField,
}
def __init__(self, database, **kwargs):
if 'password' in kwargs:
kwargs['passwd'] = kwargs.pop('password')
super(MySQLMetadata, self).__init__(database, **kwargs)
def get_column_types(self, table, schema=None):
column_types = {}
# Look up the actual column type for each column.
cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table)
# Store column metadata in dictionary keyed by column name.
for column_description in cursor.description:
name, type_code = column_description[:2]
column_types[name] = self.column_map.get(type_code, UnknownField)
return column_types, {}
class SqliteMetadata(Metadata):
column_map = {
'bigint': BigIntegerField,
'blob': BlobField,
'bool': BooleanField,
'boolean': BooleanField,
'char': CharField,
'date': DateField,
'datetime': DateTimeField,
'decimal': DecimalField,
'float': FloatField,
'integer': IntegerField,
'integer unsigned': IntegerField,
'int': IntegerField,
'long': BigIntegerField,
'numeric': DecimalField,
'real': FloatField,
'smallinteger': IntegerField,
'smallint': IntegerField,
'smallint unsigned': IntegerField,
'text': TextField,
'time': TimeField,
'varchar': CharField,
}
begin = '(?:["\[\(]+)?'
end = '(?:["\]\)]+)?'
re_foreign_key = (
'(?:FOREIGN KEY\s*)?'
'{begin}(.+?){end}\s+(?:.+\s+)?'
'references\s+{begin}(.+?){end}'
'\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end)
re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$'
def _map_col(self, column_type):
raw_column_type = column_type.lower()
if raw_column_type in self.column_map:
field_class = self.column_map[raw_column_type]
elif re.search(self.re_varchar, raw_column_type):
field_class = CharField
else:
column_type = re.sub('\(.+\)', '', raw_column_type)
if column_type == '':
field_class = BareField
else:
field_class = self.column_map.get(column_type, UnknownField)
return field_class
def get_column_types(self, table, schema=None):
column_types = {}
columns = self.database.get_columns(table)
for column in columns:
column_types[column.name] = self._map_col(column.data_type)
return column_types, {}
_DatabaseMetadata = namedtuple('_DatabaseMetadata', (
'columns',
'primary_keys',
'foreign_keys',
'model_names',
'indexes'))
class DatabaseMetadata(_DatabaseMetadata):
def multi_column_indexes(self, table):
accum = []
for index in self.indexes[table]:
if len(index.columns) > 1:
field_names = [self.columns[table][column].name
for column in index.columns
if column in self.columns[table]]
accum.append((field_names, index.unique))
return accum
def column_indexes(self, table):
accum = {}
for index in self.indexes[table]:
if len(index.columns) == 1:
accum[index.columns[0]] = index.unique
return accum
class Introspector(object):
pk_classes = [AutoField, IntegerField]
def __init__(self, metadata, schema=None):
self.metadata = metadata
self.schema = schema
def __repr__(self):
return '<Introspector: %s>' % self.metadata.database
@classmethod
def from_database(cls, database, schema=None):
if CockroachDatabase and isinstance(database, CockroachDatabase):
metadata = CockroachDBMetadata(database)
elif isinstance(database, PostgresqlDatabase):
metadata = PostgresqlMetadata(database)
elif isinstance(database, MySQLDatabase):
metadata = MySQLMetadata(database)
elif isinstance(database, SqliteDatabase):
metadata = SqliteMetadata(database)
else:
raise ValueError('Introspection not supported for %r' % database)
return cls(metadata, schema=schema)
def get_database_class(self):
return type(self.metadata.database)
def get_database_name(self):
return self.metadata.database.database
def get_database_kwargs(self):
return self.metadata.database.connect_params
def get_additional_imports(self):
if self.metadata.requires_extension:
return '\n' + self.metadata.extension_import
return ''
def make_model_name(self, table, snake_case=True):
if snake_case:
table = make_snake_case(table)
model = re.sub(r'[^\w]+', '', table)
model_name = ''.join(sub.title() for sub in model.split('_'))
if not model_name[0].isalpha():
model_name = 'T' + model_name
return model_name
def make_column_name(self, column, is_foreign_key=False, snake_case=True):
column = column.strip()
if snake_case:
column = make_snake_case(column)
column = column.lower()
if is_foreign_key:
# Strip "_id" from foreign keys, unless the foreign-key happens to
# be named "_id", in which case the name is retained.
column = re.sub('_id$', '', column) or column
# Remove characters that are invalid for Python identifiers.
column = re.sub(r'[^\w]+', '_', column)
if column in RESERVED_WORDS:
column += '_'
if len(column) and column[0].isdigit():
column = '_' + column
return column
def introspect(self, table_names=None, literal_column_names=False,
include_views=False, snake_case=True):
# Retrieve all the tables in the database.
tables = self.metadata.database.get_tables(schema=self.schema)
if include_views:
views = self.metadata.database.get_views(schema=self.schema)
tables.extend([view.name for view in views])
if table_names is not None:
tables = [table for table in tables if table in table_names]
table_set = set(tables)
# Store a mapping of table name -> dictionary of columns.
columns = {}
# Store a mapping of table name -> set of primary key columns.
primary_keys = {}
# Store a mapping of table -> foreign keys.
foreign_keys = {}
# Store a mapping of table name -> model name.
model_names = {}
# Store a mapping of table name -> indexes.
indexes = {}
# Gather the columns for each table.
for table in tables:
table_indexes = self.metadata.get_indexes(table, self.schema)
table_columns = self.metadata.get_columns(table, self.schema)
try:
foreign_keys[table] = self.metadata.get_foreign_keys(
table, self.schema)
except ValueError as exc:
err(*exc.args)
foreign_keys[table] = []
else:
# If there is a possibility we could exclude a dependent table,
# ensure that we introspect it so FKs will work.
if table_names is not None:
for foreign_key in foreign_keys[table]:
if foreign_key.dest_table not in table_set:
tables.append(foreign_key.dest_table)
table_set.add(foreign_key.dest_table)
model_names[table] = self.make_model_name(table, snake_case)
# Collect sets of all the column names as well as all the
# foreign-key column names.
lower_col_names = set(column_name.lower()
for column_name in table_columns)
fks = set(fk_col.column for fk_col in foreign_keys[table])
for col_name, column in table_columns.items():
if literal_column_names:
new_name = re.sub(r'[^\w]+', '_', col_name)
else:
new_name = self.make_column_name(col_name, col_name in fks,
snake_case)
# If we have two columns, "parent" and "parent_id", ensure
# that when we don't introduce naming conflicts.
lower_name = col_name.lower()
if lower_name.endswith('_id') and new_name in lower_col_names:
new_name = col_name.lower()
column.name = new_name
for index in table_indexes:
if len(index.columns) == 1:
column = index.columns[0]
if column in table_columns:
table_columns[column].unique = index.unique
table_columns[column].index = True
primary_keys[table] = self.metadata.get_primary_keys(
table, self.schema)
columns[table] = table_columns
indexes[table] = table_indexes
# Gather all instances where we might have a `related_name` conflict,
# either due to multiple FKs on a table pointing to the same table,
# or a related_name that would conflict with an existing field.
related_names = {}
sort_fn = lambda foreign_key: foreign_key.column
for table in tables:
models_referenced = set()
for foreign_key in sorted(foreign_keys[table], key=sort_fn):
try:
column = columns[table][foreign_key.column]
except KeyError:
continue
dest_table = foreign_key.dest_table
if dest_table in models_referenced:
related_names[column] = '%s_%s_set' % (
dest_table,
column.name)
else:
models_referenced.add(dest_table)
# On the second pass convert all foreign keys.
for table in tables:
for foreign_key in foreign_keys[table]:
src = columns[foreign_key.table][foreign_key.column]
try:
dest = columns[foreign_key.dest_table][
foreign_key.dest_column]
except KeyError:
dest = None
src.set_foreign_key(
foreign_key=foreign_key,
model_names=model_names,
dest=dest,
related_name=related_names.get(src))
return DatabaseMetadata(
columns,
primary_keys,
foreign_keys,
model_names,
indexes)
def generate_models(self, skip_invalid=False, table_names=None,
literal_column_names=False, bare_fields=False,
include_views=False):
database = self.introspect(table_names, literal_column_names,
include_views)
models = {}
class BaseModel(Model):
class Meta:
database = self.metadata.database
schema = self.schema
def _create_model(table, models):
for foreign_key in database.foreign_keys[table]:
dest = foreign_key.dest_table
if dest not in models and dest != table:
_create_model(dest, models)
primary_keys = []
columns = database.columns[table]
for column_name, column in columns.items():
if column.primary_key:
primary_keys.append(column.name)
multi_column_indexes = database.multi_column_indexes(table)
column_indexes = database.column_indexes(table)
class Meta:
indexes = multi_column_indexes
table_name = table
# Fix models with multi-column primary keys.
composite_key = False
if len(primary_keys) == 0:
primary_keys = columns.keys()
if len(primary_keys) > 1:
Meta.primary_key = CompositeKey(*[
field.name for col, field in columns.items()
if col in primary_keys])
composite_key = True
attrs = {'Meta': Meta}
for column_name, column in columns.items():
FieldClass = column.field_class
if FieldClass is not ForeignKeyField and bare_fields:
FieldClass = BareField
elif FieldClass is UnknownField:
FieldClass = BareField
params = {
'column_name': column_name,
'null': column.nullable}
if column.primary_key and composite_key:
if FieldClass is AutoField:
FieldClass = IntegerField
params['primary_key'] = False
elif column.primary_key and FieldClass is not AutoField:
params['primary_key'] = True
if column.is_foreign_key():
if column.is_self_referential_fk():
params['model'] = 'self'
else:
dest_table = column.foreign_key.dest_table
params['model'] = models[dest_table]
if column.to_field:
params['field'] = column.to_field
# Generate a unique related name.
params['backref'] = '%s_%s_rel' % (table, column_name)
if column.default is not None:
constraint = SQL('DEFAULT %s' % column.default)
params['constraints'] = [constraint]
if column_name in column_indexes and not \
column.is_primary_key():
if column_indexes[column_name]:
params['unique'] = True
elif not column.is_foreign_key():
params['index'] = True
attrs[column.name] = FieldClass(**params)
try:
models[table] = type(str(table), (BaseModel,), attrs)
except ValueError:
if not skip_invalid:
raise
# Actually generate Model classes.
for table, model in sorted(database.model_names.items()):
if table not in models:
_create_model(table, models)
return models
def introspect(database, schema=None):
introspector = Introspector.from_database(database, schema=schema)
return introspector.introspect()
def generate_models(database, schema=None, **options):
introspector = Introspector.from_database(database, schema=schema)
return introspector.generate_models(**options)
def print_model(model, indexes=True, inline_indexes=False):
print(model._meta.name)
for field in model._meta.sorted_fields:
parts = [' %s %s' % (field.name, field.field_type)]
if field.primary_key:
parts.append(' PK')
elif inline_indexes:
if field.unique:
parts.append(' UNIQUE')
elif field.index:
parts.append(' INDEX')
if isinstance(field, ForeignKeyField):
parts.append(' FK: %s.%s' % (field.rel_model.__name__,
field.rel_field.name))
print(''.join(parts))
if indexes:
index_list = model._meta.fields_to_index()
if not index_list:
return
print('\nindex(es)')
for index in index_list:
parts = [' ']
ctx = model._meta.database.get_sql_context()
with ctx.scope_values(param='%s', quote='""'):
ctx.sql(CommaNodeList(index._expressions))
if index._where:
ctx.literal(' WHERE ')
ctx.sql(index._where)
sql, params = ctx.query()
clean = sql % tuple(map(_query_val_transform, params))
parts.append(clean.replace('"', ''))
if index._unique:
parts.append(' UNIQUE')
print(''.join(parts))
def get_table_sql(model):
sql, params = model._schema._create_table().query()
if model._meta.database.param != '%s':
sql = sql.replace(model._meta.database.param, '%s')
# Format and indent the table declaration, simplest possible approach.
match_obj = re.match('^(.+?\()(.+)(\).*)', sql)
create, columns, extra = match_obj.groups()
indented = ',\n'.join(' %s' % column for column in columns.split(', '))
clean = '\n'.join((create, indented, extra)).strip()
return clean % tuple(map(_query_val_transform, params))
def print_table_sql(model):
print(get_table_sql(model))

@ -0,0 +1,252 @@
from peewee import *
from peewee import Alias
from peewee import CompoundSelectQuery
from peewee import SENTINEL
from peewee import callable_
_clone_set = lambda s: set(s) if s else set()
def model_to_dict(model, recurse=True, backrefs=False, only=None,
exclude=None, seen=None, extra_attrs=None,
fields_from_query=None, max_depth=None, manytomany=False):
"""
Convert a model instance (and any related objects) to a dictionary.
:param bool recurse: Whether foreign-keys should be recursed.
:param bool backrefs: Whether lists of related objects should be recursed.
:param only: A list (or set) of field instances indicating which fields
should be included.
:param exclude: A list (or set) of field instances that should be
excluded from the dictionary.
:param list extra_attrs: Names of model instance attributes or methods
that should be included.
:param SelectQuery fields_from_query: Query that was source of model. Take
fields explicitly selected by the query and serialize them.
:param int max_depth: Maximum depth to recurse, value <= 0 means no max.
:param bool manytomany: Process many-to-many fields.
"""
max_depth = -1 if max_depth is None else max_depth
if max_depth == 0:
recurse = False
only = _clone_set(only)
extra_attrs = _clone_set(extra_attrs)
should_skip = lambda n: (n in exclude) or (only and (n not in only))
if fields_from_query is not None:
for item in fields_from_query._returning:
if isinstance(item, Field):
only.add(item)
elif isinstance(item, Alias):
extra_attrs.add(item._alias)
data = {}
exclude = _clone_set(exclude)
seen = _clone_set(seen)
exclude |= seen
model_class = type(model)
if manytomany:
for name, m2m in model._meta.manytomany.items():
if should_skip(name):
continue
exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref]))
for fkf in m2m.through_model._meta.refs:
exclude.add(fkf)
accum = []
for rel_obj in getattr(model, name):
accum.append(model_to_dict(
rel_obj,
recurse=recurse,
backrefs=backrefs,
only=only,
exclude=exclude,
max_depth=max_depth - 1))
data[name] = accum
for field in model._meta.sorted_fields:
if should_skip(field):
continue
field_data = model.__data__.get(field.name)
if isinstance(field, ForeignKeyField) and recurse:
if field_data is not None:
seen.add(field)
rel_obj = getattr(model, field.name)
field_data = model_to_dict(
rel_obj,
recurse=recurse,
backrefs=backrefs,
only=only,
exclude=exclude,
seen=seen,
max_depth=max_depth - 1)
else:
field_data = None
data[field.name] = field_data
if extra_attrs:
for attr_name in extra_attrs:
attr = getattr(model, attr_name)
if callable_(attr):
data[attr_name] = attr()
else:
data[attr_name] = attr
if backrefs and recurse:
for foreign_key, rel_model in model._meta.backrefs.items():
if foreign_key.backref == '+': continue
descriptor = getattr(model_class, foreign_key.backref)
if descriptor in exclude or foreign_key in exclude:
continue
if only and (descriptor not in only) and (foreign_key not in only):
continue
accum = []
exclude.add(foreign_key)
related_query = getattr(model, foreign_key.backref)
for rel_obj in related_query:
accum.append(model_to_dict(
rel_obj,
recurse=recurse,
backrefs=backrefs,
only=only,
exclude=exclude,
max_depth=max_depth - 1))
data[foreign_key.backref] = accum
return data
def update_model_from_dict(instance, data, ignore_unknown=False):
meta = instance._meta
backrefs = dict([(fk.backref, fk) for fk in meta.backrefs])
for key, value in data.items():
if key in meta.combined:
field = meta.combined[key]
is_backref = False
elif key in backrefs:
field = backrefs[key]
is_backref = True
elif ignore_unknown:
setattr(instance, key, value)
continue
else:
raise AttributeError('Unrecognized attribute "%s" for model '
'class %s.' % (key, type(instance)))
is_foreign_key = isinstance(field, ForeignKeyField)
if not is_backref and is_foreign_key and isinstance(value, dict):
try:
rel_instance = instance.__rel__[field.name]
except KeyError:
rel_instance = field.rel_model()
setattr(
instance,
field.name,
update_model_from_dict(rel_instance, value, ignore_unknown))
elif is_backref and isinstance(value, (list, tuple)):
instances = [
dict_to_model(field.model, row_data, ignore_unknown)
for row_data in value]
for rel_instance in instances:
setattr(rel_instance, field.name, instance)
setattr(instance, field.backref, instances)
else:
setattr(instance, field.name, value)
return instance
def dict_to_model(model_class, data, ignore_unknown=False):
return update_model_from_dict(model_class(), data, ignore_unknown)
class ReconnectMixin(object):
"""
Mixin class that attempts to automatically reconnect to the database under
certain error conditions.
For example, MySQL servers will typically close connections that are idle
for 28800 seconds ("wait_timeout" setting). If your application makes use
of long-lived connections, you may find your connections are closed after
a period of no activity. This mixin will attempt to reconnect automatically
when these errors occur.
This mixin class probably should not be used with Postgres (unless you
REALLY know what you are doing) and definitely has no business being used
with Sqlite. If you wish to use with Postgres, you will need to adapt the
`reconnect_errors` attribute to something appropriate for Postgres.
"""
reconnect_errors = (
# Error class, error message fragment (or empty string for all).
(OperationalError, '2006'), # MySQL server has gone away.
(OperationalError, '2013'), # Lost connection to MySQL server.
(OperationalError, '2014'), # Commands out of sync.
# mysql-connector raises a slightly different error when an idle
# connection is terminated by the server. This is equivalent to 2013.
(OperationalError, 'MySQL Connection not available.'),
)
def __init__(self, *args, **kwargs):
super(ReconnectMixin, self).__init__(*args, **kwargs)
# Normalize the reconnect errors to a more efficient data-structure.
self._reconnect_errors = {}
for exc_class, err_fragment in self.reconnect_errors:
self._reconnect_errors.setdefault(exc_class, [])
self._reconnect_errors[exc_class].append(err_fragment.lower())
def execute_sql(self, sql, params=None, commit=SENTINEL):
try:
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
except Exception as exc:
exc_class = type(exc)
if exc_class not in self._reconnect_errors:
raise exc
exc_repr = str(exc).lower()
for err_fragment in self._reconnect_errors[exc_class]:
if err_fragment in exc_repr:
break
else:
raise exc
if not self.is_closed():
self.close()
self.connect()
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
def resolve_multimodel_query(query, key='_model_identifier'):
mapping = {}
accum = [query]
while accum:
curr = accum.pop()
if isinstance(curr, CompoundSelectQuery):
accum.extend((curr.lhs, curr.rhs))
continue
model_class = curr.model
name = model_class._meta.table_name
mapping[name] = model_class
curr._returning.append(Value(name).alias(key))
def wrapped_iterator():
for row in query.dicts().iterator():
identifier = row.pop(key)
model = mapping[identifier]
yield model(**row)
return wrapped_iterator()

@ -0,0 +1,79 @@
"""
Provide django-style hooks for model events.
"""
from peewee import Model as _Model
class Signal(object):
def __init__(self):
self._flush()
def _flush(self):
self._receivers = set()
self._receiver_list = []
def connect(self, receiver, name=None, sender=None):
name = name or receiver.__name__
key = (name, sender)
if key not in self._receivers:
self._receivers.add(key)
self._receiver_list.append((name, receiver, sender))
else:
raise ValueError('receiver named %s (for sender=%s) already '
'connected' % (name, sender or 'any'))
def disconnect(self, receiver=None, name=None, sender=None):
if receiver:
name = name or receiver.__name__
if not name:
raise ValueError('a receiver or a name must be provided')
key = (name, sender)
if key not in self._receivers:
raise ValueError('receiver named %s for sender=%s not found.' %
(name, sender or 'any'))
self._receivers.remove(key)
self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list
if n != name and s != sender]
def __call__(self, name=None, sender=None):
def decorator(fn):
self.connect(fn, name, sender)
return fn
return decorator
def send(self, instance, *args, **kwargs):
sender = type(instance)
responses = []
for n, r, s in self._receiver_list:
if s is None or isinstance(instance, s):
responses.append((r, r(sender, instance, *args, **kwargs)))
return responses
pre_save = Signal()
post_save = Signal()
pre_delete = Signal()
post_delete = Signal()
pre_init = Signal()
class Model(_Model):
def __init__(self, *args, **kwargs):
super(Model, self).__init__(*args, **kwargs)
pre_init.send(self)
def save(self, *args, **kwargs):
pk_value = self._pk if self._meta.primary_key else True
created = kwargs.get('force_insert', False) or not bool(pk_value)
pre_save.send(self, created=created)
ret = super(Model, self).save(*args, **kwargs)
post_save.send(self, created=created)
return ret
def delete_instance(self, *args, **kwargs):
pre_delete.send(self)
ret = super(Model, self).delete_instance(*args, **kwargs)
post_delete.send(self)
return ret

@ -0,0 +1,103 @@
"""
Peewee integration with pysqlcipher.
Project page: https://github.com/leapcode/pysqlcipher/
**WARNING!!! EXPERIMENTAL!!!**
* Although this extention's code is short, it has not been properly
peer-reviewed yet and may have introduced vulnerabilities.
Also note that this code relies on pysqlcipher and sqlcipher, and
the code there might have vulnerabilities as well, but since these
are widely used crypto modules, we can expect "short zero days" there.
Example usage:
from peewee.playground.ciphersql_ext import SqlCipherDatabase
db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real")
* `passphrase`: should be "long enough".
Note that *length beats vocabulary* (much exponential), and even
a lowercase-only passphrase like easytorememberyethardforotherstoguess
packs more noise than 8 random printable characters and *can* be memorized.
When opening an existing database, passphrase should be the one used when the
database was created. If the passphrase is incorrect, an exception will only be
raised **when you access the database**.
If you need to ask for an interactive passphrase, here's example code you can
put after the `db = ...` line:
try: # Just access the database so that it checks the encryption.
db.get_tables()
# We're looking for a DatabaseError with a specific error message.
except peewee.DatabaseError as e:
# Check whether the message *means* "passphrase is wrong"
if e.args[0] == 'file is encrypted or is not a database':
raise Exception('Developer should Prompt user for passphrase '
'again.')
else:
# A different DatabaseError. Raise it.
raise e
See a more elaborate example with this code at
https://gist.github.com/thedod/11048875
"""
import datetime
import decimal
import sys
from peewee import *
from playhouse.sqlite_ext import SqliteExtDatabase
if sys.version_info[0] != 3:
from pysqlcipher import dbapi2 as sqlcipher
else:
try:
from sqlcipher3 import dbapi2 as sqlcipher
except ImportError:
from pysqlcipher3 import dbapi2 as sqlcipher
sqlcipher.register_adapter(decimal.Decimal, str)
sqlcipher.register_adapter(datetime.date, str)
sqlcipher.register_adapter(datetime.time, str)
class _SqlCipherDatabase(object):
def _connect(self):
params = dict(self.connect_params)
passphrase = params.pop('passphrase', '').replace("'", "''")
conn = sqlcipher.connect(self.database, isolation_level=None, **params)
try:
if passphrase:
conn.execute("PRAGMA key='%s'" % passphrase)
self._add_conn_hooks(conn)
except:
conn.close()
raise
return conn
def set_passphrase(self, passphrase):
if not self.is_closed():
raise ImproperlyConfigured('Cannot set passphrase when database '
'is open. To change passphrase of an '
'open database use the rekey() method.')
self.connect_params['passphrase'] = passphrase
def rekey(self, passphrase):
if self.is_closed():
self.connect()
self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''"))
self.connect_params['passphrase'] = passphrase
return True
class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase):
pass
class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase):
pass

@ -0,0 +1,123 @@
from peewee import *
from playhouse.sqlite_ext import JSONField
class BaseChangeLog(Model):
timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')])
action = TextField()
table = TextField()
primary_key = IntegerField()
changes = JSONField()
class ChangeLog(object):
# Model class that will serve as the base for the changelog. This model
# will be subclassed and mapped to your application database.
base_model = BaseChangeLog
# Template for the triggers that handle updating the changelog table.
# table: table name
# action: insert / update / delete
# new_old: NEW or OLD (OLD is for DELETE)
# primary_key: table primary key column name
# column_array: output of build_column_array()
# change_table: changelog table name
template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s
AFTER %(action)s ON %(table)s
BEGIN
INSERT INTO %(change_table)s
("action", "table", "primary_key", "changes")
SELECT
'%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes"
FROM (
SELECT json_group_object(
col,
json_array("oldval", "newval")) AS "changes"
FROM (
SELECT json_extract(value, '$[0]') as "col",
json_extract(value, '$[1]') as "oldval",
json_extract(value, '$[2]') as "newval"
FROM json_each(json_array(%(column_array)s))
WHERE "oldval" IS NOT "newval"
)
);
END;"""
drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s'
_actions = ('INSERT', 'UPDATE', 'DELETE')
def __init__(self, db, table_name='changelog'):
self.db = db
self.table_name = table_name
def _build_column_array(self, model, use_old, use_new, skip_fields=None):
# Builds a list of SQL expressions for each field we are tracking. This
# is used as the data source for change tracking in our trigger.
col_array = []
for field in model._meta.sorted_fields:
if field.primary_key:
continue
if skip_fields is not None and field.name in skip_fields:
continue
column = field.column_name
new = 'NULL' if not use_new else 'NEW."%s"' % column
old = 'NULL' if not use_old else 'OLD."%s"' % column
if isinstance(field, JSONField):
# Ensure that values are cast to JSON so that the serialization
# is preserved when calculating the old / new.
if use_old: old = 'json(%s)' % old
if use_new: new = 'json(%s)' % new
col_array.append("json_array('%s', %s, %s)" % (column, old, new))
return ', '.join(col_array)
def trigger_sql(self, model, action, skip_fields=None):
assert action in self._actions
use_old = action != 'INSERT'
use_new = action != 'DELETE'
cols = self._build_column_array(model, use_old, use_new, skip_fields)
return self.template % {
'table': model._meta.table_name,
'action': action,
'new_old': 'NEW' if action != 'DELETE' else 'OLD',
'primary_key': model._meta.primary_key.column_name,
'column_array': cols,
'change_table': self.table_name}
def drop_trigger_sql(self, model, action):
assert action in self._actions
return self.drop_template % {
'table': model._meta.table_name,
'action': action}
@property
def model(self):
if not hasattr(self, '_changelog_model'):
class ChangeLog(self.base_model):
class Meta:
database = self.db
table_name = self.table_name
self._changelog_model = ChangeLog
return self._changelog_model
def install(self, model, skip_fields=None, drop=True, insert=True,
update=True, delete=True, create_table=True):
ChangeLog = self.model
if create_table:
ChangeLog.create_table()
actions = list(zip((insert, update, delete), self._actions))
if drop:
for _, action in actions:
self.db.execute_sql(self.drop_trigger_sql(model, action))
for enabled, action in actions:
if enabled:
sql = self.trigger_sql(model, action, skip_fields)
self.db.execute_sql(sql)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,536 @@
import datetime
import hashlib
import heapq
import math
import os
import random
import re
import sys
import threading
import zlib
try:
from collections import Counter
except ImportError:
Counter = None
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
try:
from playhouse._sqlite_ext import TableFunction
except ImportError:
TableFunction = None
SQLITE_DATETIME_FORMATS = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d',
'%H:%M:%S',
'%H:%M:%S.%f',
'%H:%M')
from peewee import format_date_time
def format_date_time_sqlite(date_value):
return format_date_time(date_value, SQLITE_DATETIME_FORMATS)
try:
from playhouse import _sqlite_udf as cython_udf
except ImportError:
cython_udf = None
# Group udf by function.
CONTROL_FLOW = 'control_flow'
DATE = 'date'
FILE = 'file'
HELPER = 'helpers'
MATH = 'math'
STRING = 'string'
AGGREGATE_COLLECTION = {}
TABLE_FUNCTION_COLLECTION = {}
UDF_COLLECTION = {}
class synchronized_dict(dict):
def __init__(self, *args, **kwargs):
super(synchronized_dict, self).__init__(*args, **kwargs)
self._lock = threading.Lock()
def __getitem__(self, key):
with self._lock:
return super(synchronized_dict, self).__getitem__(key)
def __setitem__(self, key, value):
with self._lock:
return super(synchronized_dict, self).__setitem__(key, value)
def __delitem__(self, key):
with self._lock:
return super(synchronized_dict, self).__delitem__(key)
STATE = synchronized_dict()
SETTINGS = synchronized_dict()
# Class and function decorators.
def aggregate(*groups):
def decorator(klass):
for group in groups:
AGGREGATE_COLLECTION.setdefault(group, [])
AGGREGATE_COLLECTION[group].append(klass)
return klass
return decorator
def table_function(*groups):
def decorator(klass):
for group in groups:
TABLE_FUNCTION_COLLECTION.setdefault(group, [])
TABLE_FUNCTION_COLLECTION[group].append(klass)
return klass
return decorator
def udf(*groups):
def decorator(fn):
for group in groups:
UDF_COLLECTION.setdefault(group, [])
UDF_COLLECTION[group].append(fn)
return fn
return decorator
# Register aggregates / functions with connection.
def register_aggregate_groups(db, *groups):
seen = set()
for group in groups:
klasses = AGGREGATE_COLLECTION.get(group, ())
for klass in klasses:
name = getattr(klass, 'name', klass.__name__)
if name not in seen:
seen.add(name)
db.register_aggregate(klass, name)
def register_table_function_groups(db, *groups):
seen = set()
for group in groups:
klasses = TABLE_FUNCTION_COLLECTION.get(group, ())
for klass in klasses:
if klass.name not in seen:
seen.add(klass.name)
db.register_table_function(klass)
def register_udf_groups(db, *groups):
seen = set()
for group in groups:
functions = UDF_COLLECTION.get(group, ())
for function in functions:
name = function.__name__
if name not in seen:
seen.add(name)
db.register_function(function, name)
def register_groups(db, *groups):
register_aggregate_groups(db, *groups)
register_table_function_groups(db, *groups)
register_udf_groups(db, *groups)
def register_all(db):
register_aggregate_groups(db, *AGGREGATE_COLLECTION)
register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION)
register_udf_groups(db, *UDF_COLLECTION)
# Begin actual user-defined functions and aggregates.
# Scalar functions.
@udf(CONTROL_FLOW)
def if_then_else(cond, truthy, falsey=None):
if cond:
return truthy
return falsey
@udf(DATE)
def strip_tz(date_str):
date_str = date_str.replace('T', ' ')
tz_idx1 = date_str.find('+')
if tz_idx1 != -1:
return date_str[:tz_idx1]
tz_idx2 = date_str.find('-')
if tz_idx2 > 13:
return date_str[:tz_idx2]
return date_str
@udf(DATE)
def human_delta(nseconds, glue=', '):
parts = (
(86400 * 365, 'year'),
(86400 * 30, 'month'),
(86400 * 7, 'week'),
(86400, 'day'),
(3600, 'hour'),
(60, 'minute'),
(1, 'second'),
)
accum = []
for offset, name in parts:
val, nseconds = divmod(nseconds, offset)
if val:
suffix = val != 1 and 's' or ''
accum.append('%s %s%s' % (val, name, suffix))
if not accum:
return '0 seconds'
return glue.join(accum)
@udf(FILE)
def file_ext(filename):
try:
res = os.path.splitext(filename)
except ValueError:
return None
return res[1]
@udf(FILE)
def file_read(filename):
try:
with open(filename) as fh:
return fh.read()
except:
pass
if sys.version_info[0] == 2:
@udf(HELPER)
def gzip(data, compression=9):
return buffer(zlib.compress(data, compression))
@udf(HELPER)
def gunzip(data):
return zlib.decompress(data)
else:
@udf(HELPER)
def gzip(data, compression=9):
if isinstance(data, str):
data = bytes(data.encode('raw_unicode_escape'))
return zlib.compress(data, compression)
@udf(HELPER)
def gunzip(data):
return zlib.decompress(data)
@udf(HELPER)
def hostname(url):
parse_result = urlparse(url)
if parse_result:
return parse_result.netloc
@udf(HELPER)
def toggle(key):
key = key.lower()
STATE[key] = ret = not STATE.get(key)
return ret
@udf(HELPER)
def setting(key, value=None):
if value is None:
return SETTINGS.get(key)
else:
SETTINGS[key] = value
return value
@udf(HELPER)
def clear_settings():
SETTINGS.clear()
@udf(HELPER)
def clear_toggles():
STATE.clear()
@udf(MATH)
def randomrange(start, end=None, step=None):
if end is None:
start, end = 0, start
elif step is None:
step = 1
return random.randrange(start, end, step)
@udf(MATH)
def gauss_distribution(mean, sigma):
try:
return random.gauss(mean, sigma)
except ValueError:
return None
@udf(MATH)
def sqrt(n):
try:
return math.sqrt(n)
except ValueError:
return None
@udf(MATH)
def tonumber(s):
try:
return int(s)
except ValueError:
try:
return float(s)
except:
return None
@udf(STRING)
def substr_count(haystack, needle):
if not haystack or not needle:
return 0
return haystack.count(needle)
@udf(STRING)
def strip_chars(haystack, chars):
return haystack.strip(chars)
def _hash(constructor, *args):
hash_obj = constructor()
for arg in args:
hash_obj.update(arg)
return hash_obj.hexdigest()
# Aggregates.
class _heap_agg(object):
def __init__(self):
self.heap = []
self.ct = 0
def process(self, value):
return value
def step(self, value):
self.ct += 1
heapq.heappush(self.heap, self.process(value))
class _datetime_heap_agg(_heap_agg):
def process(self, value):
return format_date_time_sqlite(value)
if sys.version_info[:2] == (2, 6):
def total_seconds(td):
return (td.seconds +
(td.days * 86400) +
(td.microseconds / (10.**6)))
else:
total_seconds = lambda td: td.total_seconds()
@aggregate(DATE)
class mintdiff(_datetime_heap_agg):
def finalize(self):
dtp = min_diff = None
while self.heap:
if min_diff is None:
if dtp is None:
dtp = heapq.heappop(self.heap)
continue
dt = heapq.heappop(self.heap)
diff = dt - dtp
if min_diff is None or min_diff > diff:
min_diff = diff
dtp = dt
if min_diff is not None:
return total_seconds(min_diff)
@aggregate(DATE)
class avgtdiff(_datetime_heap_agg):
def finalize(self):
if self.ct < 1:
return
elif self.ct == 1:
return 0
total = ct = 0
dtp = None
while self.heap:
if total == 0:
if dtp is None:
dtp = heapq.heappop(self.heap)
continue
dt = heapq.heappop(self.heap)
diff = dt - dtp
ct += 1
total += total_seconds(diff)
dtp = dt
return float(total) / ct
@aggregate(DATE)
class duration(object):
def __init__(self):
self._min = self._max = None
def step(self, value):
dt = format_date_time_sqlite(value)
if self._min is None or dt < self._min:
self._min = dt
if self._max is None or dt > self._max:
self._max = dt
def finalize(self):
if self._min and self._max:
td = (self._max - self._min)
return total_seconds(td)
return None
@aggregate(MATH)
class mode(object):
if Counter:
def __init__(self):
self.items = Counter()
def step(self, *args):
self.items.update(args)
def finalize(self):
if self.items:
return self.items.most_common(1)[0][0]
else:
def __init__(self):
self.items = []
def step(self, item):
self.items.append(item)
def finalize(self):
if self.items:
return max(set(self.items), key=self.items.count)
@aggregate(MATH)
class minrange(_heap_agg):
def finalize(self):
if self.ct == 0:
return
elif self.ct == 1:
return 0
prev = min_diff = None
while self.heap:
if min_diff is None:
if prev is None:
prev = heapq.heappop(self.heap)
continue
curr = heapq.heappop(self.heap)
diff = curr - prev
if min_diff is None or min_diff > diff:
min_diff = diff
prev = curr
return min_diff
@aggregate(MATH)
class avgrange(_heap_agg):
def finalize(self):
if self.ct == 0:
return
elif self.ct == 1:
return 0
total = ct = 0
prev = None
while self.heap:
if total == 0:
if prev is None:
prev = heapq.heappop(self.heap)
continue
curr = heapq.heappop(self.heap)
diff = curr - prev
ct += 1
total += diff
prev = curr
return float(total) / ct
@aggregate(MATH)
class _range(object):
name = 'range'
def __init__(self):
self._min = self._max = None
def step(self, value):
if self._min is None or value < self._min:
self._min = value
if self._max is None or value > self._max:
self._max = value
def finalize(self):
if self._min is not None and self._max is not None:
return self._max - self._min
return None
@aggregate(MATH)
class stddev(object):
def __init__(self):
self.n = 0
self.values = []
def step(self, v):
self.n += 1
self.values.append(v)
def finalize(self):
if self.n <= 1:
return 0
mean = sum(self.values) / self.n
return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1))
if cython_udf is not None:
damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist)
levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist)
str_dist = udf(STRING)(cython_udf.str_dist)
median = aggregate(MATH)(cython_udf.median)
if TableFunction is not None:
@table_function(STRING)
class RegexSearch(TableFunction):
params = ['regex', 'search_string']
columns = ['match']
name = 'regex_search'
def initialize(self, regex=None, search_string=None):
self._iter = re.finditer(regex, search_string)
def iterate(self, idx):
return (next(self._iter).group(0),)
@table_function(DATE)
class DateSeries(TableFunction):
params = ['start', 'stop', 'step_seconds']
columns = ['date']
name = 'date_series'
def initialize(self, start, stop, step_seconds=86400):
self.start = format_date_time_sqlite(start)
self.stop = format_date_time_sqlite(stop)
step_seconds = int(step_seconds)
self.step_seconds = datetime.timedelta(seconds=step_seconds)
if (self.start.hour == 0 and
self.start.minute == 0 and
self.start.second == 0 and
step_seconds >= 86400):
self.format = '%Y-%m-%d'
elif (self.start.year == 1900 and
self.start.month == 1 and
self.start.day == 1 and
self.stop.year == 1900 and
self.stop.month == 1 and
self.stop.day == 1 and
step_seconds < 86400):
self.format = '%H:%M:%S'
else:
self.format = '%Y-%m-%d %H:%M:%S'
def iterate(self, idx):
if self.start > self.stop:
raise StopIteration
current = self.start
self.start += self.step_seconds
return (current.strftime(self.format),)

@ -0,0 +1,331 @@
import logging
import weakref
from threading import local as thread_local
from threading import Event
from threading import Thread
try:
from Queue import Queue
except ImportError:
from queue import Queue
try:
import gevent
from gevent import Greenlet as GThread
from gevent.event import Event as GEvent
from gevent.local import local as greenlet_local
from gevent.queue import Queue as GQueue
except ImportError:
GThread = GQueue = GEvent = None
from peewee import SENTINEL
from playhouse.sqlite_ext import SqliteExtDatabase
logger = logging.getLogger('peewee.sqliteq')
class ResultTimeout(Exception):
pass
class WriterPaused(Exception):
pass
class ShutdownException(Exception):
pass
class AsyncCursor(object):
__slots__ = ('sql', 'params', 'commit', 'timeout',
'_event', '_cursor', '_exc', '_idx', '_rows', '_ready')
def __init__(self, event, sql, params, commit, timeout):
self._event = event
self.sql = sql
self.params = params
self.commit = commit
self.timeout = timeout
self._cursor = self._exc = self._idx = self._rows = None
self._ready = False
def set_result(self, cursor, exc=None):
self._cursor = cursor
self._exc = exc
self._idx = 0
self._rows = cursor.fetchall() if exc is None else []
self._event.set()
return self
def _wait(self, timeout=None):
timeout = timeout if timeout is not None else self.timeout
if not self._event.wait(timeout=timeout) and timeout:
raise ResultTimeout('results not ready, timed out.')
if self._exc is not None:
raise self._exc
self._ready = True
def __iter__(self):
if not self._ready:
self._wait()
if self._exc is not None:
raise self._exc
return self
def next(self):
if not self._ready:
self._wait()
try:
obj = self._rows[self._idx]
except IndexError:
raise StopIteration
else:
self._idx += 1
return obj
__next__ = next
@property
def lastrowid(self):
if not self._ready:
self._wait()
return self._cursor.lastrowid
@property
def rowcount(self):
if not self._ready:
self._wait()
return self._cursor.rowcount
@property
def description(self):
return self._cursor.description
def close(self):
self._cursor.close()
def fetchall(self):
return list(self) # Iterating implies waiting until populated.
def fetchone(self):
if not self._ready:
self._wait()
try:
return next(self)
except StopIteration:
return None
SHUTDOWN = StopIteration
PAUSE = object()
UNPAUSE = object()
class Writer(object):
__slots__ = ('database', 'queue')
def __init__(self, database, queue):
self.database = database
self.queue = queue
def run(self):
conn = self.database.connection()
try:
while True:
try:
if conn is None: # Paused.
if self.wait_unpause():
conn = self.database.connection()
else:
conn = self.loop(conn)
except ShutdownException:
logger.info('writer received shutdown request, exiting.')
return
finally:
if conn is not None:
self.database._close(conn)
self.database._state.reset()
def wait_unpause(self):
obj = self.queue.get()
if obj is UNPAUSE:
logger.info('writer unpaused - reconnecting to database.')
return True
elif obj is SHUTDOWN:
raise ShutdownException()
elif obj is PAUSE:
logger.error('writer received pause, but is already paused.')
else:
obj.set_result(None, WriterPaused())
logger.warning('writer paused, not handling %s', obj)
def loop(self, conn):
obj = self.queue.get()
if isinstance(obj, AsyncCursor):
self.execute(obj)
elif obj is PAUSE:
logger.info('writer paused - closing database connection.')
self.database._close(conn)
self.database._state.reset()
return
elif obj is UNPAUSE:
logger.error('writer received unpause, but is already running.')
elif obj is SHUTDOWN:
raise ShutdownException()
else:
logger.error('writer received unsupported object: %s', obj)
return conn
def execute(self, obj):
logger.debug('received query %s', obj.sql)
try:
cursor = self.database._execute(obj.sql, obj.params, obj.commit)
except Exception as execute_err:
cursor = None
exc = execute_err # python3 is so fucking lame.
else:
exc = None
return obj.set_result(cursor, exc)
class SqliteQueueDatabase(SqliteExtDatabase):
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL '
'journal mode when using this feature. WAL mode '
'allows one or more readers to continue reading '
'while another connection writes to the '
'database.')
def __init__(self, database, use_gevent=False, autostart=True,
queue_max_size=None, results_timeout=None, *args, **kwargs):
kwargs['check_same_thread'] = False
# Ensure that journal_mode is WAL. This value is passed to the parent
# class constructor below.
pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None))
# Reference to execute_sql on the parent class. Since we've overridden
# execute_sql(), this is just a handy way to reference the real
# implementation.
Parent = super(SqliteQueueDatabase, self)
self._execute = Parent.execute_sql
# Call the parent class constructor with our modified pragmas.
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
self._autostart = autostart
self._results_timeout = results_timeout
self._is_stopped = True
# Get different objects depending on the threading implementation.
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
# Create the writer thread, optionally starting it.
self._create_write_queue()
if self._autostart:
self.start()
def get_thread_impl(self, use_gevent):
return GreenletHelper if use_gevent else ThreadHelper
def _validate_journal_mode(self, pragmas=None):
if not pragmas:
return {'journal_mode': 'wal'}
if not isinstance(pragmas, dict):
pragmas = dict((k.lower(), v) for (k, v) in pragmas)
if pragmas.get('journal_mode', 'wal').lower() != 'wal':
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
pragmas['journal_mode'] = 'wal'
return pragmas
def _create_write_queue(self):
self._write_queue = self._thread_helper.queue()
def queue_size(self):
return self._write_queue.qsize()
def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None):
if commit is SENTINEL:
commit = not sql.lower().startswith('select')
if not commit:
return self._execute(sql, params, commit=commit)
cursor = AsyncCursor(
event=self._thread_helper.event(),
sql=sql,
params=params,
commit=commit,
timeout=self._results_timeout if timeout is None else timeout)
self._write_queue.put(cursor)
return cursor
def start(self):
with self._lock:
if not self._is_stopped:
return False
def run():
writer = Writer(self, self._write_queue)
writer.run()
self._writer = self._thread_helper.thread(run)
self._writer.start()
self._is_stopped = False
return True
def stop(self):
logger.debug('environment stop requested.')
with self._lock:
if self._is_stopped:
return False
self._write_queue.put(SHUTDOWN)
self._writer.join()
self._is_stopped = True
return True
def is_stopped(self):
with self._lock:
return self._is_stopped
def pause(self):
with self._lock:
self._write_queue.put(PAUSE)
def unpause(self):
with self._lock:
self._write_queue.put(UNPAUSE)
def __unsupported__(self, *args, **kwargs):
raise ValueError('This method is not supported by %r.' % type(self))
atomic = transaction = savepoint = __unsupported__
class ThreadHelper(object):
__slots__ = ('queue_max_size',)
def __init__(self, queue_max_size=None):
self.queue_max_size = queue_max_size
def event(self): return Event()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return Queue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
thread = Thread(target=fn, args=args, kwargs=kwargs)
thread.daemon = True
return thread
class GreenletHelper(ThreadHelper):
__slots__ = ()
def event(self): return GEvent()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return GQueue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
def wrap(*a, **k):
gevent.sleep()
return fn(*a, **k)
return GThread(wrap, *args, **kwargs)

@ -0,0 +1,62 @@
from functools import wraps
import logging
logger = logging.getLogger('peewee')
class _QueryLogHandler(logging.Handler):
def __init__(self, *args, **kwargs):
self.queries = []
logging.Handler.__init__(self, *args, **kwargs)
def emit(self, record):
self.queries.append(record)
class count_queries(object):
def __init__(self, only_select=False):
self.only_select = only_select
self.count = 0
def get_queries(self):
return self._handler.queries
def __enter__(self):
self._handler = _QueryLogHandler()
logger.setLevel(logging.DEBUG)
logger.addHandler(self._handler)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger.removeHandler(self._handler)
if self.only_select:
self.count = len([q for q in self._handler.queries
if q.msg[0].startswith('SELECT ')])
else:
self.count = len(self._handler.queries)
class assert_query_count(count_queries):
def __init__(self, expected, only_select=False):
super(assert_query_count, self).__init__(only_select=only_select)
self.expected = expected
def __call__(self, f):
@wraps(f)
def decorated(*args, **kwds):
with self:
ret = f(*args, **kwds)
self._assert_count()
return ret
return decorated
def _assert_count(self):
error_msg = '%s != %s' % (self.count, self.expected)
assert self.count == self.expected, error_msg
def __exit__(self, exc_type, exc_val, exc_tb):
super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb)
self._assert_count()

@ -1,219 +0,0 @@
# Copyright (c) 2014 Palantir Technologies
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""Thread safe sqlite3 interface."""
__author__ = "Shawn Lee"
__email__ = "shawnl@palantir.com"
__license__ = "MIT"
import logging
try:
import queue as Queue # module re-named in Python 3
except ImportError:
import Queue
import sqlite3
import threading
import time
import uuid
LOGGER = logging.getLogger('sqlite3worker')
class Sqlite3Worker(threading.Thread):
"""Sqlite thread safe object.
Example:
from sqlite3worker import Sqlite3Worker
sql_worker = Sqlite3Worker("/tmp/test.sqlite")
sql_worker.execute(
"CREATE TABLE tester (timestamp DATETIME, uuid TEXT)")
sql_worker.execute(
"INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow"))
sql_worker.execute(
"INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog"))
sql_worker.execute("SELECT * from tester")
sql_worker.close()
"""
def __init__(self, file_name, max_queue_size=100, as_dict=False):
"""Automatically starts the thread.
Args:
file_name: The name of the file.
max_queue_size: The max queries that will be queued.
as_dict: Return result as a dictionary.
"""
threading.Thread.__init__(self)
self.daemon = True
self.sqlite3_conn = sqlite3.connect(
file_name, check_same_thread=False,
detect_types=sqlite3.PARSE_DECLTYPES)
if as_dict:
self.sqlite3_conn.row_factory = dict_factory
self.sqlite3_cursor = self.sqlite3_conn.cursor()
self.sql_queue = Queue.Queue(maxsize=max_queue_size)
self.results = {}
self.max_queue_size = max_queue_size
self.exit_set = False
# Token that is put into queue when close() is called.
self.exit_token = str(uuid.uuid4())
self.start()
self.thread_running = True
def run(self):
"""Thread loop.
This is an infinite loop. The iter method calls self.sql_queue.get()
which blocks if there are not values in the queue. As soon as values
are placed into the queue the process will continue.
If many executes happen at once it will churn through them all before
calling commit() to speed things up by reducing the number of times
commit is called.
"""
LOGGER.debug("run: Thread started")
execute_count = 0
for token, query, values, only_one, execute_many in iter(self.sql_queue.get, None):
LOGGER.debug("sql_queue: %s", self.sql_queue.qsize())
if token != self.exit_token:
LOGGER.debug("run: %s, %s", query, values)
self.run_query(token, query, values, only_one, execute_many)
execute_count += 1
# Let the executes build up a little before committing to disk
# to speed things up.
if (
self.sql_queue.empty() or
execute_count == self.max_queue_size):
LOGGER.debug("run: commit")
self.sqlite3_conn.commit()
execute_count = 0
# Only exit if the queue is empty. Otherwise keep getting
# through the queue until it's empty.
if self.exit_set and self.sql_queue.empty():
self.sqlite3_conn.commit()
self.sqlite3_conn.close()
self.thread_running = False
return
def run_query(self, token, query, values, only_one=False, execute_many=False):
"""Run a query.
Args:
token: A uuid object of the query you want returned.
query: A sql query with ? placeholders for values.
values: A tuple of values to replace "?" in query.
"""
if query.lower().strip().startswith(("select", "pragma")):
try:
self.sqlite3_cursor.execute(query, values)
if only_one:
self.results[token] = self.sqlite3_cursor.fetchone()
else:
self.results[token] = self.sqlite3_cursor.fetchall()
except sqlite3.Error as err:
# Put the error into the output queue since a response
# is required.
self.results[token] = (
"Query returned error: %s: %s: %s" % (query, values, err))
LOGGER.error(
"Query returned error: %s: %s: %s", query, values, err)
else:
try:
if execute_many:
self.sqlite3_cursor.executemany(query, values)
if query.lower().strip().startswith(("insert", "update", "delete")):
self.results[token] = self.sqlite3_cursor.rowcount
else:
self.sqlite3_cursor.execute(query, values)
if query.lower().strip().startswith(("insert", "update", "delete")):
self.results[token] = self.sqlite3_cursor.rowcount
except sqlite3.Error as err:
self.results[token] = (
"Query returned error: %s: %s: %s" % (query, values, err))
LOGGER.error(
"Query returned error: %s: %s: %s", query, values, err)
def close(self):
"""Close down the thread and close the sqlite3 database file."""
self.exit_set = True
self.sql_queue.put((self.exit_token, "", "", "", ""), timeout=5)
# Sleep and check that the thread is done before returning.
while self.thread_running:
time.sleep(.01) # Don't kill the CPU waiting.
@property
def queue_size(self):
"""Return the queue size."""
return self.sql_queue.qsize()
def query_results(self, token):
"""Get the query results for a specific token.
Args:
token: A uuid object of the query you want returned.
Returns:
Return the results of the query when it's executed by the thread.
"""
delay = .001
while True:
if token in self.results:
return_val = self.results[token]
del self.results[token]
return return_val
# Double back on the delay to a max of 8 seconds. This prevents
# a long lived select statement from trashing the CPU with this
# infinite loop as it's waiting for the query results.
LOGGER.debug("Sleeping: %s %s", delay, token)
time.sleep(delay)
if delay < 8:
delay += delay
def execute(self, query, values=None, only_one=False, execute_many=False):
"""Execute a query.
Args:
query: The sql string using ? for placeholders of dynamic values.
values: A tuple of values to be replaced into the ? of the query.
Returns:
If it's a select query it will return the results of the query.
"""
if self.exit_set:
LOGGER.debug("Exit set, not running: %s, %s", query, values)
return "Exit Called"
LOGGER.debug("execute: %s, %s", query, values)
values = values or []
# A token to track this query with.
token = str(uuid.uuid4())
# If it's a select we queue it up with a token to mark the results
# into the output queue so we know what results are ours.
if query.lower().strip().startswith(("select", "insert", "update", "delete", "pragma")):
self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5)
return self.query_results(token)
else:
self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5)
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d

@ -20,6 +20,7 @@ guess_language-spirit=0.5.3
Js2Py=0.63 <-- modified: manually merged from upstream: https://github.com/PiotrDabkowski/Js2Py/pull/192/files
knowit=0.3.0-dev
msgpack=1.0.2
peewee=3.14.4
py-pretty=1
pycountry=18.2.23
pyga=2.6.1

Loading…
Cancel
Save