Merge branch 'development' into python3

# Conflicts:
#	bazarr/get_episodes.py
#	bazarr/get_movies.py
#	bazarr/get_subtitle.py
#	bazarr/list_subtitles.py
#	bazarr/main.py
#	libs/subliminal_patch/core.py
#	libs/subliminal_patch/providers/addic7ed.py
pull/684/head
Louis Vézina 5 years ago
commit 995b9ac9ae

@ -295,7 +295,8 @@ def request_json(url, **kwargs):
def updated(restart=True):
if settings.general.getboolean('update_restart') and restart:
try:
requests.get(bazarr_url + 'restart')
from main import restart
restart()
except requests.ConnectionError:
logging.info('BAZARR Restart failed, please restart Bazarr manualy')
updated(restart=False)

@ -8,13 +8,17 @@ from playhouse.sqliteq import SqliteQueueDatabase
from playhouse.migrate import *
database = SqliteQueueDatabase(
None,
os.path.join(args.config_dir, 'db', 'bazarr.db'),
use_gevent=False,
autostart=False,
autostart=True,
queue_max_size=256, # Max. # of pending writes that can accumulate.
results_timeout=30.0) # Max. time to wait for query to be executed.
results_timeout=30.0 # Max. time to wait for query to be executed.
)
migrator = SqliteMigrator(database)
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
database.cache_size = -1024 # Number of KB of cache for wal-journal.
# Must be negative because positive means number of pages.
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.
@database.func('path_substitution')
@ -62,15 +66,12 @@ class TableShows(BaseModel):
tvdb_id = IntegerField(column_name='tvdbId', null=True, unique=True, primary_key=True)
year = TextField(null=True)
migrate(
migrator.add_column('table_shows', 'forced', forced),
)
class Meta:
table_name = 'table_shows'
class TableEpisodes(BaseModel):
rowid = IntegerField()
audio_codec = TextField(null=True)
episode = IntegerField(null=False)
failed_attempts = TextField(column_name='failedAttempts', null=True)
@ -88,16 +89,13 @@ class TableEpisodes(BaseModel):
video_codec = TextField(null=True)
episode_file_id = IntegerField(null=True)
migrate(
migrator.add_column('table_episodes', 'episode_file_id', episode_file_id),
)
class Meta:
table_name = 'table_episodes'
primary_key = False
class TableMovies(BaseModel):
rowid = IntegerField()
alternative_titles = TextField(column_name='alternativeTitles', null=True)
audio_codec = TextField(null=True)
audio_language = TextField(null=True)
@ -124,11 +122,6 @@ class TableMovies(BaseModel):
year = TextField(null=True)
movie_file_id = IntegerField(null=True)
migrate(
migrator.add_column('table_movies', 'forced', forced),
migrator.add_column('table_movies', 'movie_file_id', movie_file_id),
)
class Meta:
table_name = 'table_movies'
@ -184,22 +177,39 @@ class TableSettingsNotifier(BaseModel):
table_name = 'table_settings_notifier'
def database_init():
database.init(os.path.join(args.config_dir, 'db', 'bazarr.db'))
database.start()
database.connect()
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
database.cache_size = -1024 # Number of KB of cache for wal-journal.
# Must be negative because positive means number of pages.
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.
# Database tables creation if they don't exists
models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages,
TableSettingsNotifier, System]
database.create_tables(models_list, safe=True)
# Database migration
migrator = SqliteMigrator(database)
# TableShows migration
table_shows_columns = []
for column in database.get_columns('table_shows'):
table_shows_columns.append(column.name)
if 'forced' not in table_shows_columns:
migrate(migrator.add_column('table_shows', 'forced', TableShows.forced))
# TableEpisodes migration
table_episodes_columns = []
for column in database.get_columns('table_episodes'):
table_episodes_columns.append(column.name)
if 'episode_file_id' not in table_episodes_columns:
migrate(migrator.add_column('table_episodes', 'episode_file_id', TableEpisodes.episode_file_id))
# TableMovies migration
table_movies_columns = []
for column in database.get_columns('table_movies'):
table_movies_columns.append(column.name)
if 'forced' not in table_movies_columns:
migrate(migrator.add_column('table_movies', 'forced', TableMovies.forced))
if 'movie_file_id' not in table_movies_columns:
migrate(migrator.add_column('table_movies', 'movie_file_id', TableMovies.movie_file_id))
def wal_cleaning():
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.

@ -17,8 +17,6 @@ from get_subtitle import episode_download_subtitles
def update_all_episodes():
series_full_scan_subtitles()
logging.info('BAZARR All existing episode subtitles indexed from disk.')
list_missing_subtitles()
logging.info('BAZARR All missing episode subtitles updated in database.')
wal_cleaning()
@ -173,8 +171,7 @@ def sync_episodes():
added_episode
).on_conflict_ignore().execute()
altered_episodes.append([added_episode['sonarr_episode_id'],
added_episode['path'],
added_episode['sonarr_series_id']])
added_episode['path']])
# Remove old episodes from DB
removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr))
@ -189,7 +186,6 @@ def sync_episodes():
notifications.write(msg='Indexing episodes embedded subtitles...', queue='get_episodes', item=i,
length=len(altered_episodes))
store_subtitles(altered_episode[1], path_replace(altered_episode[1]))
list_missing_subtitles(altered_episode[2])
logging.debug('BAZARR All episodes synced from Sonarr into database.')

@ -20,8 +20,6 @@ import six
def update_all_movies():
movies_full_scan_subtitles()
logging.info('BAZARR All existing movie subtitles indexed from disk.')
list_missing_subtitles_movies()
logging.info('BAZARR All missing movie subtitles updated in database.')
wal_cleaning()
@ -271,7 +269,6 @@ def update_movies():
notifications.write(msg='Indexing movies embedded subtitles...', queue='get_movies', item=i,
length=len(altered_movies))
store_subtitles_movie(altered_movie[1], path_replace_movie(altered_movie[1]))
list_missing_subtitles_movies(altered_movie[2])
logging.debug('BAZARR All movies synced from Radarr into database.')

@ -159,7 +159,7 @@ def update_series():
TableShows.insert(
added_series
).on_conflict_ignore().execute()
list_missing_subtitles(added_series['sonarr_series_id'])
list_missing_subtitles(no=added_series['sonarr_series_id'])
# Remove old series from DB
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))

