From 73df77984de9a7e371b25647c29e8b6a8d640736 Mon Sep 17 00:00:00 2001 From: morpheus65535 Date: Sun, 13 Oct 2019 22:41:34 -0400 Subject: [PATCH 1/8] Fix for #630 --- libs/pyprobe/ffprobeparsers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/pyprobe/ffprobeparsers.py b/libs/pyprobe/ffprobeparsers.py index 36a395c53..64ad05d10 100644 --- a/libs/pyprobe/ffprobeparsers.py +++ b/libs/pyprobe/ffprobeparsers.py @@ -121,7 +121,7 @@ class SubtitleStreamParser(BaseParser): """Returns a string """ tags = data.get("tags", None) if tags: - info = tags.get("language", None) + info = tags.get("language", None) or tags.get("LANGUAGE", None) return info, (info or "null") return None, "null" From b7049c79cea00aa0d179a05cce293f12bde48ec8 Mon Sep 17 00:00:00 2001 From: morpheus65535 Date: Mon, 14 Oct 2019 08:45:04 -0400 Subject: [PATCH 2/8] Fix for #620 --- bazarr/check_update.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bazarr/check_update.py b/bazarr/check_update.py index 08b1946fc..4e800edba 100644 --- a/bazarr/check_update.py +++ b/bazarr/check_update.py @@ -294,7 +294,8 @@ def request_json(url, **kwargs): def updated(restart=True): if settings.general.getboolean('update_restart') and restart: try: - requests.get(bazarr_url + 'restart') + from main import restart + restart() except requests.ConnectionError: logging.info('BAZARR Restart failed, please restart Bazarr manualy') updated(restart=False) From 983ffe2199941923983a42aeae1c88b6c8204c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20V=C3=A9zina?= <5130500+morpheus65535@users.noreply.github.com> Date: Tue, 15 Oct 2019 11:03:47 -0400 Subject: [PATCH 3/8] Fix for #635. --- bazarr/get_subtitle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bazarr/get_subtitle.py b/bazarr/get_subtitle.py index 95aee2b40..65fc00e84 100644 --- a/bazarr/get_subtitle.py +++ b/bazarr/get_subtitle.py @@ -1181,7 +1181,7 @@ def upgrade_subtitles(): notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long') logging.info("BAZARR All providers are throttled") return - if episode['languages'] != "None": + if episode['languages']: desired_languages = ast.literal_eval(str(episode['languages'])) if episode['forced'] == "True": forced_languages = [l + ":forced" for l in desired_languages] @@ -1230,7 +1230,7 @@ def upgrade_subtitles(): notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long') logging.info("BAZARR All providers are throttled") return - if movie['languages'] != "None": + if movie['languages']: desired_languages = ast.literal_eval(str(movie['languages'])) if movie['forced'] == "True": forced_languages = [l + ":forced" for l in desired_languages] From 55037cfde26ac72d74f1828551438751955b9f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20V=C3=A9zina?= <5130500+morpheus65535@users.noreply.github.com> Date: Tue, 15 Oct 2019 17:18:51 -0400 Subject: [PATCH 4/8] Major fixes to database subsystem. --- bazarr/database.py | 67 ++++++++++++++++++++-------------------- bazarr/get_subtitle.py | 4 +-- bazarr/list_subtitles.py | 10 +++--- bazarr/main.py | 9 ++---- 4 files changed, 44 insertions(+), 46 deletions(-) diff --git a/bazarr/database.py b/bazarr/database.py index 94846abd2..9c96b7772 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -3,19 +3,16 @@ import atexit from get_args import args from peewee import * -from playhouse.sqliteq import SqliteQueueDatabase from playhouse.migrate import * from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie -database = SqliteQueueDatabase( - None, - use_gevent=False, - autostart=False, - queue_max_size=256, # Max. # of pending writes that can accumulate. - results_timeout=30.0) # Max. time to wait for query to be executed. - -migrator = SqliteMigrator(database) +database = SqliteDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db')) +database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. +database.timeout = 30 # Number of second to wait for database +database.cache_size = -1024 # Number of KB of cache for wal-journal. + # Must be negative because positive means number of pages. +database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. @database.func('path_substitution') @@ -61,10 +58,6 @@ class TableShows(BaseModel): tvdb_id = IntegerField(column_name='tvdbId', null=True, unique=True, primary_key=True) year = TextField(null=True) - migrate( - migrator.add_column('table_shows', 'forced', forced), - ) - class Meta: table_name = 'table_shows' @@ -87,10 +80,6 @@ class TableEpisodes(BaseModel): video_codec = TextField(null=True) episode_file_id = IntegerField(null=True) - migrate( - migrator.add_column('table_episodes', 'episode_file_id', episode_file_id), - ) - class Meta: table_name = 'table_episodes' primary_key = False @@ -123,11 +112,6 @@ class TableMovies(BaseModel): year = TextField(null=True) movie_file_id = IntegerField(null=True) - migrate( - migrator.add_column('table_movies', 'forced', forced), - migrator.add_column('table_movies', 'movie_file_id', movie_file_id), - ) - class Meta: table_name = 'table_movies' @@ -183,20 +167,37 @@ class TableSettingsNotifier(BaseModel): table_name = 'table_settings_notifier' -def database_init(): - database.init(os.path.join(args.config_dir, 'db', 'bazarr.db')) - database.start() - database.connect() +# Database tables creation if they don't exists +models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages, + TableSettingsNotifier, System] +database.create_tables(models_list, safe=True) - database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. - database.cache_size = -1024 # Number of KB of cache for wal-journal. - # Must be negative because positive means number of pages. - database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. - models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages, - TableSettingsNotifier, System] +# Database migration +migrator = SqliteMigrator(database) - database.create_tables(models_list, safe=True) +# TableShows migration +table_shows_columns = [] +for column in database.get_columns('table_shows'): + table_shows_columns.append(column.name) +if 'forced' not in table_shows_columns: + migrate(migrator.add_column('table_shows', 'forced', TableShows.forced)) + +# TableEpisodes migration +table_episodes_columns = [] +for column in database.get_columns('table_episodes'): + table_episodes_columns.append(column.name) +if 'episode_file_id' not in table_episodes_columns: + migrate(migrator.add_column('table_episodes', 'episode_file_id', TableEpisodes.episode_file_id)) + +# TableMovies migration +table_movies_columns = [] +for column in database.get_columns('table_movies'): + table_movies_columns.append(column.name) +if 'forced' not in table_movies_columns: + migrate(migrator.add_column('table_movies', 'forced', TableMovies.forced)) +if 'movie_file_id' not in table_movies_columns: + migrate(migrator.add_column('table_movies', 'movie_file_id', TableMovies.movie_file_id)) def wal_cleaning(): diff --git a/bazarr/get_subtitle.py b/bazarr/get_subtitle.py index 65fc00e84..9207e6b9e 100644 --- a/bazarr/get_subtitle.py +++ b/bazarr/get_subtitle.py @@ -33,7 +33,7 @@ from get_providers import get_providers, get_providers_auth, provider_throttle, from get_args import args from queueconfig import notifications from pyprobe.pyprobe import VideoFileParser -from database import TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie +from database import database, TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie from peewee import fn, JOIN from analytics import track_event @@ -620,7 +620,7 @@ def series_download_subtitles(no): def episode_download_subtitles(no): episodes_details_clause = [ - (TableEpisodes.sonarr_series_id == no) + (TableEpisodes.sonarr_episode_id == no) ] if settings.sonarr.getboolean('only_monitored'): episodes_details_clause.append( diff --git a/bazarr/list_subtitles.py b/bazarr/list_subtitles.py index 2c15d4a31..58abb1762 100644 --- a/bazarr/list_subtitles.py +++ b/bazarr/list_subtitles.py @@ -222,11 +222,11 @@ def store_subtitles_movie(file): def list_missing_subtitles(no=None): - episodes_subtitles_clause = {TableShows.sonarr_series_id.is_null(False)} + episodes_subtitles_clause = (TableShows.sonarr_series_id.is_null(False)) if no is not None: - episodes_subtitles_clause = {TableShows.sonarr_series_id ** no} - + episodes_subtitles_clause = (TableShows.sonarr_series_id == no) episodes_subtitles = TableEpisodes.select( + TableShows.sonarr_series_id, TableEpisodes.sonarr_episode_id, TableEpisodes.subtitles, TableShows.languages, @@ -288,9 +288,9 @@ def list_missing_subtitles(no=None): def list_missing_subtitles_movies(no=None): - movies_subtitles_clause = {TableMovies.radarr_id.is_null(False)} + movies_subtitles_clause = (TableMovies.radarr_id.is_null(False)) if no is not None: - movies_subtitles_clause = {TableMovies.radarr_id ** no} + movies_subtitles_clause = (TableMovies.radarr_id == no) movies_subtitles = TableMovies.select( TableMovies.radarr_id, diff --git a/bazarr/main.py b/bazarr/main.py index 9b42594a4..b52688e47 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -23,12 +23,9 @@ from calendar import day_name from get_args import args from init import * -from database import database, database_init, TableEpisodes, TableShows, TableMovies, TableHistory, TableHistoryMovie, \ +from database import database, TableEpisodes, TableShows, TableMovies, TableHistory, TableHistoryMovie, \ TableSettingsLanguages, TableSettingsNotifier, System -# Initiate database -database_init() - from notifier import update_notifier from logger import configure_logging, empty_log @@ -736,7 +733,7 @@ def edit_series(no): TableShows.forced: forced } ).where( - TableShows.sonarr_series_id ** no + TableShows.sonarr_series_id == no ).execute() list_missing_subtitles(no) @@ -809,7 +806,7 @@ def episodes(no): fn.path_substitution(TableShows.path).alias('path'), TableShows.forced ).where( - TableShows.sonarr_series_id ** str(no) + TableShows.sonarr_series_id == no ).limit(1) for series in series_details: tvdbid = series.tvdb_id From ae6f7117fca34896d2dfc2c1007f9637e029ec19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20V=C3=A9zina?= <5130500+morpheus65535@users.noreply.github.com> Date: Wed, 16 Oct 2019 10:20:35 -0400 Subject: [PATCH 5/8] Major fixes to database subsystem. --- bazarr/database.py | 11 +- libs/peewee.py | 325 ++++++++++++++++++++++++++------- libs/playhouse/dataset.py | 41 ++++- libs/playhouse/hybrid.py | 7 +- libs/playhouse/migrate.py | 2 +- libs/playhouse/mysql_ext.py | 17 +- libs/playhouse/pool.py | 3 +- libs/playhouse/postgres_ext.py | 12 +- libs/playhouse/reflection.py | 2 +- libs/playhouse/shortcuts.py | 6 +- libs/playhouse/signals.py | 2 +- libs/playhouse/sqlite_ext.py | 36 +++- 12 files changed, 377 insertions(+), 87 deletions(-) diff --git a/bazarr/database.py b/bazarr/database.py index 9c96b7772..c347a5933 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -3,13 +3,20 @@ import atexit from get_args import args from peewee import * +from playhouse.sqliteq import SqliteQueueDatabase from playhouse.migrate import * from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie -database = SqliteDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db')) +database = SqliteQueueDatabase( + os.path.join(args.config_dir, 'db', 'bazarr.db'), + use_gevent=False, + autostart=True, + queue_max_size=256, # Max. # of pending writes that can accumulate. + results_timeout=30.0 # Max. time to wait for query to be executed. +) + database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal. -database.timeout = 30 # Number of second to wait for database database.cache_size = -1024 # Number of KB of cache for wal-journal. # Must be negative because positive means number of pages. database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions. diff --git a/libs/peewee.py b/libs/peewee.py index 3204edb34..c41dc7135 100644 --- a/libs/peewee.py +++ b/libs/peewee.py @@ -65,7 +65,7 @@ except ImportError: mysql = None -__version__ = '3.9.6' +__version__ = '3.11.2' __all__ = [ 'AsIs', 'AutoField', @@ -206,15 +206,15 @@ __sqlite_datetime_formats__ = ( '%H:%M') __sqlite_date_trunc__ = { - 'year': '%Y', - 'month': '%Y-%m', - 'day': '%Y-%m-%d', - 'hour': '%Y-%m-%d %H', - 'minute': '%Y-%m-%d %H:%M', + 'year': '%Y-01-01 00:00:00', + 'month': '%Y-%m-01 00:00:00', + 'day': '%Y-%m-%d 00:00:00', + 'hour': '%Y-%m-%d %H:00:00', + 'minute': '%Y-%m-%d %H:%M:00', 'second': '%Y-%m-%d %H:%M:%S'} __mysql_date_trunc__ = __sqlite_date_trunc__.copy() -__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i' +__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' __mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' def _sqlite_date_part(lookup_type, datetime_string): @@ -460,6 +460,9 @@ class DatabaseProxy(Proxy): return _savepoint(self) +class ModelDescriptor(object): pass + + # SQL Generation. @@ -1141,11 +1144,24 @@ class ColumnBase(Node): op = OP.IS if is_null else OP.IS_NOT return Expression(self, op, None) def contains(self, rhs): - return Expression(self, OP.ILIKE, '%%%s%%' % rhs) + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, + Expression(rhs, OP.CONCAT, '%')) + else: + rhs = '%%%s%%' % rhs + return Expression(self, OP.ILIKE, rhs) def startswith(self, rhs): - return Expression(self, OP.ILIKE, '%s%%' % rhs) + if isinstance(rhs, Node): + rhs = Expression(rhs, OP.CONCAT, '%') + else: + rhs = '%s%%' % rhs + return Expression(self, OP.ILIKE, rhs) def endswith(self, rhs): - return Expression(self, OP.ILIKE, '%%%s' % rhs) + if isinstance(rhs, Node): + rhs = Expression('%', OP.CONCAT, rhs) + else: + rhs = '%%%s' % rhs + return Expression(self, OP.ILIKE, rhs) def between(self, lo, hi): return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) def concat(self, rhs): @@ -1229,6 +1245,9 @@ class Alias(WrappedNode): super(Alias, self).__init__(node) self._alias = alias + def __hash__(self): + return hash(self._alias) + def alias(self, alias=None): if alias is None: return self.node @@ -1376,8 +1395,16 @@ class Expression(ColumnBase): def __sql__(self, ctx): overrides = {'parentheses': not self.flat, 'in_expr': True} - if isinstance(self.lhs, Field): - overrides['converter'] = self.lhs.db_value + + # First attempt to unwrap the node on the left-hand-side, so that we + # can get at the underlying Field if one is present. + node = raw_node = self.lhs + if isinstance(raw_node, WrappedNode): + node = raw_node.unwrap() + + # Set up the appropriate converter if we have a field on the left side. + if isinstance(node, Field) and raw_node._coerce: + overrides['converter'] = node.db_value else: overrides['converter'] = None @@ -2090,6 +2117,11 @@ class CompoundSelectQuery(SelectBase): def _returning(self): return self.lhs._returning + @database_required + def exists(self, database): + query = Select((self.limit(1),), (SQL('1'),)).bind(database) + return bool(query.scalar()) + def _get_query_key(self): return (self.lhs.get_query_key(), self.rhs.get_query_key()) @@ -2101,6 +2133,14 @@ class CompoundSelectQuery(SelectBase): elif csq_setting == CSQ_PARENTHESES_ALWAYS: return True elif csq_setting == CSQ_PARENTHESES_UNNESTED: + if ctx.state.in_expr or ctx.state.in_function: + # If this compound select query is being used inside an + # expression, e.g., an IN or EXISTS(). + return False + + # If the query on the left or right is itself a compound select + # query, then we do not apply parentheses. However, if it is a + # regular SELECT query, we will apply parentheses. return not isinstance(subq, CompoundSelectQuery) def __sql__(self, ctx): @@ -2433,7 +2473,6 @@ class Insert(_WriteQuery): # Load and organize column defaults (if provided). defaults = self.get_default_data() - value_lookups = {} # First figure out what columns are being inserted (if they weren't # specified explicitly). Resulting columns are normalized and ordered. @@ -2443,47 +2482,48 @@ class Insert(_WriteQuery): except StopIteration: raise self.DefaultValuesException('Error: no rows to insert.') - if not isinstance(row, dict): + if not isinstance(row, Mapping): columns = self.get_default_columns() if columns is None: raise ValueError('Bulk insert must specify columns.') else: # Infer column names from the dict of data being inserted. accum = [] - uses_strings = False # Are the dict keys strings or columns? - for key in row: - if isinstance(key, basestring): - column = getattr(self.table, key) - uses_strings = True - else: - column = key + for column in row: + if isinstance(column, basestring): + column = getattr(self.table, column) accum.append(column) - value_lookups[column] = key # Add any columns present in the default data that are not # accounted for by the dictionary of row data. column_set = set(accum) for col in (set(defaults) - column_set): accum.append(col) - value_lookups[col] = col.name if uses_strings else col columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) rows_iter = itertools.chain(iter((row,)), rows_iter) else: clean_columns = [] + seen = set() for column in columns: if isinstance(column, basestring): column_obj = getattr(self.table, column) else: column_obj = column - value_lookups[column_obj] = column clean_columns.append(column_obj) + seen.add(column_obj) columns = clean_columns for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): - if col not in value_lookups: + if col not in seen: columns.append(col) - value_lookups[col] = col + + value_lookups = {} + for column in columns: + lookups = [column, column.name] + if isinstance(column, Field) and column.name != column.column_name: + lookups.append(column.column_name) + value_lookups[column] = lookups ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') columns_converters = [ @@ -2497,7 +2537,18 @@ class Insert(_WriteQuery): for i, (column, converter) in enumerate(columns_converters): try: if is_dict: - val = row[value_lookups[column]] + # The logic is a bit convoluted, but in order to be + # flexible in what we accept (dict keyed by + # column/field, field name, or underlying column name), + # we try accessing the row data dict using each + # possible key. If no match is found, throw an error. + for lookup in value_lookups[column]: + try: + val = row[lookup] + except KeyError: pass + else: break + else: + raise KeyError else: val = row[i] except (KeyError, IndexError): @@ -2544,7 +2595,7 @@ class Insert(_WriteQuery): .sql(self.table) .literal(' ')) - if isinstance(self._insert, dict) and not self._columns: + if isinstance(self._insert, Mapping) and not self._columns: try: self._simple_insert(ctx) except self.DefaultValuesException: @@ -2817,7 +2868,8 @@ class Database(_callable_context_manager): truncate_table = True def __init__(self, database, thread_safe=True, autorollback=False, - field_types=None, operations=None, autocommit=None, **kwargs): + field_types=None, operations=None, autocommit=None, + autoconnect=True, **kwargs): self._field_types = merge_dict(FIELD, self.field_types) self._operations = merge_dict(OP, self.operations) if field_types: @@ -2825,6 +2877,7 @@ class Database(_callable_context_manager): if operations: self._operations.update(operations) + self.autoconnect = autoconnect self.autorollback = autorollback self.thread_safe = thread_safe if thread_safe: @@ -2930,7 +2983,10 @@ class Database(_callable_context_manager): def cursor(self, commit=None): if self.is_closed(): - self.connect() + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') return self._state.conn.cursor() def execute_sql(self, sql, params=None, commit=SENTINEL): @@ -3141,6 +3197,15 @@ class Database(_callable_context_manager): def truncate_date(self, date_part, date_field): raise NotImplementedError + def to_timestamp(self, date_field): + raise NotImplementedError + + def from_timestamp(self, date_field): + raise NotImplementedError + + def random(self): + return fn.random() + def bind(self, models, bind_refs=True, bind_backrefs=True): for model in models: model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) @@ -3524,10 +3589,17 @@ class SqliteDatabase(Database): return self._build_on_conflict_update(oc, query) def extract_date(self, date_part, date_field): - return fn.date_part(date_part, date_field) + return fn.date_part(date_part, date_field, python_value=int) def truncate_date(self, date_part, date_field): - return fn.date_trunc(date_part, date_field) + return fn.date_trunc(date_part, date_field, + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.strftime('%s', date_field).cast('integer') + + def from_timestamp(self, date_field): + return fn.datetime(date_field, 'unixepoch') class PostgresqlDatabase(Database): @@ -3552,9 +3624,11 @@ class PostgresqlDatabase(Database): safe_create_index = False sequences = True - def init(self, database, register_unicode=True, encoding=None, **kwargs): + def init(self, database, register_unicode=True, encoding=None, + isolation_level=None, **kwargs): self._register_unicode = register_unicode self._encoding = encoding + self._isolation_level = isolation_level super(PostgresqlDatabase, self).init(database, **kwargs) def _connect(self): @@ -3566,6 +3640,8 @@ class PostgresqlDatabase(Database): pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) if self._encoding: conn.set_client_encoding(self._encoding) + if self._isolation_level: + conn.set_isolation_level(self._isolation_level) return conn def _set_server_version(self, conn): @@ -3695,6 +3771,13 @@ class PostgresqlDatabase(Database): def truncate_date(self, date_part, date_field): return fn.DATE_TRUNC(date_part, date_field) + def to_timestamp(self, date_field): + return self.extract_date('EPOCH', date_field) + + def from_timestamp(self, date_field): + # Ironically, here, Postgres means "to the Postgresql timestamp type". + return fn.to_timestamp(date_field) + def get_noop_select(self, ctx): return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) @@ -3727,9 +3810,13 @@ class MySQLDatabase(Database): limit_max = 2 ** 64 - 1 safe_create_index = False safe_drop_index = False + sql_mode = 'PIPES_AS_CONCAT' def init(self, database, **kwargs): - params = {'charset': 'utf8', 'use_unicode': True} + params = { + 'charset': 'utf8', + 'sql_mode': self.sql_mode, + 'use_unicode': True} params.update(kwargs) if 'password' in params and mysql_passwd: params['passwd'] = params.pop('password') @@ -3876,7 +3963,17 @@ class MySQLDatabase(Database): return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) def truncate_date(self, date_part, date_field): - return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part]) + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], + python_value=simple_date_time) + + def to_timestamp(self, date_field): + return fn.UNIX_TIMESTAMP(date_field) + + def from_timestamp(self, date_field): + return fn.FROM_UNIXTIME(date_field) + + def random(self): + return fn.rand() def get_noop_select(self, ctx): return ctx.literal('DO 0') @@ -4274,7 +4371,7 @@ class Field(ColumnBase): def bind(self, model, name, set_attribute=True): self.model = model - self.name = name + self.name = self.safe_name = name self.column_name = self.column_name or name if set_attribute: setattr(model, name, self.accessor_class(model, self, name)) @@ -4299,7 +4396,7 @@ class Field(ColumnBase): return ctx.sql(self.column) def get_modifiers(self): - return + pass def ddl_datatype(self, ctx): if ctx and ctx.state.field_types: @@ -4337,7 +4434,12 @@ class Field(ColumnBase): class IntegerField(Field): field_type = 'INT' - adapt = int + + def adapt(self, value): + try: + return int(value) + except ValueError: + return value class BigIntegerField(IntegerField): @@ -4377,7 +4479,12 @@ class PrimaryKeyField(AutoField): class FloatField(Field): field_type = 'FLOAT' - adapt = float + + def adapt(self, value): + try: + return float(value) + except ValueError: + return value class DoubleField(FloatField): @@ -4393,6 +4500,7 @@ class DecimalField(Field): self.decimal_places = decimal_places self.auto_round = auto_round self.rounding = rounding or decimal.DefaultContext.rounding + self._exp = decimal.Decimal(10) ** (-self.decimal_places) super(DecimalField, self).__init__(*args, **kwargs) def get_modifiers(self): @@ -4403,9 +4511,8 @@ class DecimalField(Field): if not value: return value if value is None else D(0) if self.auto_round: - exp = D(10) ** (-self.decimal_places) - rounding = self.rounding - return D(text_type(value)).quantize(exp, rounding=rounding) + decimal_value = D(text_type(value)) + return decimal_value.quantize(self._exp, rounding=self.rounding) return value def python_value(self, value): @@ -4423,8 +4530,8 @@ class _StringField(Field): return value.decode('utf-8') return text_type(value) - def __add__(self, other): return self.concat(other) - def __radd__(self, other): return other.concat(self) + def __add__(self, other): return StringExpression(self, OP.CONCAT, other) + def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) class CharField(_StringField): @@ -4652,6 +4759,12 @@ def format_date_time(value, formats, post_process=None): pass return value +def simple_date_time(value): + try: + return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') + except (TypeError, ValueError): + return value + class _BaseFormattedField(Field): formats = None @@ -4675,6 +4788,12 @@ class DateTimeField(_BaseFormattedField): return format_date_time(value, self.formats) return value + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + year = property(_date_part('year')) month = property(_date_part('month')) day = property(_date_part('day')) @@ -4699,6 +4818,12 @@ class DateField(_BaseFormattedField): return value.date() return value + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + year = property(_date_part('year')) month = property(_date_part('month')) day = property(_date_part('day')) @@ -4730,19 +4855,30 @@ class TimeField(_BaseFormattedField): second = property(_date_part('second')) +def _timestamp_date_part(date_part): + def dec(self): + db = self.model._meta.database + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return db.extract_date(date_part, db.from_timestamp(expr)) + return dec + + class TimestampField(BigIntegerField): # Support second -> microsecond resolution. valid_resolutions = [10**i for i in range(7)] def __init__(self, *args, **kwargs): self.resolution = kwargs.pop('resolution', None) + if not self.resolution: self.resolution = 1 - elif self.resolution in range(7): + elif self.resolution in range(2, 7): self.resolution = 10 ** self.resolution elif self.resolution not in self.valid_resolutions: raise ValueError('TimestampField resolution must be one of: %s' % ', '.join(str(i) for i in self.valid_resolutions)) + self.ticks_to_microsecond = 1000000 // self.resolution self.utc = kwargs.pop('utc', False) or False dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now @@ -4764,6 +4900,13 @@ class TimestampField(BigIntegerField): ts = calendar.timegm(dt.utctimetuple()) return datetime.datetime.fromtimestamp(ts) + def get_timestamp(self, value): + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + return calendar.timegm(value.utctimetuple()) + else: + return time.mktime(value.timetuple()) + def db_value(self, value): if value is None: return @@ -4775,12 +4918,7 @@ class TimestampField(BigIntegerField): else: return int(round(value * self.resolution)) - if self.utc: - # If utc-mode is on, then we assume all naive datetimes are in UTC. - timestamp = calendar.timegm(value.utctimetuple()) - else: - timestamp = time.mktime(value.timetuple()) - + timestamp = self.get_timestamp(value) if self.resolution > 1: timestamp += (value.microsecond * .000001) timestamp *= self.resolution @@ -4789,9 +4927,8 @@ class TimestampField(BigIntegerField): def python_value(self, value): if value is not None and isinstance(value, (int, float, long)): if self.resolution > 1: - ticks_to_microsecond = 1000000 // self.resolution value, ticks = divmod(value, self.resolution) - microseconds = int(ticks * ticks_to_microsecond) + microseconds = int(ticks * self.ticks_to_microsecond) else: microseconds = 0 @@ -4805,6 +4942,18 @@ class TimestampField(BigIntegerField): return value + def from_timestamp(self): + expr = ((self / Value(self.resolution, converter=False)) + if self.resolution > 1 else self) + return self.model._meta.database.from_timestamp(expr) + + year = property(_timestamp_date_part('year')) + month = property(_timestamp_date_part('month')) + day = property(_timestamp_date_part('day')) + hour = property(_timestamp_date_part('hour')) + minute = property(_timestamp_date_part('minute')) + second = property(_timestamp_date_part('second')) + class IPField(BigIntegerField): def db_value(self, val): @@ -4860,6 +5009,7 @@ class ForeignKeyField(Field): '"backref" for Field objects.') backref = related_name + self._is_self_reference = model == 'self' self.rel_model = model self.rel_field = field self.declared_backref = backref @@ -4908,7 +5058,7 @@ class ForeignKeyField(Field): raise ValueError('ForeignKeyField "%s"."%s" specifies an ' 'object_id_name that conflicts with its field ' 'name.' % (model._meta.name, name)) - if self.rel_model == 'self': + if self._is_self_reference: self.rel_model = model if isinstance(self.rel_field, basestring): self.rel_field = getattr(self.rel_model, self.rel_field) @@ -4918,6 +5068,7 @@ class ForeignKeyField(Field): # Bind field before assigning backref, so field is bound when # calling declared_backref() (if callable). super(ForeignKeyField, self).bind(model, name, set_attribute) + self.safe_name = self.object_id_name if callable_(self.declared_backref): self.backref = self.declared_backref(self) @@ -5143,7 +5294,7 @@ class VirtualField(MetaField): def bind(self, model, name, set_attribute=True): self.model = model - self.column_name = self.name = name + self.column_name = self.name = self.safe_name = name setattr(model, name, self.accessor_class(model, self, name)) @@ -5190,7 +5341,7 @@ class CompositeKey(MetaField): def bind(self, model, name, set_attribute=True): self.model = model - self.column_name = self.name = name + self.column_name = self.name = self.safe_name = name setattr(model, self.name, self) @@ -6139,7 +6290,12 @@ class Model(with_metaclass(ModelBase, Node)): return cls.select().filter(*dq_nodes, **filters) def get_id(self): - return getattr(self, self._meta.primary_key.name) + # Using getattr(self, pk-name) could accidentally trigger a query if + # the primary-key is a foreign-key. So we use the safe_name attribute, + # which defaults to the field-name, but will be the object_id_name for + # foreign-key fields. + if self._meta.primary_key is not False: + return getattr(self, self._meta.primary_key.safe_name) _pk = property(get_id) @@ -6254,13 +6410,14 @@ class Model(with_metaclass(ModelBase, Node)): return ( other.__class__ == self.__class__ and self._pk is not None and - other._pk == self._pk) + self._pk == other._pk) def __ne__(self, other): return not self == other def __sql__(self, ctx): - return ctx.sql(getattr(self, self._meta.primary_key.name)) + return ctx.sql(Value(getattr(self, self._meta.primary_key.name), + converter=self._meta.primary_key.db_value)) @classmethod def bind(cls, database, bind_refs=True, bind_backrefs=True): @@ -6327,6 +6484,18 @@ class ModelAlias(Node): self.__dict__['alias'] = alias def __getattr__(self, attr): + # Hack to work-around the fact that properties or other objects + # implementing the descriptor protocol (on the model being aliased), + # will not work correctly when we use getattr(). So we explicitly pass + # the model alias to the descriptor's getter. + try: + obj = self.model.__dict__[attr] + except KeyError: + pass + else: + if isinstance(obj, ModelDescriptor): + return obj.__get__(None, self) + model_attr = getattr(self.model, attr) if isinstance(model_attr, Field): self.__dict__[attr] = FieldAlias.create(self, model_attr) @@ -6675,6 +6844,13 @@ class ModelSelect(BaseModelSelect, Select): return fk_fields[0], is_backref if on is None: + # If multiple foreign-keys exist, try using the FK whose name + # matches that of the related model. If not, raise an error as this + # is ambiguous. + for fk in fk_fields: + if fk.name == dest._meta.name: + return fk, is_backref + raise ValueError('More than one foreign key between %s and %s.' ' Please specify which you are joining on.' % (src, dest)) @@ -6710,7 +6886,7 @@ class ModelSelect(BaseModelSelect, Select): on, attr, constructor = self._normalize_join(src, dest, on, attr) if attr: self._joins.setdefault(src, []) - self._joins[src].append((dest, attr, constructor)) + self._joins[src].append((dest, attr, constructor, join_type)) elif on is not None: raise ValueError('Cannot specify on clause with cross join.') @@ -6732,7 +6908,7 @@ class ModelSelect(BaseModelSelect, Select): def ensure_join(self, lm, rm, on=None, **join_kwargs): join_ctx = self._join_ctx - for dest, attr, constructor in self._joins.get(lm, []): + for dest, _, constructor, _ in self._joins.get(lm, []): if dest == rm: return self return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) @@ -6757,7 +6933,7 @@ class ModelSelect(BaseModelSelect, Select): model_attr = getattr(curr, key) else: for piece in key.split('__'): - for dest, attr, _ in self._joins.get(curr, ()): + for dest, attr, _, _ in self._joins.get(curr, ()): if attr == piece or (isinstance(dest, ModelAlias) and dest.alias == piece): curr = dest @@ -6999,8 +7175,7 @@ class BaseModelCursorWrapper(DictCursorWrapper): if raw_node._coerce: converters[idx] = node.python_value fields[idx] = node - if (column == node.name or column == node.column_name) and \ - not raw_node.is_alias(): + if not raw_node.is_alias(): self.columns[idx] = node.name elif isinstance(node, Function) and node._coerce: if node._python_value is not None: @@ -7100,6 +7275,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper): self.src_to_dest = [] accum = collections.deque(self.from_list) dests = set() + while accum: curr = accum.popleft() if isinstance(curr, Join): @@ -7110,11 +7286,14 @@ class ModelCursorWrapper(BaseModelCursorWrapper): if curr not in self.joins: continue - for key, attr, constructor in self.joins[curr]: + is_dict = isinstance(curr, dict) + for key, attr, constructor, join_type in self.joins[curr]: if key not in self.key_to_constructor: self.key_to_constructor[key] = constructor - self.src_to_dest.append((curr, attr, key, - isinstance(curr, dict))) + + # (src, attr, dest, is_dict, join_type). + self.src_to_dest.append((curr, attr, key, is_dict, + join_type)) dests.add(key) accum.append(key) @@ -7127,7 +7306,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper): self.key_to_constructor[src] = src.model # Indicate which sources are also dests. - for src, _, dest, _ in self.src_to_dest: + for src, _, dest, _, _ in self.src_to_dest: self.src_is_dest[src] = src in dests and (dest in selected_src or src in selected_src) @@ -7171,7 +7350,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper): setattr(instance, column, value) # Need to do some analysis on the joins before this. - for (src, attr, dest, is_dict) in self.src_to_dest: + for (src, attr, dest, is_dict, join_type) in self.src_to_dest: instance = objects[src] try: joined_instance = objects[dest] @@ -7184,6 +7363,12 @@ class ModelCursorWrapper(BaseModelCursorWrapper): (dest not in set_keys and not self.src_is_dest.get(dest)): continue + # If no fields were set on either the source or the destination, + # then we have nothing to do here. + if instance not in set_keys and dest not in set_keys \ + and join_type.endswith('OUTER'): + continue + if is_dict: instance[attr] = joined_instance else: @@ -7309,7 +7494,7 @@ def prefetch(sq, *subqueries): rel_map.setdefault(rel_model, []) rel_map[rel_model].append(pq) - deps[query_model] = {} + deps.setdefault(query_model, {}) id_map = deps[query_model] has_relations = bool(rel_map.get(query_model)) diff --git a/libs/playhouse/dataset.py b/libs/playhouse/dataset.py index 27f8189bb..f5bbf8b28 100644 --- a/libs/playhouse/dataset.py +++ b/libs/playhouse/dataset.py @@ -61,12 +61,14 @@ class DataSet(object): def get_export_formats(self): return { 'csv': CSVExporter, - 'json': JSONExporter} + 'json': JSONExporter, + 'tsv': TSVExporter} def get_import_formats(self): return { 'csv': CSVImporter, - 'json': JSONImporter} + 'json': JSONImporter, + 'tsv': TSVImporter} def __getitem__(self, table): if table not in self._models and table in self.tables: @@ -244,6 +246,29 @@ class Table(object): self.dataset.update_cache(self.name) + def __getitem__(self, item): + try: + return self.model_class[item] + except self.model_class.DoesNotExist: + pass + + def __setitem__(self, item, value): + if not isinstance(value, dict): + raise ValueError('Table.__setitem__() value must be a dict') + + pk = self.model_class._meta.primary_key + value[pk.name] = item + + try: + with self.dataset.transaction() as txn: + self.insert(**value) + except IntegrityError: + self.dataset.update_cache(self.name) + self.update(columns=[pk.name], **value) + + def __delitem__(self, item): + del self.model_class[item] + def insert(self, **data): self._migrate_new_columns(data) return self.model_class.insert(**data).execute() @@ -343,6 +368,12 @@ class CSVExporter(Exporter): writer.writerow(row) +class TSVExporter(CSVExporter): + def export(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVExporter, self).export(file_obj, header, **kwargs) + + class Importer(object): def __init__(self, table, strict=False): self.table = table @@ -413,3 +444,9 @@ class CSVImporter(Importer): count += 1 return count + + +class TSVImporter(CSVImporter): + def load(self, file_obj, header=True, **kwargs): + kwargs.setdefault('delimiter', '\t') + return super(TSVImporter, self).load(file_obj, header, **kwargs) diff --git a/libs/playhouse/hybrid.py b/libs/playhouse/hybrid.py index 53f226288..50531cc35 100644 --- a/libs/playhouse/hybrid.py +++ b/libs/playhouse/hybrid.py @@ -1,6 +1,9 @@ +from peewee import ModelDescriptor + + # Hybrid methods/attributes, based on similar functionality in SQLAlchemy: # http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html -class hybrid_method(object): +class hybrid_method(ModelDescriptor): def __init__(self, func, expr=None): self.func = func self.expr = expr or func @@ -15,7 +18,7 @@ class hybrid_method(object): return self -class hybrid_property(object): +class hybrid_property(ModelDescriptor): def __init__(self, fget, fset=None, fdel=None, expr=None): self.fget = fget self.fset = fset diff --git a/libs/playhouse/migrate.py b/libs/playhouse/migrate.py index 0abde2123..4d90b70ec 100644 --- a/libs/playhouse/migrate.py +++ b/libs/playhouse/migrate.py @@ -777,7 +777,7 @@ class SqliteMigrator(SchemaMigrator): clean = [] for column in columns: if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column): - column = new_columne + column[len(column_to_update):] + column = new_column + column[len(column_to_update):] clean.append(column) return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean)) diff --git a/libs/playhouse/mysql_ext.py b/libs/playhouse/mysql_ext.py index 8eb2a43fa..9ee265573 100644 --- a/libs/playhouse/mysql_ext.py +++ b/libs/playhouse/mysql_ext.py @@ -7,7 +7,10 @@ except ImportError: from peewee import ImproperlyConfigured from peewee import MySQLDatabase +from peewee import NodeList +from peewee import SQL from peewee import TextField +from peewee import fn class MySQLConnectorDatabase(MySQLDatabase): @@ -18,7 +21,10 @@ class MySQLConnectorDatabase(MySQLDatabase): def cursor(self, commit=None): if self.is_closed(): - self.connect() + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') return self._state.conn.cursor(buffered=True) @@ -32,3 +38,12 @@ class JSONField(TextField): def python_value(self, value): if value is not None: return json.loads(value) + + +def Match(columns, expr, modifier=None): + if isinstance(columns, (list, tuple)): + match = fn.MATCH(*columns) # Tuple of one or more columns / fields. + else: + match = fn.MATCH(columns) # Single column / field. + args = expr if modifier is None else NodeList((expr, SQL(modifier))) + return NodeList((match, fn.AGAINST(args))) diff --git a/libs/playhouse/pool.py b/libs/playhouse/pool.py index 9ade1da94..2ee3b486f 100644 --- a/libs/playhouse/pool.py +++ b/libs/playhouse/pool.py @@ -33,6 +33,7 @@ That's it! """ import heapq import logging +import random import time from collections import namedtuple from itertools import chain @@ -153,7 +154,7 @@ class PooledDatabase(object): len(self._in_use) >= self._max_connections): raise MaxConnectionsExceeded('Exceeded maximum connections.') conn = super(PooledDatabase, self)._connect() - ts = time.time() + ts = time.time() - random.random() / 1000 key = self.conn_key(conn) logger.debug('Created new connection %s.', key) diff --git a/libs/playhouse/postgres_ext.py b/libs/playhouse/postgres_ext.py index 6a2893eb5..64f44073f 100644 --- a/libs/playhouse/postgres_ext.py +++ b/libs/playhouse/postgres_ext.py @@ -3,6 +3,7 @@ Collection of postgres-specific extensions, currently including: * Support for hstore, a key/value type storage """ +import json import logging import uuid @@ -277,18 +278,19 @@ class HStoreField(IndexedFieldMixin, Field): class JSONField(Field): field_type = 'JSON' + _json_datatype = 'json' def __init__(self, dumps=None, *args, **kwargs): if Json is None: raise Exception('Your version of psycopg2 does not support JSON.') - self.dumps = dumps + self.dumps = dumps or json.dumps super(JSONField, self).__init__(*args, **kwargs) def db_value(self, value): if value is None: return value if not isinstance(value, Json): - return Json(value, dumps=self.dumps) + return Cast(self.dumps(value), self._json_datatype) return value def __getitem__(self, value): @@ -307,6 +309,7 @@ def cast_jsonb(node): class BinaryJSONField(IndexedFieldMixin, JSONField): field_type = 'JSONB' + _json_datatype = 'jsonb' __hash__ = Field.__hash__ def contains(self, other): @@ -449,7 +452,10 @@ class PostgresqlExtDatabase(PostgresqlDatabase): def cursor(self, commit=None): if self.is_closed(): - self.connect() + if self.autoconnect: + self.connect() + else: + raise InterfaceError('Error, database connection not opened.') if commit is __named_cursor__: return self._state.conn.cursor(name=str(uuid.uuid1())) return self._state.conn.cursor() diff --git a/libs/playhouse/reflection.py b/libs/playhouse/reflection.py index 14ab1ba4b..3a8f525eb 100644 --- a/libs/playhouse/reflection.py +++ b/libs/playhouse/reflection.py @@ -224,7 +224,7 @@ class PostgresqlMetadata(Metadata): 23: IntegerField, 25: TextField, 700: FloatField, - 701: FloatField, + 701: DoubleField, 1042: CharField, # blank-padded CHAR 1043: CharField, 1082: DateField, diff --git a/libs/playhouse/shortcuts.py b/libs/playhouse/shortcuts.py index e1851b181..1772cf1d3 100644 --- a/libs/playhouse/shortcuts.py +++ b/libs/playhouse/shortcuts.py @@ -73,7 +73,7 @@ def model_to_dict(model, recurse=True, backrefs=False, only=None, field_data = model.__data__.get(field.name) if isinstance(field, ForeignKeyField) and recurse: - if field_data: + if field_data is not None: seen.add(field) rel_obj = getattr(model, field.name) field_data = model_to_dict( @@ -191,6 +191,10 @@ class ReconnectMixin(object): (OperationalError, '2006'), # MySQL server has gone away. (OperationalError, '2013'), # Lost connection to MySQL server. (OperationalError, '2014'), # Commands out of sync. + + # mysql-connector raises a slightly different error when an idle + # connection is terminated by the server. This is equivalent to 2013. + (OperationalError, 'MySQL Connection not available.'), ) def __init__(self, *args, **kwargs): diff --git a/libs/playhouse/signals.py b/libs/playhouse/signals.py index f070bdfdb..4e92872e5 100644 --- a/libs/playhouse/signals.py +++ b/libs/playhouse/signals.py @@ -65,7 +65,7 @@ class Model(_Model): pre_init.send(self) def save(self, *args, **kwargs): - pk_value = self._pk + pk_value = self._pk if self._meta.primary_key else True created = kwargs.get('force_insert', False) or not bool(pk_value) pre_save.send(self, created=created) ret = super(Model, self).save(*args, **kwargs) diff --git a/libs/playhouse/sqlite_ext.py b/libs/playhouse/sqlite_ext.py index c97cbd252..d9504c5fd 100644 --- a/libs/playhouse/sqlite_ext.py +++ b/libs/playhouse/sqlite_ext.py @@ -69,6 +69,11 @@ class AutoIncrementField(AutoField): return NodeList((node_list, SQL('AUTOINCREMENT'))) +class TDecimalField(DecimalField): + field_type = 'TEXT' + def get_modifiers(self): pass + + class JSONPath(ColumnBase): def __init__(self, field, path=None): super(JSONPath, self).__init__() @@ -1190,6 +1195,33 @@ def bm25(raw_match_info, *args): L_O = A_O + col_count X_O = L_O + col_count + # Worked example of pcnalx for two columns and two phrases, 100 docs total. + # { + # p = 2 + # c = 2 + # n = 100 + # a0 = 4 -- avg number of tokens for col0, e.g. title + # a1 = 40 -- avg number of tokens for col1, e.g. body + # l0 = 5 -- curr doc has 5 tokens in col0 + # l1 = 30 -- curr doc has 30 tokens in col1 + # + # x000 -- hits this row for phrase0, col0 + # x001 -- hits all rows for phrase0, col0 + # x002 -- rows with phrase0 in col0 at least once + # + # x010 -- hits this row for phrase0, col1 + # x011 -- hits all rows for phrase0, col1 + # x012 -- rows with phrase0 in col1 at least once + # + # x100 -- hits this row for phrase1, col0 + # x101 -- hits all rows for phrase1, col0 + # x102 -- rows with phrase1 in col0 at least once + # + # x110 -- hits this row for phrase1, col1 + # x111 -- hits all rows for phrase1, col1 + # x112 -- rows with phrase1 in col1 at least once + # } + weights = get_weights(col_count, args) for i in range(term_count): @@ -1213,8 +1245,8 @@ def bm25(raw_match_info, *args): avg_length = float(match_info[A_O + j]) or 1. # avgdl ratio = doc_length / avg_length - num = term_frequency * (K + 1) - b_part = 1 - B + (B * ratio) + num = term_frequency * (K + 1.0) + b_part = 1.0 - B + (B * ratio) denom = term_frequency + (K * b_part) pc_score = idf * (num / denom) From a9c26691a22c33f40377623c77015ca51ca26234 Mon Sep 17 00:00:00 2001 From: Halali Date: Sat, 19 Oct 2019 21:21:23 +0200 Subject: [PATCH 6/8] Change ordering source for wanted items --- bazarr/database.py | 2 ++ bazarr/main.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bazarr/database.py b/bazarr/database.py index 94846abd2..95168a980 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -70,6 +70,7 @@ class TableShows(BaseModel): class TableEpisodes(BaseModel): + rowid = IntegerField() audio_codec = TextField(null=True) episode = IntegerField(null=False) failed_attempts = TextField(column_name='failedAttempts', null=True) @@ -97,6 +98,7 @@ class TableEpisodes(BaseModel): class TableMovies(BaseModel): + rowid = IntegerField() alternative_titles = TextField(column_name='alternativeTitles', null=True) audio_codec = TextField(null=True) audio_language = TextField(null=True) diff --git a/bazarr/main.py b/bazarr/main.py index 9b42594a4..cce6435e3 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -1395,7 +1395,7 @@ def wantedseries(): ).where( reduce(operator.and_, missing_subtitles_clause) ).order_by( - TableEpisodes.sonarr_episode_id.desc() + TableEpisodes.rowid.desc() ).paginate( int(page), page_size @@ -1439,7 +1439,7 @@ def wantedmovies(): ).where( reduce(operator.and_, missing_subtitles_clause) ).order_by( - TableMovies.radarr_id.desc() + TableMovies.rowid.desc() ).paginate( int(page), page_size From 698bca14b440b026edccea6cadf43fb1194fd5bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20V=C3=A9zina?= <5130500+morpheus65535@users.noreply.github.com> Date: Sat, 19 Oct 2019 16:37:40 -0400 Subject: [PATCH 7/8] Major fixes to database subsystem. --- bazarr/get_episodes.py | 6 +----- bazarr/get_movies.py | 3 --- bazarr/get_series.py | 2 +- bazarr/get_subtitle.py | 5 ----- bazarr/list_subtitles.py | 28 +++++++++++++++++++++------- bazarr/main.py | 14 +++----------- 6 files changed, 26 insertions(+), 32 deletions(-) diff --git a/bazarr/get_episodes.py b/bazarr/get_episodes.py index d7a91a4bc..4173fe4bd 100644 --- a/bazarr/get_episodes.py +++ b/bazarr/get_episodes.py @@ -16,8 +16,6 @@ from get_subtitle import episode_download_subtitles def update_all_episodes(): series_full_scan_subtitles() logging.info('BAZARR All existing episode subtitles indexed from disk.') - list_missing_subtitles() - logging.info('BAZARR All missing episode subtitles updated in database.') wal_cleaning() @@ -172,8 +170,7 @@ def sync_episodes(): added_episode ).on_conflict_ignore().execute() altered_episodes.append([added_episode['sonarr_episode_id'], - added_episode['path'], - added_episode['sonarr_series_id']]) + added_episode['path']]) # Remove old episodes from DB removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr)) @@ -188,7 +185,6 @@ def sync_episodes(): notifications.write(msg='Indexing episodes embedded subtitles...', queue='get_episodes', item=i, length=len(altered_episodes)) store_subtitles(path_replace(altered_episode[1])) - list_missing_subtitles(altered_episode[2]) logging.debug('BAZARR All episodes synced from Sonarr into database.') diff --git a/bazarr/get_movies.py b/bazarr/get_movies.py index f39592db8..d8a54f56b 100644 --- a/bazarr/get_movies.py +++ b/bazarr/get_movies.py @@ -18,8 +18,6 @@ from database import TableMovies, wal_cleaning def update_all_movies(): movies_full_scan_subtitles() logging.info('BAZARR All existing movie subtitles indexed from disk.') - list_missing_subtitles_movies() - logging.info('BAZARR All missing movie subtitles updated in database.') wal_cleaning() @@ -269,7 +267,6 @@ def update_movies(): notifications.write(msg='Indexing movies embedded subtitles...', queue='get_movies', item=i, length=len(altered_movies)) store_subtitles_movie(path_replace_movie(altered_movie[1])) - list_missing_subtitles_movies(altered_movie[2]) logging.debug('BAZARR All movies synced from Radarr into database.') diff --git a/bazarr/get_series.py b/bazarr/get_series.py index fc56b5a67..eb8d2699c 100644 --- a/bazarr/get_series.py +++ b/bazarr/get_series.py @@ -156,7 +156,7 @@ def update_series(): TableShows.insert( added_series ).on_conflict_ignore().execute() - list_missing_subtitles(added_series['sonarr_series_id']) + list_missing_subtitles(no=added_series['sonarr_series_id']) # Remove old series from DB removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr)) diff --git a/bazarr/get_subtitle.py b/bazarr/get_subtitle.py index 9207e6b9e..67adc7cf4 100644 --- a/bazarr/get_subtitle.py +++ b/bazarr/get_subtitle.py @@ -611,7 +611,6 @@ def series_download_subtitles(no): notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long') logging.info("BAZARR All providers are throttled") break - list_missing_subtitles(no) if count_episodes_details: notifications.write(msg='Search Complete. Please Reload The Page.', type='success', duration='permanent', @@ -671,7 +670,6 @@ def episode_download_subtitles(no): store_subtitles(path_replace(episode.path)) history_log(1, episode.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score) send_notifications(episode.sonarr_series_id, episode.sonarr_episode_id, message) - list_missing_subtitles(episode.sonarr_series_id) else: notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long') logging.info("BAZARR All providers are throttled") @@ -724,7 +722,6 @@ def movies_download_subtitles(no): notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long') logging.info("BAZARR All providers are throttled") break - list_missing_subtitles_movies(no) if count_movie: notifications.write(msg='Search Complete. Please Reload The Page.', type='success', duration='permanent', @@ -796,7 +793,6 @@ def wanted_download_subtitles(path, l, count_episodes): provider = result[3] score = result[4] store_subtitles(path_replace(episode.path)) - list_missing_subtitles(episode.sonarr_series_id.sonarr_series_id) history_log(1, episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score) send_notifications(episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message) else: @@ -865,7 +861,6 @@ def wanted_download_subtitles_movie(path, l, count_movies): provider = result[3] score = result[4] store_subtitles_movie(path_replace_movie(movie.path)) - list_missing_subtitles_movies(movie.radarr_id) history_log_movie(1, movie.radarr_id, message, path, language_code, provider, score) send_notifications_movie(movie.radarr_id, message) else: diff --git a/bazarr/list_subtitles.py b/bazarr/list_subtitles.py index 58abb1762..1fd38eb4b 100644 --- a/bazarr/list_subtitles.py +++ b/bazarr/list_subtitles.py @@ -120,7 +120,15 @@ def store_subtitles(file): logging.debug("BAZARR this file doesn't seems to exist or isn't accessible.") logging.debug('BAZARR ended subtitles indexing for this file: ' + file) - + + episode = TableEpisodes.select( + TableEpisodes.sonarr_episode_id + ).where( + TableEpisodes.path == path_replace_reverse(file) + ).first() + + list_missing_subtitles(epno=episode.sonarr_episode_id) + return actual_subtitles @@ -217,14 +225,24 @@ def store_subtitles_movie(file): logging.debug("BAZARR this file doesn't seems to exist or isn't accessible.") logging.debug('BAZARR ended subtitles indexing for this file: ' + file) - + + movie = TableMovies.select( + TableMovies.radarr_id + ).where( + TableMovies.path == path_replace_reverse_movie(file) + ).first() + + list_missing_subtitles_movies(no=movie.radarr_id) + return actual_subtitles -def list_missing_subtitles(no=None): +def list_missing_subtitles(no=None, epno=None): episodes_subtitles_clause = (TableShows.sonarr_series_id.is_null(False)) if no is not None: episodes_subtitles_clause = (TableShows.sonarr_series_id == no) + elif epno is not None: + episodes_subtitles_clause = (TableEpisodes.sonarr_episode_id == epno) episodes_subtitles = TableEpisodes.select( TableShows.sonarr_series_id, TableEpisodes.sonarr_episode_id, @@ -388,8 +406,6 @@ def series_scan_subtitles(no): for episode in episodes: store_subtitles(path_replace(episode.path)) - - list_missing_subtitles(no) def movies_scan_subtitles(no): @@ -401,8 +417,6 @@ def movies_scan_subtitles(no): for movie in movies: store_subtitles_movie(path_replace_movie(movie.path)) - - list_missing_subtitles_movies(no) def get_external_subtitles_path(file, subtitle): diff --git a/bazarr/main.py b/bazarr/main.py index b52688e47..0d6bad36b 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -1,6 +1,6 @@ # coding=utf-8 -bazarr_version = '0.8.2.4' +bazarr_version = '0.8.2.5' import gc import sys @@ -736,7 +736,7 @@ def edit_series(no): TableShows.sonarr_series_id == no ).execute() - list_missing_subtitles(no) + list_missing_subtitles(no=no) redirect(ref) @@ -784,7 +784,7 @@ def edit_serieseditor(): ).execute() for serie in series: - list_missing_subtitles(serie) + list_missing_subtitles(no=serie) redirect(ref) @@ -2050,7 +2050,6 @@ def remove_subtitles(): except OSError as e: logging.exception('BAZARR cannot delete subtitles file: ' + subtitlesPath) store_subtitles(unicode(episodePath)) - list_missing_subtitles(sonarrSeriesId) @route(base_url + 'remove_subtitles_movie', method='POST') @@ -2069,7 +2068,6 @@ def remove_subtitles_movie(): except OSError as e: logging.exception('BAZARR cannot delete subtitles file: ' + subtitlesPath) store_subtitles_movie(unicode(moviePath)) - list_missing_subtitles_movies(radarrId) @route(base_url + 'get_subtitle', method='POST') @@ -2103,7 +2101,6 @@ def get_subtitle(): history_log(1, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score) send_notifications(sonarrSeriesId, sonarrEpisodeId, message) store_subtitles(unicode(episodePath)) - list_missing_subtitles(sonarrSeriesId) redirect(ref) except OSError: pass @@ -2161,7 +2158,6 @@ def manual_get_subtitle(): history_log(2, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score) send_notifications(sonarrSeriesId, sonarrEpisodeId, message) store_subtitles(unicode(episodePath)) - list_missing_subtitles(sonarrSeriesId) redirect(ref) except OSError: pass @@ -2205,7 +2201,6 @@ def perform_manual_upload_subtitle(): history_log(4, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score) send_notifications(sonarrSeriesId, sonarrEpisodeId, message) store_subtitles(unicode(episodePath)) - list_missing_subtitles(sonarrSeriesId) redirect(ref) except OSError: @@ -2242,7 +2237,6 @@ def get_subtitle_movie(): history_log_movie(1, radarrId, message, path, language_code, provider, score) send_notifications_movie(radarrId, message) store_subtitles_movie(unicode(moviePath)) - list_missing_subtitles_movies(radarrId) redirect(ref) except OSError: pass @@ -2298,7 +2292,6 @@ def manual_get_subtitle_movie(): history_log_movie(2, radarrId, message, path, language_code, provider, score) send_notifications_movie(radarrId, message) store_subtitles_movie(unicode(moviePath)) - list_missing_subtitles_movies(radarrId) redirect(ref) except OSError: pass @@ -2341,7 +2334,6 @@ def perform_manual_upload_subtitle_movie(): history_log_movie(4, radarrId, message, path, language_code, provider, score) send_notifications_movie(radarrId, message) store_subtitles_movie(unicode(moviePath)) - list_missing_subtitles_movies(radarrId) redirect(ref) except OSError: From e452394841197fbae72195b87c67078385565de6 Mon Sep 17 00:00:00 2001 From: panni Date: Sat, 19 Oct 2019 23:20:12 +0200 Subject: [PATCH 8/8] core: update to subliminal_patch:head; addic7ed: fix captcha solving; fix getting show list --- libs/subliminal_patch/core.py | 35 +++++++-------------- libs/subliminal_patch/providers/addic7ed.py | 23 ++++++++------ 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/libs/subliminal_patch/core.py b/libs/subliminal_patch/core.py index 9fe4f4d35..46d701dc8 100644 --- a/libs/subliminal_patch/core.py +++ b/libs/subliminal_patch/core.py @@ -28,7 +28,7 @@ from subliminal.utils import hash_napiprojekt, hash_opensubtitles, hash_shooter, from subliminal.video import VIDEO_EXTENSIONS, Video, Episode, Movie from subliminal.core import guessit, ProviderPool, io, is_windows_special_path, \ ThreadPoolExecutor, check_video -from subliminal_patch.exceptions import TooManyRequests, APIThrottled, ParseResponseError +from subliminal_patch.exceptions import TooManyRequests, APIThrottled from subzero.language import Language, ENDSWITH_LANGUAGECODE_RE from scandir import scandir, scandir_generic as _scandir_generic @@ -280,14 +280,10 @@ class SZProviderPool(ProviderPool): logger.debug("RAR Traceback: %s", traceback.format_exc()) return False - except (TooManyRequests, DownloadLimitExceeded, ServiceUnavailable, APIThrottled, ParseResponseError) as e: - self.throttle_callback(subtitle.provider_name, e) - self.discarded_providers.add(subtitle.provider_name) - return False - - except: + except Exception as e: logger.exception('Unexpected error in provider %r, Traceback: %s', subtitle.provider_name, traceback.format_exc()) + self.throttle_callback(subtitle.provider_name, e) self.discarded_providers.add(subtitle.provider_name) return False @@ -611,16 +607,6 @@ def _search_external_subtitles(path, languages=None, only_one=False, scandir_gen if adv_tag: forced = "forced" in adv_tag - # extract the potential language code - try: - language_code = p_root.rsplit(".", 1)[1].replace('_', '-') - try: - Language.fromietf(language_code) - except: - language_code = None - except IndexError: - language_code = None - # remove possible language code for matching p_root_bare = ENDSWITH_LANGUAGECODE_RE.sub("", p_root) @@ -633,19 +619,21 @@ def _search_external_subtitles(path, languages=None, only_one=False, scandir_gen if match_strictness == "strict" or (match_strictness == "loose" and not filename_contains): continue - # default language is undefined - language = Language('und') + language = None - # attempt to parse - if language_code: + # extract the potential language code + try: + language_code = p_root.rsplit(".", 1)[1].replace('_', '-') try: language = Language.fromietf(language_code) language.forced = forced except ValueError: logger.error('Cannot parse language code %r', language_code) - language = None + language_code = None + except IndexError: + language_code = None - elif not language_code and only_one: + if not language and not language_code and only_one: language = Language.rebuild(list(languages)[0], forced=forced) subtitles[p] = language @@ -875,6 +863,7 @@ def save_subtitles(file_path, subtitles, single=False, directory=None, chmod=Non if content: if os.path.exists(subtitle_path): os.remove(subtitle_path) + with open(subtitle_path, 'w') as f: f.write(content) subtitle.storage_path = subtitle_path diff --git a/libs/subliminal_patch/providers/addic7ed.py b/libs/subliminal_patch/providers/addic7ed.py index 8507052d3..1e04821b0 100644 --- a/libs/subliminal_patch/providers/addic7ed.py +++ b/libs/subliminal_patch/providers/addic7ed.py @@ -10,7 +10,7 @@ from requests import Session from subliminal.cache import region from subliminal.exceptions import DownloadLimitExceeded, AuthenticationError from subliminal.providers.addic7ed import Addic7edProvider as _Addic7edProvider, \ - Addic7edSubtitle as _Addic7edSubtitle, ParserBeautifulSoup, show_cells_re + Addic7edSubtitle as _Addic7edSubtitle, ParserBeautifulSoup from subliminal.subtitle import fix_line_ending from subliminal_patch.utils import sanitize from subliminal_patch.exceptions import TooManyRequests @@ -19,6 +19,8 @@ from subzero.language import Language logger = logging.getLogger(__name__) +show_cells_re = re.compile(b'.*?', re.DOTALL) + #: Series header parsing regex series_year_re = re.compile(r'^(?P[ \w\'.:(),*&!?-]+?)(?: \((?P\d{4})\))?$') @@ -103,11 +105,15 @@ class Addic7edProvider(_Addic7edProvider): tries = 0 while tries < 3: r = self.session.get(self.server_url + 'login.php', timeout=10, headers={"Referer": self.server_url}) - if "grecaptcha" in r.content: + if "g-recaptcha" in r.content or "grecaptcha" in r.content: logger.info('Addic7ed: Solving captcha. This might take a couple of minutes, but should only ' 'happen once every so often') - site_key = re.search(r'grecaptcha.execute\(\'(.+?)\',', r.content).group(1) + for g, s in (("g-recaptcha-response", r'g-recaptcha.+?data-sitekey=\"(.+?)\"'), + ("recaptcha_response", r'grecaptcha.execute\(\'(.+?)\',')): + site_key = re.search(s, r.content).group(1) + if site_key: + break if not site_key: logger.error("Addic7ed: Captcha site-key not found!") return @@ -121,7 +127,7 @@ class Addic7edProvider(_Addic7edProvider): if not result: raise Exception("Addic7ed: Couldn't solve captcha!") - data["recaptcha_response"] = result + data[g] = result r = self.session.post(self.server_url + 'dologin.php', data, allow_redirects=False, timeout=10, headers={"Referer": self.server_url + "login.php"}) @@ -129,12 +135,11 @@ class Addic7edProvider(_Addic7edProvider): if "relax, slow down" in r.content: raise TooManyRequests(self.username) - if r.status_code != 302: - if "User doesn't exist" in r.content and tries <= 2: - logger.info("Addic7ed: Error, trying again. (%s/%s)", tries+1, 3) - tries += 1 - continue + if "Try again" in r.content or "Wrong password" in r.content: + raise AuthenticationError(self.username) + if r.status_code != 302: + logger.error("Addic7ed: Something went wrong when logging in") raise AuthenticationError(self.username) break