@ -34,7 +34,7 @@ from get_providers import get_providers, get_providers_auth, provider_throttle,
from get_args import args
from queueconfig import notifications
from pyprobe.pyprobe import VideoFileParser
from database import TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie
from database import database, TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie
from peewee import fn, JOIN
from analytics import track_event
@ -614,7 +614,6 @@ def series_download_subtitles(no):
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
logging.info("BAZARR All providers are throttled")
break
list_missing_subtitles(no)
if count_episodes_details:
notifications.write(msg='Search Complete. Please Reload The Page.', type='success', duration='permanent',
@ -623,7 +622,7 @@ def series_download_subtitles(no):
def episode_download_subtitles(no):
episodes_details_clause = [
(TableEpisodes.sonarr_series_id == no)
(TableEpisodes.sonarr_episode_id == no)
]
if settings.sonarr.getboolean('only_monitored'):
episodes_details_clause.append(
@ -674,7 +673,6 @@ def episode_download_subtitles(no):
store_subtitles(episode.path, path_replace(episode.path))
history_log(1, episode.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score)
send_notifications(episode.sonarr_series_id, episode.sonarr_episode_id, message)
list_missing_subtitles(episode.sonarr_series_id)
else:
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
logging.info("BAZARR All providers are throttled")
@ -727,7 +725,6 @@ def movies_download_subtitles(no):
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
logging.info("BAZARR All providers are throttled")
break
list_missing_subtitles_movies(no)
if count_movie:
notifications.write(msg='Search Complete. Please Reload The Page.', type='success', duration='permanent',
@ -799,7 +796,6 @@ def wanted_download_subtitles(path, l, count_episodes):
provider = result[3]
score = result[4]
store_subtitles(episode.path, path_replace(episode.path))
list_missing_subtitles(episode.sonarr_series_id.sonarr_series_id)
history_log(1, episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score)
send_notifications(episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message)
else:
@ -868,7 +864,6 @@ def wanted_download_subtitles_movie(path, l, count_movies):
provider = result[3]
score = result[4]
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
list_missing_subtitles_movies(movie.radarr_id)
history_log_movie(1, movie.radarr_id, message, path, language_code, provider, score)
send_notifications_movie(movie.radarr_id, message)
else:
@ -1184,7 +1179,7 @@ def upgrade_subtitles():
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
logging.info("BAZARR All providers are throttled")
return
if episode['languages'] != "None":
if episode['languages']:
desired_languages = ast.literal_eval(str(episode['languages']))
if episode['forced'] == "True":
forced_languages = [l + ":forced" for l in desired_languages]
@ -1233,7 +1228,7 @@ def upgrade_subtitles():
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
logging.info("BAZARR All providers are throttled")
return
if movie['languages'] != "None":
if movie['languages']:
desired_languages = ast.literal_eval(str(movie['languages']))
if movie['forced'] == "True":
forced_languages = [l + ":forced" for l in desired_languages]

@ -124,6 +124,14 @@ def store_subtitles(original_path, reversed_path):
logging.debug('BAZARR ended subtitles indexing for this file: ' + reversed_path)
episode = TableEpisodes.select(
TableEpisodes.sonarr_episode_id
).where(
TableEpisodes.path == path_replace_reverse(file)
).first()
list_missing_subtitles(epno=episode.sonarr_episode_id)
return actual_subtitles
@ -224,15 +232,25 @@ def store_subtitles_movie(original_path, reversed_path):
logging.debug('BAZARR ended subtitles indexing for this file: ' + reversed_path)
movie = TableMovies.select(
TableMovies.radarr_id
).where(
TableMovies.path == path_replace_reverse_movie(file)
).first()
list_missing_subtitles_movies(no=movie.radarr_id)
return actual_subtitles
def list_missing_subtitles(no=None):
episodes_subtitles_clause = [(TableShows.sonarr_series_id.is_null(False))]
def list_missing_subtitles(no=None, epno=None):
episodes_subtitles_clause = (TableShows.sonarr_series_id.is_null(False))
if no is not None:
episodes_subtitles_clause.append((TableShows.sonarr_series_id ** no))
episodes_subtitles_clause = (TableShows.sonarr_series_id == no)
elif epno is not None:
episodes_subtitles_clause = (TableEpisodes.sonarr_episode_id == epno)
episodes_subtitles = TableEpisodes.select(
TableShows.sonarr_series_id,
TableEpisodes.sonarr_episode_id,
TableEpisodes.subtitles,
TableShows.languages,
@ -294,9 +312,9 @@ def list_missing_subtitles(no=None):
def list_missing_subtitles_movies(no=None):
movies_subtitles_clause = [(TableMovies.radarr_id.is_null(False))]
movies_subtitles_clause = (TableMovies.radarr_id.is_null(False))
if no is not None:
movies_subtitles_clause.append((TableMovies.radarr_id == no))
movies_subtitles_clause = (TableMovies.radarr_id == no)
movies_subtitles = TableMovies.select(
TableMovies.radarr_id,
@ -395,8 +413,6 @@ def series_scan_subtitles(no):
for episode in episodes:
store_subtitles(episode.path, path_replace(episode.path))
list_missing_subtitles(no)
def movies_scan_subtitles(no):
movies = TableMovies.select(
@ -408,8 +424,6 @@ def movies_scan_subtitles(no):
for movie in movies:
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
list_missing_subtitles_movies(no)
def get_external_subtitles_path(file, subtitle):
fld = os.path.dirname(file)

@ -28,12 +28,9 @@ from calendar import day_name
from get_args import args
from init import *
from database import database, database_init, TableEpisodes, TableShows, TableMovies, TableHistory, TableHistoryMovie, \
from database import database, TableEpisodes, TableShows, TableMovies, TableHistory, TableHistoryMovie, \
TableSettingsLanguages, TableSettingsNotifier, System
# Initiate database
database_init()
from notifier import update_notifier
from logger import configure_logging, empty_log
@ -742,10 +739,10 @@ def edit_series(no):
TableShows.forced: forced
}
).where(
TableShows.sonarr_series_id ** no
TableShows.sonarr_series_id == no
).execute()
list_missing_subtitles(no)
list_missing_subtitles(no=no)
redirect(ref)
@ -793,7 +790,7 @@ def edit_serieseditor():
).execute()
for serie in series:
list_missing_subtitles(serie)
list_missing_subtitles(no=serie)
redirect(ref)
@ -815,7 +812,7 @@ def episodes(no):
fn.path_substitution(TableShows.path).alias('path'),
TableShows.forced
).where(
TableShows.sonarr_series_id ** str(no)
TableShows.sonarr_series_id == no
).limit(1)
for series in series_details:
tvdbid = series.tvdb_id
@ -1401,7 +1398,7 @@ def wantedseries():
).where(
reduce(operator.and_, missing_subtitles_clause)
).order_by(
TableEpisodes.sonarr_episode_id.desc()
TableEpisodes.rowid.desc()
).paginate(
int(page),
page_size
@ -1445,7 +1442,7 @@ def wantedmovies():
).where(
reduce(operator.and_, missing_subtitles_clause)
).order_by(
TableMovies.radarr_id.desc()
TableMovies.rowid.desc()
).paginate(
int(page),
page_size
@ -2060,8 +2057,7 @@ def remove_subtitles():
history_log(0, sonarrSeriesId, sonarrEpisodeId, result)
except OSError as e:
logging.exception('BAZARR cannot delete subtitles file: ' + subtitlesPath)
store_subtitles(episodePath, six.text_type(episodePath))
list_missing_subtitles(sonarrSeriesId)
store_subtitles(episodePath, unicode(episodePath))
@route(base_url + 'remove_subtitles_movie', method='POST')
@ -2080,7 +2076,6 @@ def remove_subtitles_movie():
except OSError as e:
logging.exception('BAZARR cannot delete subtitles file: ' + subtitlesPath)
store_subtitles_movie(moviePath, six.text_type(moviePath))
list_missing_subtitles_movies(radarrId)
@route(base_url + 'get_subtitle', method='POST')
@ -2114,7 +2109,6 @@ def get_subtitle():
history_log(1, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score)
send_notifications(sonarrSeriesId, sonarrEpisodeId, message)
store_subtitles(episodePath, six.text_type(episodePath))
list_missing_subtitles(sonarrSeriesId)
redirect(ref)
except OSError:
pass
@ -2172,7 +2166,6 @@ def manual_get_subtitle():
history_log(2, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score)
send_notifications(sonarrSeriesId, sonarrEpisodeId, message)
store_subtitles(episodePath, six.text_type(episodePath))
list_missing_subtitles(sonarrSeriesId)
redirect(ref)
except OSError:
pass
@ -2216,7 +2209,6 @@ def perform_manual_upload_subtitle():
history_log(4, sonarrSeriesId, sonarrEpisodeId, message, path, language_code, provider, score)
send_notifications(sonarrSeriesId, sonarrEpisodeId, message)
store_subtitles(episodePath, six.text_type(episodePath))
list_missing_subtitles(sonarrSeriesId)
redirect(ref)
except OSError:
@ -2253,7 +2245,6 @@ def get_subtitle_movie():
history_log_movie(1, radarrId, message, path, language_code, provider, score)
send_notifications_movie(radarrId, message)
store_subtitles_movie(moviePath, six.text_type(moviePath))
list_missing_subtitles_movies(radarrId)
redirect(ref)
except OSError:
pass
@ -2309,7 +2300,6 @@ def manual_get_subtitle_movie():
history_log_movie(2, radarrId, message, path, language_code, provider, score)
send_notifications_movie(radarrId, message)
store_subtitles_movie(moviePath, six.text_type(moviePath))
list_missing_subtitles_movies(radarrId)
redirect(ref)
except OSError:
pass
@ -2352,7 +2342,6 @@ def perform_manual_upload_subtitle_movie():
history_log_movie(4, radarrId, message, path, language_code, provider, score)
send_notifications_movie(radarrId, message)
store_subtitles_movie(moviePath, six.text_type(moviePath))
list_missing_subtitles_movies(radarrId)
redirect(ref)
except OSError:

@ -65,7 +65,7 @@ except ImportError:
mysql = None
__version__ = '3.9.6'
__version__ = '3.11.2'
__all__ = [
'AsIs',
'AutoField',
@ -206,15 +206,15 @@ __sqlite_datetime_formats__ = (
'%H:%M')
__sqlite_date_trunc__ = {
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'year': '%Y-01-01 00:00:00',
'month': '%Y-%m-01 00:00:00',
'day': '%Y-%m-%d 00:00:00',
'hour': '%Y-%m-%d %H:00:00',
'minute': '%Y-%m-%d %H:%M:00',
'second': '%Y-%m-%d %H:%M:%S'}
__mysql_date_trunc__ = __sqlite_date_trunc__.copy()
__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i'
__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00'
__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_part(lookup_type, datetime_string):
@ -460,6 +460,9 @@ class DatabaseProxy(Proxy):
return _savepoint(self)
class ModelDescriptor(object): pass
# SQL Generation.
@ -1141,11 +1144,24 @@ class ColumnBase(Node):
op = OP.IS if is_null else OP.IS_NOT
return Expression(self, op, None)
def contains(self, rhs):
return Expression(self, OP.ILIKE, '%%%s%%' % rhs)
if isinstance(rhs, Node):
rhs = Expression('%', OP.CONCAT,
Expression(rhs, OP.CONCAT, '%'))
else:
rhs = '%%%s%%' % rhs
return Expression(self, OP.ILIKE, rhs)
def startswith(self, rhs):
return Expression(self, OP.ILIKE, '%s%%' % rhs)
if isinstance(rhs, Node):
rhs = Expression(rhs, OP.CONCAT, '%')
else:
rhs = '%s%%' % rhs
return Expression(self, OP.ILIKE, rhs)
def endswith(self, rhs):
return Expression(self, OP.ILIKE, '%%%s' % rhs)
if isinstance(rhs, Node):
rhs = Expression('%', OP.CONCAT, rhs)
else:
rhs = '%%%s' % rhs
return Expression(self, OP.ILIKE, rhs)
def between(self, lo, hi):
return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi)))
def concat(self, rhs):
@ -1229,6 +1245,9 @@ class Alias(WrappedNode):
super(Alias, self).__init__(node)
self._alias = alias
def __hash__(self):
return hash(self._alias)
def alias(self, alias=None):
if alias is None:
return self.node
@ -1376,8 +1395,16 @@ class Expression(ColumnBase):
def __sql__(self, ctx):
overrides = {'parentheses': not self.flat, 'in_expr': True}
if isinstance(self.lhs, Field):
overrides['converter'] = self.lhs.db_value
# First attempt to unwrap the node on the left-hand-side, so that we
# can get at the underlying Field if one is present.
node = raw_node = self.lhs
if isinstance(raw_node, WrappedNode):
node = raw_node.unwrap()
# Set up the appropriate converter if we have a field on the left side.
if isinstance(node, Field) and raw_node._coerce:
overrides['converter'] = node.db_value
else:
overrides['converter'] = None
@ -2090,6 +2117,11 @@ class CompoundSelectQuery(SelectBase):
def _returning(self):
return self.lhs._returning
@database_required
def exists(self, database):
query = Select((self.limit(1),), (SQL('1'),)).bind(database)
return bool(query.scalar())
def _get_query_key(self):
return (self.lhs.get_query_key(), self.rhs.get_query_key())
@ -2101,6 +2133,14 @@ class CompoundSelectQuery(SelectBase):
elif csq_setting == CSQ_PARENTHESES_ALWAYS:
return True
elif csq_setting == CSQ_PARENTHESES_UNNESTED:
if ctx.state.in_expr or ctx.state.in_function:
# If this compound select query is being used inside an
# expression, e.g., an IN or EXISTS().
return False
# If the query on the left or right is itself a compound select
# query, then we do not apply parentheses. However, if it is a
# regular SELECT query, we will apply parentheses.
return not isinstance(subq, CompoundSelectQuery)
def __sql__(self, ctx):
@ -2433,7 +2473,6 @@ class Insert(_WriteQuery):
# Load and organize column defaults (if provided).
defaults = self.get_default_data()
value_lookups = {}
# First figure out what columns are being inserted (if they weren't
# specified explicitly). Resulting columns are normalized and ordered.
@ -2443,47 +2482,48 @@ class Insert(_WriteQuery):
except StopIteration:
raise self.DefaultValuesException('Error: no rows to insert.')
if not isinstance(row, dict):
if not isinstance(row, Mapping):
columns = self.get_default_columns()
if columns is None:
raise ValueError('Bulk insert must specify columns.')
else:
# Infer column names from the dict of data being inserted.
accum = []
uses_strings = False # Are the dict keys strings or columns?
for key in row:
if isinstance(key, basestring):
column = getattr(self.table, key)
uses_strings = True
else:
column = key
for column in row:
if isinstance(column, basestring):
column = getattr(self.table, column)
accum.append(column)
value_lookups[column] = key
# Add any columns present in the default data that are not
# accounted for by the dictionary of row data.
column_set = set(accum)
for col in (set(defaults) - column_set):
accum.append(col)
value_lookups[col] = col.name if uses_strings else col
columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx))
rows_iter = itertools.chain(iter((row,)), rows_iter)
else:
clean_columns = []
seen = set()
for column in columns:
if isinstance(column, basestring):
column_obj = getattr(self.table, column)
else:
column_obj = column
value_lookups[column_obj] = column
clean_columns.append(column_obj)
seen.add(column_obj)
columns = clean_columns
for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)):
if col not in value_lookups:
if col not in seen:
columns.append(col)
value_lookups[col] = col
value_lookups = {}
for column in columns:
lookups = [column, column.name]
if isinstance(column, Field) and column.name != column.column_name:
lookups.append(column.column_name)
value_lookups[column] = lookups
ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
columns_converters = [
@ -2497,7 +2537,18 @@ class Insert(_WriteQuery):
for i, (column, converter) in enumerate(columns_converters):
try:
if is_dict:
val = row[value_lookups[column]]
# The logic is a bit convoluted, but in order to be
# flexible in what we accept (dict keyed by
# column/field, field name, or underlying column name),
# we try accessing the row data dict using each
# possible key. If no match is found, throw an error.
for lookup in value_lookups[column]:
try:
val = row[lookup]
except KeyError: pass
else: break
else:
raise KeyError
else:
val = row[i]
except (KeyError, IndexError):
@ -2544,7 +2595,7 @@ class Insert(_WriteQuery):
.sql(self.table)
.literal(' '))
if isinstance(self._insert, dict) and not self._columns:
if isinstance(self._insert, Mapping) and not self._columns:
try:
self._simple_insert(ctx)
except self.DefaultValuesException:
@ -2817,7 +2868,8 @@ class Database(_callable_context_manager):
truncate_table = True
def __init__(self, database, thread_safe=True, autorollback=False,
field_types=None, operations=None, autocommit=None, **kwargs):
field_types=None, operations=None, autocommit=None,
autoconnect=True, **kwargs):
self._field_types = merge_dict(FIELD, self.field_types)
self._operations = merge_dict(OP, self.operations)
if field_types:
@ -2825,6 +2877,7 @@ class Database(_callable_context_manager):
if operations:
self._operations.update(operations)
self.autoconnect = autoconnect
self.autorollback = autorollback
self.thread_safe = thread_safe
if thread_safe:
@ -2930,7 +2983,10 @@ class Database(_callable_context_manager):
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor()
def execute_sql(self, sql, params=None, commit=SENTINEL):
@ -3141,6 +3197,15 @@ class Database(_callable_context_manager):
def truncate_date(self, date_part, date_field):
raise NotImplementedError
def to_timestamp(self, date_field):
raise NotImplementedError
def from_timestamp(self, date_field):
raise NotImplementedError
def random(self):
return fn.random()
def bind(self, models, bind_refs=True, bind_backrefs=True):
for model in models:
model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs)
@ -3524,10 +3589,17 @@ class SqliteDatabase(Database):
return self._build_on_conflict_update(oc, query)
def extract_date(self, date_part, date_field):
return fn.date_part(date_part, date_field)
return fn.date_part(date_part, date_field, python_value=int)
def truncate_date(self, date_part, date_field):
return fn.date_trunc(date_part, date_field)
return fn.date_trunc(date_part, date_field,
python_value=simple_date_time)
def to_timestamp(self, date_field):
return fn.strftime('%s', date_field).cast('integer')
def from_timestamp(self, date_field):
return fn.datetime(date_field, 'unixepoch')
class PostgresqlDatabase(Database):
@ -3552,9 +3624,11 @@ class PostgresqlDatabase(Database):
safe_create_index = False
sequences = True
def init(self, database, register_unicode=True, encoding=None, **kwargs):
def init(self, database, register_unicode=True, encoding=None,
isolation_level=None, **kwargs):
self._register_unicode = register_unicode
self._encoding = encoding
self._isolation_level = isolation_level
super(PostgresqlDatabase, self).init(database, **kwargs)
def _connect(self):
@ -3566,6 +3640,8 @@ class PostgresqlDatabase(Database):
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
if self._encoding:
conn.set_client_encoding(self._encoding)
if self._isolation_level:
conn.set_isolation_level(self._isolation_level)
return conn
def _set_server_version(self, conn):
@ -3695,6 +3771,13 @@ class PostgresqlDatabase(Database):
def truncate_date(self, date_part, date_field):
return fn.DATE_TRUNC(date_part, date_field)
def to_timestamp(self, date_field):
return self.extract_date('EPOCH', date_field)
def from_timestamp(self, date_field):
# Ironically, here, Postgres means "to the Postgresql timestamp type".
return fn.to_timestamp(date_field)
def get_noop_select(self, ctx):
return ctx.sql(Select().columns(SQL('0')).where(SQL('false')))
@ -3727,9 +3810,13 @@ class MySQLDatabase(Database):
limit_max = 2 ** 64 - 1
safe_create_index = False
safe_drop_index = False
sql_mode = 'PIPES_AS_CONCAT'
def init(self, database, **kwargs):
params = {'charset': 'utf8', 'use_unicode': True}
params = {
'charset': 'utf8',
'sql_mode': self.sql_mode,
'use_unicode': True}
params.update(kwargs)
if 'password' in params and mysql_passwd:
params['passwd'] = params.pop('password')
@ -3876,7 +3963,17 @@ class MySQLDatabase(Database):
return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field)))
def truncate_date(self, date_part, date_field):
return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part])
return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part],
python_value=simple_date_time)
def to_timestamp(self, date_field):
return fn.UNIX_TIMESTAMP(date_field)
def from_timestamp(self, date_field):
return fn.FROM_UNIXTIME(date_field)
def random(self):
return fn.rand()
def get_noop_select(self, ctx):
return ctx.literal('DO 0')
@ -4274,7 +4371,7 @@ class Field(ColumnBase):
def bind(self, model, name, set_attribute=True):
self.model = model
self.name = name
self.name = self.safe_name = name
self.column_name = self.column_name or name
if set_attribute:
setattr(model, name, self.accessor_class(model, self, name))
@ -4299,7 +4396,7 @@ class Field(ColumnBase):
return ctx.sql(self.column)
def get_modifiers(self):
return
pass
def ddl_datatype(self, ctx):
if ctx and ctx.state.field_types:
@ -4337,7 +4434,12 @@ class Field(ColumnBase):
class IntegerField(Field):
field_type = 'INT'
adapt = int
def adapt(self, value):
try:
return int(value)
except ValueError:
return value
class BigIntegerField(IntegerField):
@ -4377,7 +4479,12 @@ class PrimaryKeyField(AutoField):
class FloatField(Field):
field_type = 'FLOAT'
adapt = float
def adapt(self, value):
try:
return float(value)
except ValueError:
return value
class DoubleField(FloatField):
@ -4393,6 +4500,7 @@ class DecimalField(Field):
self.decimal_places = decimal_places
self.auto_round = auto_round
self.rounding = rounding or decimal.DefaultContext.rounding
self._exp = decimal.Decimal(10) ** (-self.decimal_places)
super(DecimalField, self).__init__(*args, **kwargs)
def get_modifiers(self):
@ -4403,9 +4511,8 @@ class DecimalField(Field):
if not value:
return value if value is None else D(0)
if self.auto_round:
exp = D(10) ** (-self.decimal_places)
rounding = self.rounding
return D(text_type(value)).quantize(exp, rounding=rounding)
decimal_value = D(text_type(value))
return decimal_value.quantize(self._exp, rounding=self.rounding)
return value
def python_value(self, value):
@ -4423,8 +4530,8 @@ class _StringField(Field):
return value.decode('utf-8')
return text_type(value)
def __add__(self, other): return self.concat(other)
def __radd__(self, other): return other.concat(self)
def __add__(self, other): return StringExpression(self, OP.CONCAT, other)
def __radd__(self, other): return StringExpression(other, OP.CONCAT, self)
class CharField(_StringField):
@ -4652,6 +4759,12 @@ def format_date_time(value, formats, post_process=None):
pass
return value
def simple_date_time(value):
try:
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
except (TypeError, ValueError):
return value
class _BaseFormattedField(Field):
formats = None
@ -4675,6 +4788,12 @@ class DateTimeField(_BaseFormattedField):
return format_date_time(value, self.formats)
return value
def to_timestamp(self):
return self.model._meta.database.to_timestamp(self)
def truncate(self, part):
return self.model._meta.database.truncate_date(part, self)
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
@ -4699,6 +4818,12 @@ class DateField(_BaseFormattedField):
return value.date()
return value
def to_timestamp(self):
return self.model._meta.database.to_timestamp(self)
def truncate(self, part):
return self.model._meta.database.truncate_date(part, self)
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
@ -4730,19 +4855,30 @@ class TimeField(_BaseFormattedField):
second = property(_date_part('second'))
def _timestamp_date_part(date_part):
def dec(self):
db = self.model._meta.database
expr = ((self / Value(self.resolution, converter=False))
if self.resolution > 1 else self)
return db.extract_date(date_part, db.from_timestamp(expr))
return dec
class TimestampField(BigIntegerField):
# Support second -> microsecond resolution.
valid_resolutions = [10**i for i in range(7)]
def __init__(self, *args, **kwargs):
self.resolution = kwargs.pop('resolution', None)
if not self.resolution:
self.resolution = 1
elif self.resolution in range(7):
elif self.resolution in range(2, 7):
self.resolution = 10 ** self.resolution
elif self.resolution not in self.valid_resolutions:
raise ValueError('TimestampField resolution must be one of: %s' %
', '.join(str(i) for i in self.valid_resolutions))
self.ticks_to_microsecond = 1000000 // self.resolution
self.utc = kwargs.pop('utc', False) or False
dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now
@ -4764,6 +4900,13 @@ class TimestampField(BigIntegerField):
ts = calendar.timegm(dt.utctimetuple())
return datetime.datetime.fromtimestamp(ts)
def get_timestamp(self, value):
if self.utc:
# If utc-mode is on, then we assume all naive datetimes are in UTC.
return calendar.timegm(value.utctimetuple())
else:
return time.mktime(value.timetuple())
def db_value(self, value):
if value is None:
return
@ -4775,12 +4918,7 @@ class TimestampField(BigIntegerField):
else:
return int(round(value * self.resolution))
if self.utc:
# If utc-mode is on, then we assume all naive datetimes are in UTC.
timestamp = calendar.timegm(value.utctimetuple())
else:
timestamp = time.mktime(value.timetuple())
timestamp = self.get_timestamp(value)
if self.resolution > 1:
timestamp += (value.microsecond * .000001)
timestamp *= self.resolution
@ -4789,9 +4927,8 @@ class TimestampField(BigIntegerField):
def python_value(self, value):
if value is not None and isinstance(value, (int, float, long)):
if self.resolution > 1:
ticks_to_microsecond = 1000000 // self.resolution
value, ticks = divmod(value, self.resolution)
microseconds = int(ticks * ticks_to_microsecond)
microseconds = int(ticks * self.ticks_to_microsecond)
else:
microseconds = 0
@ -4805,6 +4942,18 @@ class TimestampField(BigIntegerField):
return value
def from_timestamp(self):
expr = ((self / Value(self.resolution, converter=False))
if self.resolution > 1 else self)
return self.model._meta.database.from_timestamp(expr)
year = property(_timestamp_date_part('year'))
month = property(_timestamp_date_part('month'))
day = property(_timestamp_date_part('day'))
hour = property(_timestamp_date_part('hour'))
minute = property(_timestamp_date_part('minute'))
second = property(_timestamp_date_part('second'))
class IPField(BigIntegerField):
def db_value(self, val):
@ -4860,6 +5009,7 @@ class ForeignKeyField(Field):
'"backref" for Field objects.')
backref = related_name
self._is_self_reference = model == 'self'
self.rel_model = model
self.rel_field = field
self.declared_backref = backref
@ -4908,7 +5058,7 @@ class ForeignKeyField(Field):
raise ValueError('ForeignKeyField "%s"."%s" specifies an '
'object_id_name that conflicts with its field '
'name.' % (model._meta.name, name))
if self.rel_model == 'self':
if self._is_self_reference:
self.rel_model = model
if isinstance(self.rel_field, basestring):
self.rel_field = getattr(self.rel_model, self.rel_field)
@ -4918,6 +5068,7 @@ class ForeignKeyField(Field):
# Bind field before assigning backref, so field is bound when
# calling declared_backref() (if callable).
super(ForeignKeyField, self).bind(model, name, set_attribute)
self.safe_name = self.object_id_name
if callable_(self.declared_backref):
self.backref = self.declared_backref(self)
@ -5143,7 +5294,7 @@ class VirtualField(MetaField):
def bind(self, model, name, set_attribute=True):
self.model = model
self.column_name = self.name = name
self.column_name = self.name = self.safe_name = name
setattr(model, name, self.accessor_class(model, self, name))
@ -5190,7 +5341,7 @@ class CompositeKey(MetaField):
def bind(self, model, name, set_attribute=True):
self.model = model
self.column_name = self.name = name
self.column_name = self.name = self.safe_name = name
setattr(model, self.name, self)
@ -6139,7 +6290,12 @@ class Model(with_metaclass(ModelBase, Node)):
return cls.select().filter(*dq_nodes, **filters)
def get_id(self):
return getattr(self, self._meta.primary_key.name)
# Using getattr(self, pk-name) could accidentally trigger a query if
# the primary-key is a foreign-key. So we use the safe_name attribute,
# which defaults to the field-name, but will be the object_id_name for
# foreign-key fields.
if self._meta.primary_key is not False:
return getattr(self, self._meta.primary_key.safe_name)
_pk = property(get_id)
@ -6254,13 +6410,14 @@ class Model(with_metaclass(ModelBase, Node)):
return (
other.__class__ == self.__class__ and
self._pk is not None and
other._pk == self._pk)
self._pk == other._pk)
def __ne__(self, other):
return not self == other
def __sql__(self, ctx):
return ctx.sql(getattr(self, self._meta.primary_key.name))
return ctx.sql(Value(getattr(self, self._meta.primary_key.name),
converter=self._meta.primary_key.db_value))
@classmethod
def bind(cls, database, bind_refs=True, bind_backrefs=True):
@ -6327,6 +6484,18 @@ class ModelAlias(Node):
self.__dict__['alias'] = alias
def __getattr__(self, attr):
# Hack to work-around the fact that properties or other objects
# implementing the descriptor protocol (on the model being aliased),
# will not work correctly when we use getattr(). So we explicitly pass
# the model alias to the descriptor's getter.
try:
obj = self.model.__dict__[attr]
except KeyError:
pass
else:
if isinstance(obj, ModelDescriptor):
return obj.__get__(None, self)
model_attr = getattr(self.model, attr)
if isinstance(model_attr, Field):
self.__dict__[attr] = FieldAlias.create(self, model_attr)
@ -6675,6 +6844,13 @@ class ModelSelect(BaseModelSelect, Select):
return fk_fields[0], is_backref
if on is None:
# If multiple foreign-keys exist, try using the FK whose name
# matches that of the related model. If not, raise an error as this
# is ambiguous.
for fk in fk_fields:
if fk.name == dest._meta.name:
return fk, is_backref
raise ValueError('More than one foreign key between %s and %s.'
' Please specify which you are joining on.' %
(src, dest))
@ -6710,7 +6886,7 @@ class ModelSelect(BaseModelSelect, Select):
on, attr, constructor = self._normalize_join(src, dest, on, attr)
if attr:
self._joins.setdefault(src, [])
self._joins[src].append((dest, attr, constructor))
self._joins[src].append((dest, attr, constructor, join_type))
elif on is not None:
raise ValueError('Cannot specify on clause with cross join.')
@ -6732,7 +6908,7 @@ class ModelSelect(BaseModelSelect, Select):
def ensure_join(self, lm, rm, on=None, **join_kwargs):
join_ctx = self._join_ctx
for dest, attr, constructor in self._joins.get(lm, []):
for dest, _, constructor, _ in self._joins.get(lm, []):
if dest == rm:
return self
return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx)
@ -6757,7 +6933,7 @@ class ModelSelect(BaseModelSelect, Select):
model_attr = getattr(curr, key)
else:
for piece in key.split('__'):
for dest, attr, _ in self._joins.get(curr, ()):
for dest, attr, _, _ in self._joins.get(curr, ()):
if attr == piece or (isinstance(dest, ModelAlias) and
dest.alias == piece):
curr = dest
@ -6999,8 +7175,7 @@ class BaseModelCursorWrapper(DictCursorWrapper):
if raw_node._coerce:
converters[idx] = node.python_value
fields[idx] = node
if (column == node.name or column == node.column_name) and \
not raw_node.is_alias():
if not raw_node.is_alias():
self.columns[idx] = node.name
elif isinstance(node, Function) and node._coerce:
if node._python_value is not None:
@ -7100,6 +7275,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
self.src_to_dest = []
accum = collections.deque(self.from_list)
dests = set()
while accum:
curr = accum.popleft()
if isinstance(curr, Join):
@ -7110,11 +7286,14 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
if curr not in self.joins:
continue
for key, attr, constructor in self.joins[curr]:
is_dict = isinstance(curr, dict)
for key, attr, constructor, join_type in self.joins[curr]:
if key not in self.key_to_constructor:
self.key_to_constructor[key] = constructor
self.src_to_dest.append((curr, attr, key,
isinstance(curr, dict)))
# (src, attr, dest, is_dict, join_type).
self.src_to_dest.append((curr, attr, key, is_dict,
join_type))
dests.add(key)
accum.append(key)
@ -7127,7 +7306,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
self.key_to_constructor[src] = src.model
# Indicate which sources are also dests.
for src, _, dest, _ in self.src_to_dest:
for src, _, dest, _, _ in self.src_to_dest:
self.src_is_dest[src] = src in dests and (dest in selected_src
or src in selected_src)
@ -7171,7 +7350,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
setattr(instance, column, value)
# Need to do some analysis on the joins before this.
for (src, attr, dest, is_dict) in self.src_to_dest:
for (src, attr, dest, is_dict, join_type) in self.src_to_dest:
instance = objects[src]
try:
joined_instance = objects[dest]
@ -7184,6 +7363,12 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
(dest not in set_keys and not self.src_is_dest.get(dest)):
continue
# If no fields were set on either the source or the destination,
# then we have nothing to do here.
if instance not in set_keys and dest not in set_keys \
and join_type.endswith('OUTER'):
continue
if is_dict:
instance[attr] = joined_instance
else:
@ -7309,7 +7494,7 @@ def prefetch(sq, *subqueries):
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(pq)
deps[query_model] = {}
deps.setdefault(query_model, {})
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))

@ -61,12 +61,14 @@ class DataSet(object):
def get_export_formats(self):
return {
'csv': CSVExporter,
'json': JSONExporter}
'json': JSONExporter,
'tsv': TSVExporter}
def get_import_formats(self):
return {
'csv': CSVImporter,
'json': JSONImporter}
'json': JSONImporter,
'tsv': TSVImporter}
def __getitem__(self, table):
if table not in self._models and table in self.tables:
@ -244,6 +246,29 @@ class Table(object):
self.dataset.update_cache(self.name)
def __getitem__(self, item):
try:
return self.model_class[item]
except self.model_class.DoesNotExist:
pass
def __setitem__(self, item, value):
if not isinstance(value, dict):
raise ValueError('Table.__setitem__() value must be a dict')
pk = self.model_class._meta.primary_key
value[pk.name] = item
try:
with self.dataset.transaction() as txn:
self.insert(**value)
except IntegrityError:
self.dataset.update_cache(self.name)
self.update(columns=[pk.name], **value)
def __delitem__(self, item):
del self.model_class[item]
def insert(self, **data):
self._migrate_new_columns(data)
return self.model_class.insert(**data).execute()
@ -343,6 +368,12 @@ class CSVExporter(Exporter):
writer.writerow(row)
class TSVExporter(CSVExporter):
def export(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVExporter, self).export(file_obj, header, **kwargs)
class Importer(object):
def __init__(self, table, strict=False):
self.table = table
@ -413,3 +444,9 @@ class CSVImporter(Importer):
count += 1
return count
class TSVImporter(CSVImporter):
def load(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVImporter, self).load(file_obj, header, **kwargs)

@ -1,6 +1,9 @@
from peewee import ModelDescriptor
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
class hybrid_method(object):
class hybrid_method(ModelDescriptor):
def __init__(self, func, expr=None):
self.func = func
self.expr = expr or func
@ -15,7 +18,7 @@ class hybrid_method(object):
return self
class hybrid_property(object):
class hybrid_property(ModelDescriptor):
def __init__(self, fget, fset=None, fdel=None, expr=None):
self.fget = fget
self.fset = fset

@ -777,7 +777,7 @@ class SqliteMigrator(SchemaMigrator):
clean = []
for column in columns:
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
column = new_columne + column[len(column_to_update):]
column = new_column + column[len(column_to_update):]
clean.append(column)
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))

@ -7,7 +7,10 @@ except ImportError:
from peewee import ImproperlyConfigured
from peewee import MySQLDatabase
from peewee import NodeList
from peewee import SQL
from peewee import TextField
from peewee import fn
class MySQLConnectorDatabase(MySQLDatabase):
@ -18,7 +21,10 @@ class MySQLConnectorDatabase(MySQLDatabase):
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor(buffered=True)
@ -32,3 +38,12 @@ class JSONField(TextField):
def python_value(self, value):
if value is not None:
return json.loads(value)
def Match(columns, expr, modifier=None):
if isinstance(columns, (list, tuple)):
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
else:
match = fn.MATCH(columns) # Single column / field.
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
return NodeList((match, fn.AGAINST(args)))

@ -33,6 +33,7 @@ That's it!
"""
import heapq
import logging
import random
import time
from collections import namedtuple
from itertools import chain
@ -153,7 +154,7 @@ class PooledDatabase(object):
len(self._in_use) >= self._max_connections):
raise MaxConnectionsExceeded('Exceeded maximum connections.')
conn = super(PooledDatabase, self)._connect()
ts = time.time()
ts = time.time() - random.random() / 1000
key = self.conn_key(conn)
logger.debug('Created new connection %s.', key)

@ -3,6 +3,7 @@ Collection of postgres-specific extensions, currently including:
* Support for hstore, a key/value type storage
"""
import json
import logging
import uuid
@ -277,18 +278,19 @@ class HStoreField(IndexedFieldMixin, Field):
class JSONField(Field):
field_type = 'JSON'
_json_datatype = 'json'
def __init__(self, dumps=None, *args, **kwargs):
if Json is None:
raise Exception('Your version of psycopg2 does not support JSON.')
self.dumps = dumps
self.dumps = dumps or json.dumps
super(JSONField, self).__init__(*args, **kwargs)
def db_value(self, value):
if value is None:
return value
if not isinstance(value, Json):
return Json(value, dumps=self.dumps)
return Cast(self.dumps(value), self._json_datatype)
return value
def __getitem__(self, value):
@ -307,6 +309,7 @@ def cast_jsonb(node):
class BinaryJSONField(IndexedFieldMixin, JSONField):
field_type = 'JSONB'
_json_datatype = 'jsonb'
__hash__ = Field.__hash__
def contains(self, other):
@ -449,7 +452,10 @@ class PostgresqlExtDatabase(PostgresqlDatabase):
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
if commit is __named_cursor__:
return self._state.conn.cursor(name=str(uuid.uuid1()))
return self._state.conn.cursor()

@ -224,7 +224,7 @@ class PostgresqlMetadata(Metadata):
23: IntegerField,
25: TextField,
700: FloatField,
701: FloatField,
701: DoubleField,
1042: CharField, # blank-padded CHAR
1043: CharField,
1082: DateField,

@ -73,7 +73,7 @@ def model_to_dict(model, recurse=True, backrefs=False, only=None,
field_data = model.__data__.get(field.name)
if isinstance(field, ForeignKeyField) and recurse:
if field_data:
if field_data is not None:
seen.add(field)
rel_obj = getattr(model, field.name)
field_data = model_to_dict(
@ -191,6 +191,10 @@ class ReconnectMixin(object):
(OperationalError, '2006'), # MySQL server has gone away.
(OperationalError, '2013'), # Lost connection to MySQL server.
(OperationalError, '2014'), # Commands out of sync.
# mysql-connector raises a slightly different error when an idle
# connection is terminated by the server. This is equivalent to 2013.
(OperationalError, 'MySQL Connection not available.'),
)
def __init__(self, *args, **kwargs):

@ -65,7 +65,7 @@ class Model(_Model):
pre_init.send(self)
def save(self, *args, **kwargs):
pk_value = self._pk
pk_value = self._pk if self._meta.primary_key else True
created = kwargs.get('force_insert', False) or not bool(pk_value)
pre_save.send(self, created=created)
ret = super(Model, self).save(*args, **kwargs)

@ -69,6 +69,11 @@ class AutoIncrementField(AutoField):
return NodeList((node_list, SQL('AUTOINCREMENT')))
class TDecimalField(DecimalField):
field_type = 'TEXT'
def get_modifiers(self): pass
class JSONPath(ColumnBase):
def __init__(self, field, path=None):
super(JSONPath, self).__init__()
@ -1190,6 +1195,33 @@ def bm25(raw_match_info, *args):
L_O = A_O + col_count
X_O = L_O + col_count
# Worked example of pcnalx for two columns and two phrases, 100 docs total.
# {
# p = 2
# c = 2
# n = 100
# a0 = 4 -- avg number of tokens for col0, e.g. title
# a1 = 40 -- avg number of tokens for col1, e.g. body
# l0 = 5 -- curr doc has 5 tokens in col0
# l1 = 30 -- curr doc has 30 tokens in col1
#
# x000 -- hits this row for phrase0, col0
# x001 -- hits all rows for phrase0, col0
# x002 -- rows with phrase0 in col0 at least once
#
# x010 -- hits this row for phrase0, col1
# x011 -- hits all rows for phrase0, col1
# x012 -- rows with phrase0 in col1 at least once
#
# x100 -- hits this row for phrase1, col0
# x101 -- hits all rows for phrase1, col0
# x102 -- rows with phrase1 in col0 at least once
#
# x110 -- hits this row for phrase1, col1
# x111 -- hits all rows for phrase1, col1
# x112 -- rows with phrase1 in col1 at least once
# }
weights = get_weights(col_count, args)
for i in range(term_count):
@ -1213,8 +1245,8 @@ def bm25(raw_match_info, *args):
avg_length = float(match_info[A_O + j]) or 1. # avgdl
ratio = doc_length / avg_length
num = term_frequency * (K + 1)
b_part = 1 - B + (B * ratio)
num = term_frequency * (K + 1.0)
b_part = 1.0 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * (num / denom)

@ -122,7 +122,7 @@ class SubtitleStreamParser(BaseParser):
"""Returns a string """
tags = data.get("tags", None)
if tags:
info = tags.get("language", None)
info = tags.get("language", None) or tags.get("LANGUAGE", None)
return info, (info or "null")
return None, "null"

@ -29,7 +29,7 @@ from subliminal.utils import hash_napiprojekt, hash_opensubtitles, hash_shooter,
from subliminal.video import VIDEO_EXTENSIONS, Video, Episode, Movie
from subliminal.core import guessit, ProviderPool, io, is_windows_special_path, \
ThreadPoolExecutor, check_video
from subliminal_patch.exceptions import TooManyRequests, APIThrottled, ParseResponseError
from subliminal_patch.exceptions import TooManyRequests, APIThrottled
from subzero.language import Language, ENDSWITH_LANGUAGECODE_RE
from scandir import scandir, scandir_generic as _scandir_generic
@ -282,14 +282,10 @@ class SZProviderPool(ProviderPool):
logger.debug("RAR Traceback: %s", traceback.format_exc())
return False
except (TooManyRequests, DownloadLimitExceeded, ServiceUnavailable, APIThrottled, ParseResponseError) as e:
self.throttle_callback(subtitle.provider_name, e)
self.discarded_providers.add(subtitle.provider_name)
return False
except:
except Exception as e:
logger.exception('Unexpected error in provider %r, Traceback: %s', subtitle.provider_name,
traceback.format_exc())
self.throttle_callback(subtitle.provider_name, e)
self.discarded_providers.add(subtitle.provider_name)
return False
@ -613,16 +609,6 @@ def _search_external_subtitles(path, languages=None, only_one=False, scandir_gen
if adv_tag:
forced = "forced" in adv_tag
# extract the potential language code
try:
language_code = p_root.rsplit(".", 1)[1].replace('_', '-')
try:
Language.fromietf(language_code)
except:
language_code = None
except IndexError:
language_code = None
# remove possible language code for matching
p_root_bare = ENDSWITH_LANGUAGECODE_RE.sub("", p_root)
@ -635,19 +621,21 @@ def _search_external_subtitles(path, languages=None, only_one=False, scandir_gen
if match_strictness == "strict" or (match_strictness == "loose" and not filename_contains):
continue
# default language is undefined
language = Language('und')
language = None
# attempt to parse
if language_code:
# extract the potential language code
try:
language_code = p_root.rsplit(".", 1)[1].replace('_', '-')
try:
language = Language.fromietf(language_code)
language.forced = forced
except ValueError:
logger.error('Cannot parse language code %r', language_code)
language = None
language_code = None
except IndexError:
language_code = None
elif not language_code and only_one:
if not language and not language_code and only_one:
language = Language.rebuild(list(languages)[0], forced=forced)
subtitles[p] = language
@ -877,6 +865,7 @@ def save_subtitles(file_path, subtitles, single=False, directory=None, chmod=Non
if content:
if os.path.exists(subtitle_path):
os.remove(subtitle_path)
with open(subtitle_path, 'wb') as f:
f.write(content)
subtitle.storage_path = subtitle_path

@ -1,5 +1,4 @@
# coding=utf-8
from __future__ import absolute_import
import logging
import re
import datetime
@ -11,7 +10,7 @@ from requests import Session
from subliminal.cache import region
from subliminal.exceptions import DownloadLimitExceeded, AuthenticationError
from subliminal.providers.addic7ed import Addic7edProvider as _Addic7edProvider, \
Addic7edSubtitle as _Addic7edSubtitle, ParserBeautifulSoup, show_cells_re
Addic7edSubtitle as _Addic7edSubtitle, ParserBeautifulSoup
from subliminal.subtitle import fix_line_ending
from subliminal_patch.utils import sanitize
from subliminal_patch.exceptions import TooManyRequests
@ -20,6 +19,8 @@ from subzero.language import Language
logger = logging.getLogger(__name__)
show_cells_re = re.compile(b'<td class="(?:version|vr)">.*?</td>', re.DOTALL)
#: Series header parsing regex
series_year_re = re.compile(r'^(?P<series>[ \w\'.:(),*&!?-]+?)(?: \((?P<year>\d{4})\))?$')
@ -104,11 +105,15 @@ class Addic7edProvider(_Addic7edProvider):
tries = 0
while tries < 3:
r = self.session.get(self.server_url + 'login.php', timeout=10, headers={"Referer": self.server_url})
if "grecaptcha" in r.text:
if "g-recaptcha" in r.text or "grecaptcha" in r.text:
logger.info('Addic7ed: Solving captcha. This might take a couple of minutes, but should only '
'happen once every so often')
site_key = re.search(r'grecaptcha.execute\(\'(.+?)\',', r.text).group(1)
for g, s in (("g-recaptcha-response", r'g-recaptcha.+?data-sitekey=\"(.+?)\"'),
("recaptcha_response", r'grecaptcha.execute\(\'(.+?)\',')):
site_key = re.search(s, r.text).group(1)
if site_key:
break
if not site_key:
logger.error("Addic7ed: Captcha site-key not found!")
return
@ -122,7 +127,7 @@ class Addic7edProvider(_Addic7edProvider):
if not result:
raise Exception("Addic7ed: Couldn't solve captcha!")
data["recaptcha_response"] = result
data[g] = result
r = self.session.post(self.server_url + 'dologin.php', data, allow_redirects=False, timeout=10,
headers={"Referer": self.server_url + "login.php"})
@ -130,12 +135,11 @@ class Addic7edProvider(_Addic7edProvider):
if "relax, slow down" in r.text:
raise TooManyRequests(self.username)
if r.status_code != 302:
if "User <b></b> doesn't exist" in r.text and tries <= 2:
logger.info("Addic7ed: Error, trying again. (%s/%s)", tries+1, 3)
tries += 1
continue
if "Try again" in r.content or "Wrong password" in r.content:
raise AuthenticationError(self.username)
if r.status_code != 302:
logger.error("Addic7ed: Something went wrong when logging in")
raise AuthenticationError(self.username)
break

Loading…
Cancel
Save