Major fixes to database subsystem.

pull/684/head
Louis Vézina 5 years ago
parent 55037cfde2
commit ae6f7117fc

@ -3,13 +3,20 @@ import atexit
from get_args import args
from peewee import *
from playhouse.sqliteq import SqliteQueueDatabase
from playhouse.migrate import *
from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie
database = SqliteDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db'))
database = SqliteQueueDatabase(
os.path.join(args.config_dir, 'db', 'bazarr.db'),
use_gevent=False,
autostart=True,
queue_max_size=256, # Max. # of pending writes that can accumulate.
results_timeout=30.0 # Max. time to wait for query to be executed.
)
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
database.timeout = 30 # Number of second to wait for database
database.cache_size = -1024 # Number of KB of cache for wal-journal.
# Must be negative because positive means number of pages.
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.

@ -65,7 +65,7 @@ except ImportError:
mysql = None
__version__ = '3.9.6'
__version__ = '3.11.2'
__all__ = [
'AsIs',
'AutoField',
@ -206,15 +206,15 @@ __sqlite_datetime_formats__ = (
'%H:%M')
__sqlite_date_trunc__ = {
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'year': '%Y-01-01 00:00:00',
'month': '%Y-%m-01 00:00:00',
'day': '%Y-%m-%d 00:00:00',
'hour': '%Y-%m-%d %H:00:00',
'minute': '%Y-%m-%d %H:%M:00',
'second': '%Y-%m-%d %H:%M:%S'}
__mysql_date_trunc__ = __sqlite_date_trunc__.copy()
__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i'
__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00'
__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_part(lookup_type, datetime_string):
@ -460,6 +460,9 @@ class DatabaseProxy(Proxy):
return _savepoint(self)
class ModelDescriptor(object): pass
# SQL Generation.
@ -1141,11 +1144,24 @@ class ColumnBase(Node):
op = OP.IS if is_null else OP.IS_NOT
return Expression(self, op, None)
def contains(self, rhs):
return Expression(self, OP.ILIKE, '%%%s%%' % rhs)
if isinstance(rhs, Node):
rhs = Expression('%', OP.CONCAT,
Expression(rhs, OP.CONCAT, '%'))
else:
rhs = '%%%s%%' % rhs
return Expression(self, OP.ILIKE, rhs)
def startswith(self, rhs):
return Expression(self, OP.ILIKE, '%s%%' % rhs)
if isinstance(rhs, Node):
rhs = Expression(rhs, OP.CONCAT, '%')
else:
rhs = '%s%%' % rhs
return Expression(self, OP.ILIKE, rhs)
def endswith(self, rhs):
return Expression(self, OP.ILIKE, '%%%s' % rhs)
if isinstance(rhs, Node):
rhs = Expression('%', OP.CONCAT, rhs)
else:
rhs = '%%%s' % rhs
return Expression(self, OP.ILIKE, rhs)
def between(self, lo, hi):
return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi)))
def concat(self, rhs):
@ -1229,6 +1245,9 @@ class Alias(WrappedNode):
super(Alias, self).__init__(node)
self._alias = alias
def __hash__(self):
return hash(self._alias)
def alias(self, alias=None):
if alias is None:
return self.node
@ -1376,8 +1395,16 @@ class Expression(ColumnBase):
def __sql__(self, ctx):
overrides = {'parentheses': not self.flat, 'in_expr': True}
if isinstance(self.lhs, Field):
overrides['converter'] = self.lhs.db_value
# First attempt to unwrap the node on the left-hand-side, so that we
# can get at the underlying Field if one is present.
node = raw_node = self.lhs
if isinstance(raw_node, WrappedNode):
node = raw_node.unwrap()
# Set up the appropriate converter if we have a field on the left side.
if isinstance(node, Field) and raw_node._coerce:
overrides['converter'] = node.db_value
else:
overrides['converter'] = None
@ -2090,6 +2117,11 @@ class CompoundSelectQuery(SelectBase):
def _returning(self):
return self.lhs._returning
@database_required
def exists(self, database):
query = Select((self.limit(1),), (SQL('1'),)).bind(database)
return bool(query.scalar())
def _get_query_key(self):
return (self.lhs.get_query_key(), self.rhs.get_query_key())
@ -2101,6 +2133,14 @@ class CompoundSelectQuery(SelectBase):
elif csq_setting == CSQ_PARENTHESES_ALWAYS:
return True
elif csq_setting == CSQ_PARENTHESES_UNNESTED:
if ctx.state.in_expr or ctx.state.in_function:
# If this compound select query is being used inside an
# expression, e.g., an IN or EXISTS().
return False
# If the query on the left or right is itself a compound select
# query, then we do not apply parentheses. However, if it is a
# regular SELECT query, we will apply parentheses.
return not isinstance(subq, CompoundSelectQuery)
def __sql__(self, ctx):
@ -2433,7 +2473,6 @@ class Insert(_WriteQuery):
# Load and organize column defaults (if provided).
defaults = self.get_default_data()
value_lookups = {}
# First figure out what columns are being inserted (if they weren't
# specified explicitly). Resulting columns are normalized and ordered.
@ -2443,47 +2482,48 @@ class Insert(_WriteQuery):
except StopIteration:
raise self.DefaultValuesException('Error: no rows to insert.')
if not isinstance(row, dict):
if not isinstance(row, Mapping):
columns = self.get_default_columns()
if columns is None:
raise ValueError('Bulk insert must specify columns.')
else:
# Infer column names from the dict of data being inserted.
accum = []
uses_strings = False # Are the dict keys strings or columns?
for key in row:
if isinstance(key, basestring):
column = getattr(self.table, key)
uses_strings = True
else:
column = key
for column in row:
if isinstance(column, basestring):
column = getattr(self.table, column)
accum.append(column)
value_lookups[column] = key
# Add any columns present in the default data that are not
# accounted for by the dictionary of row data.
column_set = set(accum)
for col in (set(defaults) - column_set):
accum.append(col)
value_lookups[col] = col.name if uses_strings else col
columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx))
rows_iter = itertools.chain(iter((row,)), rows_iter)
else:
clean_columns = []
seen = set()
for column in columns:
if isinstance(column, basestring):
column_obj = getattr(self.table, column)
else:
column_obj = column
value_lookups[column_obj] = column
clean_columns.append(column_obj)
seen.add(column_obj)
columns = clean_columns
for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)):
if col not in value_lookups:
if col not in seen:
columns.append(col)
value_lookups[col] = col
value_lookups = {}
for column in columns:
lookups = [column, column.name]
if isinstance(column, Field) and column.name != column.column_name:
lookups.append(column.column_name)
value_lookups[column] = lookups
ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
columns_converters = [
@ -2497,7 +2537,18 @@ class Insert(_WriteQuery):
for i, (column, converter) in enumerate(columns_converters):
try:
if is_dict:
val = row[value_lookups[column]]
# The logic is a bit convoluted, but in order to be
# flexible in what we accept (dict keyed by
# column/field, field name, or underlying column name),
# we try accessing the row data dict using each
# possible key. If no match is found, throw an error.
for lookup in value_lookups[column]:
try:
val = row[lookup]
except KeyError: pass
else: break
else:
raise KeyError
else:
val = row[i]
except (KeyError, IndexError):
@ -2544,7 +2595,7 @@ class Insert(_WriteQuery):
.sql(self.table)
.literal(' '))
if isinstance(self._insert, dict) and not self._columns:
if isinstance(self._insert, Mapping) and not self._columns:
try:
self._simple_insert(ctx)
except self.DefaultValuesException:
@ -2817,7 +2868,8 @@ class Database(_callable_context_manager):
truncate_table = True
def __init__(self, database, thread_safe=True, autorollback=False,
field_types=None, operations=None, autocommit=None, **kwargs):
field_types=None, operations=None, autocommit=None,
autoconnect=True, **kwargs):
self._field_types = merge_dict(FIELD, self.field_types)
self._operations = merge_dict(OP, self.operations)
if field_types:
@ -2825,6 +2877,7 @@ class Database(_callable_context_manager):
if operations:
self._operations.update(operations)
self.autoconnect = autoconnect
self.autorollback = autorollback
self.thread_safe = thread_safe
if thread_safe:
@ -2930,7 +2983,10 @@ class Database(_callable_context_manager):
def cursor(self, commit=None):
if self.is_closed():
self.connect()
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor()
def execute_sql(self, sql, params=None, commit=SENTINEL):
@ -3141,6 +3197,15 @@ class Database(_callable_context_manager):
def truncate_date(self, date_part, date_field):
raise NotImplementedError
def to_timestamp(self, date_field):
raise NotImplementedError
def from_timestamp(self, date_field):
raise NotImplementedError
def random(self):
return fn.random()
def bind(self, models, bind_refs=True, bind_backrefs=True):
for model in models:
model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs)
@ -3524,10 +3589,17 @@ class SqliteDatabase(Database):
return self._build_on_conflict_update(oc, query)
def extract_date(self, date_part, date_field):
return fn.date_part(date_part, date_field)
return fn.date_part(date_part, date_field, python_value=int)
def truncate_date(self, date_part, date_field):
return fn.date_trunc(date_part, date_field)
return fn.date_trunc(date_part, date_field,
python_value=simple_date_time)
def to_timestamp(self, date_field):
return fn.strftime('%s', date_field).cast('integer')
def from_timestamp(self, date_field):
return fn.datetime(date_field, 'unixepoch')
class PostgresqlDatabase(Database):
@ -3552,9 +3624,11 @@ class PostgresqlDatabase(Database):
safe_create_index = False
sequences = True
def init(self, database, register_unicode=True, encoding=None, **kwargs):
def init(self, database, register_unicode=True, encoding=None,
isolation_level=None, **kwargs):
self._register_unicode = register_unicode
self._encoding = encoding
self._isolation_level = isolation_level
super(PostgresqlDatabase, self).init(database, **kwargs)
def _connect(self):
@ -3566,6 +3640,8 @@ class PostgresqlDatabase(Database):
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
if self._encoding:
conn.set_client_encoding(self._encoding)
if self._isolation_level:
conn.set_isolation_level(self._isolation_level)
return conn
def _set_server_version(self, conn):
@ -3695,6 +3771,13 @@ class PostgresqlDatabase(Database):
def truncate_date(self, date_part, date_field):
return fn.DATE_TRUNC(date_part, date_field)
def to_timestamp(self, date_field):
return self.extract_date('EPOCH', date_field)
def from_timestamp(self, date_field):
# Ironically, here, Postgres means "to the Postgresql timestamp type".
return fn.to_timestamp(date_field)
def get_noop_select(self, ctx):
return ctx.sql(Select().columns(SQL('0')).where(SQL('false')))
@ -3727,9 +3810,13 @@ class MySQLDatabase(Database):
limit_max = 2 ** 64 - 1
safe_create_index = False
safe_drop_index = False
sql_mode = 'PIPES_AS_CONCAT'
def init(self, database, **kwargs):
params = {'charset': 'utf8', 'use_unicode': True}
params = {
'charset': 'utf8',
'sql_mode': self.sql_mode,
'use_unicode': True}
params.update(kwargs)
if 'password' in params and mysql_passwd:
params['passwd'] = params.pop('password')
@ -3876,7 +3963,17 @@ class MySQLDatabase(Database):
return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field)))
def truncate_date(self, date_part, date_field):
return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part])
return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part],
python_value=simple_date_time)
def to_timestamp(self, date_field):
return fn.UNIX_TIMESTAMP(date_field)
def from_timestamp(self, date_field):
return fn.FROM_UNIXTIME(date_field)
def random(self):
return fn.rand()
def get_noop_select(self, ctx):
return ctx.literal('DO 0')
@ -4274,7 +4371,7 @@ class Field(ColumnBase):
def bind(self, model, name, set_attribute=True):
self.model = model
self.name = name
self.name = self.safe_name = name
self.column_name = self.column_name or name
if set_attribute:
setattr(model, name, self.accessor_class(model, self, name))
@ -4299,7 +4396,7 @@ class Field(ColumnBase):
return ctx.sql(self.column)
def get_modifiers(self):
return
pass
def ddl_datatype(self, ctx):
if ctx and ctx.state.field_types:
@ -4337,7 +4434,12 @@ class Field(ColumnBase):
class IntegerField(Field):
field_type = 'INT'
adapt = int
def adapt(self, value):
try:
return int(value)
except ValueError:
return value
class BigIntegerField(IntegerField):
@ -4377,7 +4479,12 @@ class PrimaryKeyField(AutoField):
class FloatField(Field):
field_type = 'FLOAT'
adapt = float
def adapt(self, value):
try:
return float(value)
except ValueError:
return value
class DoubleField(FloatField):
@ -4393,6 +4500,7 @@ class DecimalField(Field):
self.decimal_places = decimal_places
self.auto_round = auto_round
self.rounding = rounding or decimal.DefaultContext.rounding
self._exp = decimal.Decimal(10) ** (-self.decimal_places)
super(DecimalField, self).__init__(*args, **kwargs)
def get_modifiers(self):
@ -4403,9 +4511,8 @@ class DecimalField(Field):
if not value:
return value if value is None else D(0)
if self.auto_round:
exp = D(10) ** (-self.decimal_places)
rounding = self.rounding
return D(text_type(value)).quantize(exp, rounding=rounding)
decimal_value = D(text_type(value))
return decimal_value.quantize(self._exp, rounding=self.rounding)
return value
def python_value(self, value):
@ -4423,8 +4530,8 @@ class _StringField(Field):
return value.decode('utf-8')
return text_type(value)
def __add__(self, other): return self.concat(other)
def __radd__(self, other): return other.concat(self)
def __add__(self, other): return StringExpression(self, OP.CONCAT, other)
def __radd__(self, other): return StringExpression(other, OP.CONCAT, self)
class CharField(_StringField):
@ -4652,6 +4759,12 @@ def format_date_time(value, formats, post_process=None):
pass
return value
def simple_date_time(value):
try:
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
except (TypeError, ValueError):
return value
class _BaseFormattedField(Field):
formats = None
@ -4675,6 +4788,12 @@ class DateTimeField(_BaseFormattedField):
return format_date_time(value, self.formats)
return value
def to_timestamp(self):
return self.model._meta.database.to_timestamp(self)
def truncate(self, part):
return self.model._meta.database.truncate_date(part, self)
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
@ -4699,6 +4818,12 @@ class DateField(_BaseFormattedField):
return value.date()
return value
def to_timestamp(self):
return self.model._meta.database.to_timestamp(self)
def truncate(self, part):
return self.model._meta.database.truncate_date(part, self)
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
@ -4730,19 +4855,30 @@ class TimeField(_BaseFormattedField):
second = property(_date_part('second'))
def _timestamp_date_part(date_part):
def dec(self):
db = self.model._meta.database
expr = ((self / Value(self.resolution, converter=False))
if self.resolution > 1 else self)
return db.extract_date(date_part, db.from_timestamp(expr))
return dec
class TimestampField(BigIntegerField):
# Support second -> microsecond resolution.
valid_resolutions = [10**i for i in range(7)]
def __init__(self, *args, **kwargs):
self.resolution = kwargs.pop('resolution', None)
if not self.resolution:
self.resolution = 1
elif self.resolution in range(7):
elif self.resolution in range(2, 7):
self.resolution = 10 ** self.resolution
elif self.resolution not in self.valid_resolutions:
raise ValueError('TimestampField resolution must be one of: %s' %
', '.join(str(i) for i in self.valid_resolutions))
self.ticks_to_microsecond = 1000000 // self.resolution
self.utc = kwargs.pop('utc', False) or False
dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now
@ -4764,6 +4900,13 @@ class TimestampField(BigIntegerField):
ts = calendar.timegm(dt.utctimetuple())
return datetime.datetime.fromtimestamp(ts)
def get_timestamp(self, value):
if self.utc:
# If utc-mode is on, then we assume all naive datetimes are in UTC.
return calendar.timegm(value.utctimetuple())
else:
return time.mktime(value.timetuple())
def db_value(self, value):
if value is None:
return
@ -4775,12 +4918,7 @@ class TimestampField(BigIntegerField):
else:
return int(round(value * self.resolution))
if self.utc:
# If utc-mode is on, then we assume all naive datetimes are in UTC.
timestamp = calendar.timegm(value.utctimetuple())
else:
timestamp = time.mktime(value.timetuple())
timestamp = self.get_timestamp(value)
if self.resolution > 1:
timestamp += (value.microsecond * .000001)
timestamp *= self.resolution
@ -4789,9 +4927,8 @@ class TimestampField(BigIntegerField):
def python_value(self, value):
if value is not None and isinstance(value, (int, float, long)):
if self.resolution > 1:
ticks_to_microsecond = 1000000 // self.resolution
value, ticks = divmod(value, self.resolution)
microseconds = int(ticks * ticks_to_microsecond)
microseconds = int(ticks * self.ticks_to_microsecond)
else:
microseconds = 0
@ -4805,6 +4942,18 @@ class TimestampField(BigIntegerField):
return value
def from_timestamp(self):
expr = ((self / Value(self.resolution, converter=False))
if self.resolution > 1 else self)
return self.model._meta.database.from_timestamp(expr)
year = property(_timestamp_date_part('year'))
month = property(_timestamp_date_part('month'))
day = property(_timestamp_date_part('day'))
hour = property(_timestamp_date_part('hour'))
minute = property(_timestamp_date_part('minute'))
second = property(_timestamp_date_part('second'))
class IPField(BigIntegerField):
def db_value(self, val):
@ -4860,6 +5009,7 @@ class ForeignKeyField(Field):
'"backref" for Field objects.')
backref = related_name
self._is_self_reference = model == 'self'
self.rel_model = model
self.rel_field = field
self.declared_backref = backref
@ -4908,7 +5058,7 @@ class ForeignKeyField(Field):
raise ValueError('ForeignKeyField "%s"."%s" specifies an '
'object_id_name that conflicts with its field '
'name.' % (model._meta.name, name))
if self.rel_model == 'self':
if self._is_self_reference:
self.rel_model = model
if isinstance(self.rel_field, basestring):
self.rel_field = getattr(self.rel_model, self.rel_field)
@ -4918,6 +5068,7 @@ class ForeignKeyField(Field):
# Bind field before assigning backref, so field is bound when
# calling declared_backref() (if callable).
super(ForeignKeyField, self).bind(model, name, set_attribute)
self.safe_name = self.object_id_name
if callable_(self.declared_backref):
self.backref = self.declared_backref(self)
@ -5143,7 +5294,7 @@ class VirtualField(MetaField):
def bind(self, model, name, set_attribute=True):
self.model = model
self.column_name = self.name = name
self.column_name = self.name = self.safe_name = name
setattr(model, name, self.accessor_class(model, self, name))
@ -5190,7 +5341,7 @@ class CompositeKey(MetaField):
def bind(self, model, name, set_attribute=True):
self.model = model
self.column_name = self.name = name
self.column_name = self.name = self.safe_name = name
setattr(model, self.name, self)
@ -6139,7 +6290,12 @@ class Model(with_metaclass(ModelBase, Node)):
return cls.select().filter(*dq_nodes, **filters)
def get_id(self):
return getattr(self, self._meta.primary_key.name)
# Using getattr(self, pk-name) could accidentally trigger a query if
# the primary-key is a foreign-key. So we use the safe_name attribute,
# which defaults to the field-name, but will be the object_id_name for
# foreign-key fields.
if self._meta.primary_key is not False:
return getattr(self, self._meta.primary_key.safe_name)
_pk = property(get_id)
@ -6254,13 +6410,14 @@ class Model(with_metaclass(ModelBase, Node)):
return (
other.__class__ == self.__class__ and
self._pk is not None and
other._pk == self._pk)
self._pk == other._pk)
def __ne__(self, other):
return not self == other
def __sql__(self, ctx):
return ctx.sql(getattr(self, self._meta.primary_key.name))
return ctx.sql(Value(getattr(self, self._meta.primary_key.name),
converter=self._meta.primary_key.db_value))
@classmethod
def bind(cls, database, bind_refs=True, bind_backrefs=True):
@ -6327,6 +6484,18 @@ class ModelAlias(Node):
self.__dict__['alias'] = alias
def __getattr__(self, attr):
# Hack to work-around the fact that properties or other objects
# implementing the descriptor protocol (on the model being aliased),
# will not work correctly when we use getattr(). So we explicitly pass
# the model alias to the descriptor's getter.
try:
obj = self.model.__dict__[attr]
except KeyError:
pass
else:
if isinstance(obj, ModelDescriptor):
return obj.__get__(None, self)
model_attr = getattr(self.model, attr)
if isinstance(model_attr, Field):
self.__dict__[attr] = FieldAlias.create(self, model_attr)
@ -6675,6 +6844,13 @@ class ModelSelect(BaseModelSelect, Select):
return fk_fields[0], is_backref
if on is None:
# If multiple foreign-keys exist, try using the FK whose name
# matches that of the related model. If not, raise an error as this
# is ambiguous.
for fk in fk_fields:
if fk.name == dest._meta.name:
return fk, is_backref
raise ValueError('More than one foreign key between %s and %s.'
' Please specify which you are joining on.' %
(src, dest))
@ -6710,7 +6886,7 @@ class ModelSelect(BaseModelSelect, Select):
on, attr, constructor = self._normalize_join(src, dest, on, attr)
if attr:
self._joins.setdefault(src, [])
self._joins[src].append((dest, attr, constructor))
self._joins[src].append((dest, attr, constructor, join_type))
elif on is not None:
raise ValueError('Cannot specify on clause with cross join.')
@ -6732,7 +6908,7 @@ class ModelSelect(BaseModelSelect, Select):
def ensure_join(self, lm, rm, on=None, **join_kwargs):
join_ctx = self._join_ctx
for dest, attr, constructor in self._joins.get(lm, []):
for dest, _, constructor, _ in self._joins.get(lm, []):
if dest == rm:
return self
return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx)
@ -6757,7 +6933,7 @@ class ModelSelect(BaseModelSelect, Select):
model_attr = getattr(curr, key)
else:
for piece in key.split('__'):
for dest, attr, _ in self._joins.get(curr, ()):
for dest, attr, _, _ in self._joins.get(curr, ()):
if attr == piece or (isinstance(dest, ModelAlias) and
dest.alias == piece):
curr = dest
@ -6999,8 +7175,7 @@ class BaseModelCursorWrapper(DictCursorWrapper):
if raw_node._coerce:
converters[idx] = node.python_value
fields[idx] = node
if (column == node.name or column == node.column_name) and \
not raw_node.is_alias():
if not raw_node.is_alias():
self.columns[idx] = node.name
elif isinstance(node, Function) and node._coerce:
if node._python_value is not None:
@ -7100,6 +7275,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
self.src_to_dest = []
accum = collections.deque(self.from_list)
dests = set()
while accum:
curr = accum.popleft()
if isinstance(curr, Join):
@ -7110,11 +7286,14 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
if curr not in self.joins:
continue
for key, attr, constructor in self.joins[curr]:
is_dict = isinstance(curr, dict)
for key, attr, constructor, join_type in self.joins[curr]:
if key not in self.key_to_constructor:
self.key_to_constructor[key] = constructor
self.src_to_dest.append((curr, attr, key,
isinstance(curr, dict)))
# (src, attr, dest, is_dict, join_type).
self.src_to_dest.append((curr, attr, key, is_dict,
join_type))
dests.add(key)
accum.append(key)
@ -7127,7 +7306,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
self.key_to_constructor[src] = src.model
# Indicate which sources are also dests.
for src, _, dest, _ in self.src_to_dest:
for src, _, dest, _, _ in self.src_to_dest:
self.src_is_dest[src] = src in dests and (dest in selected_src
or src in selected_src)
@ -7171,7 +7350,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
setattr(instance, column, value)
# Need to do some analysis on the joins before this.
for (src, attr, dest, is_dict) in self.src_to_dest:
for (src, attr, dest, is_dict, join_type) in self.src_to_dest:
instance = objects[src]
try:
joined_instance = objects[dest]
@ -7184,6 +7363,12 @@ class ModelCursorWrapper(BaseModelCursorWrapper):
(dest not in set_keys and not self.src_is_dest.get(dest)):
continue
# If no fields were set on either the source or the destination,
# then we have nothing to do here.
if instance not in set_keys and dest not in set_keys \
and join_type.endswith('OUTER'):
continue
if is_dict:
instance[attr] = joined_instance
else:
@ -7309,7 +7494,7 @@ def prefetch(sq, *subqueries):
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(pq)
deps[query_model] = {}
deps.setdefault(query_model, {})
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))

@ -61,12 +61,14 @@ class DataSet(object):
def get_export_formats(self):
return {
'csv': CSVExporter,
'json': JSONExporter}
'json': JSONExporter,
'tsv': TSVExporter}
def get_import_formats(self):
return {
'csv': CSVImporter,
'json': JSONImporter}
'json': JSONImporter,
'tsv': TSVImporter}
def __getitem__(self, table):
if table not in self._models and table in self.tables:
@ -244,6 +246,29 @@ class Table(object):
self.dataset.update_cache(self.name)
def __getitem__(self, item):
try:
return self.model_class[item]
except self.model_class.DoesNotExist:
pass
def __setitem__(self, item, value):
if not isinstance(value, dict):
raise ValueError('Table.__setitem__() value must be a dict')
pk = self.model_class._meta.primary_key
value[pk.name] = item
try:
with self.dataset.transaction() as txn:
self.insert(**value)
except IntegrityError:
self.dataset.update_cache(self.name)
self.update(columns=[pk.name], **value)
def __delitem__(self, item):
del self.model_class[item]
def insert(self, **data):
self._migrate_new_columns(data)
return self.model_class.insert(**data).execute()
@ -343,6 +368,12 @@ class CSVExporter(Exporter):
writer.writerow(row)
class TSVExporter(CSVExporter):
def export(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVExporter, self).export(file_obj, header, **kwargs)
class Importer(object):
def __init__(self, table, strict=False):
self.table = table
@ -413,3 +444,9 @@ class CSVImporter(Importer):
count += 1
return count
class TSVImporter(CSVImporter):
def load(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVImporter, self).load(file_obj, header, **kwargs)

@ -1,6 +1,9 @@
from peewee import ModelDescriptor
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
class hybrid_method(object):
class hybrid_method(ModelDescriptor):
def __init__(self, func, expr=None):
self.func = func
self.expr = expr or func
@ -15,7 +18,7 @@ class hybrid_method(object):
return self
class hybrid_property(object):
class hybrid_property(ModelDescriptor):
def __init__(self, fget, fset=None, fdel=None, expr=None):
self.fget = fget
self.fset = fset

@ -777,7 +777,7 @@ class SqliteMigrator(SchemaMigrator):
clean = []
for column in columns:
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
column = new_columne + column[len(column_to_update):]
column = new_column + column[len(column_to_update):]
clean.append(column)
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))

@ -7,7 +7,10 @@ except ImportError:
from peewee import ImproperlyConfigured
from peewee import MySQLDatabase
from peewee import NodeList
from peewee import SQL
from peewee import TextField
from peewee import fn
class MySQLConnectorDatabase(MySQLDatabase):
@ -18,7 +21,10 @@ class MySQLConnectorDatabase(MySQLDatabase):
def cursor(self, commit=None):
if self.is_closed():
self.connect()
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor(buffered=True)
@ -32,3 +38,12 @@ class JSONField(TextField):
def python_value(self, value):
if value is not None:
return json.loads(value)
def Match(columns, expr, modifier=None):
if isinstance(columns, (list, tuple)):
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
else:
match = fn.MATCH(columns) # Single column / field.
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
return NodeList((match, fn.AGAINST(args)))

@ -33,6 +33,7 @@ That's it!
"""
import heapq
import logging
import random
import time
from collections import namedtuple
from itertools import chain
@ -153,7 +154,7 @@ class PooledDatabase(object):
len(self._in_use) >= self._max_connections):
raise MaxConnectionsExceeded('Exceeded maximum connections.')
conn = super(PooledDatabase, self)._connect()
ts = time.time()
ts = time.time() - random.random() / 1000
key = self.conn_key(conn)
logger.debug('Created new connection %s.', key)

@ -3,6 +3,7 @@ Collection of postgres-specific extensions, currently including:
* Support for hstore, a key/value type storage
"""
import json
import logging
import uuid
@ -277,18 +278,19 @@ class HStoreField(IndexedFieldMixin, Field):
class JSONField(Field):
field_type = 'JSON'
_json_datatype = 'json'
def __init__(self, dumps=None, *args, **kwargs):
if Json is None:
raise Exception('Your version of psycopg2 does not support JSON.')
self.dumps = dumps
self.dumps = dumps or json.dumps
super(JSONField, self).__init__(*args, **kwargs)
def db_value(self, value):
if value is None:
return value
if not isinstance(value, Json):
return Json(value, dumps=self.dumps)
return Cast(self.dumps(value), self._json_datatype)
return value
def __getitem__(self, value):
@ -307,6 +309,7 @@ def cast_jsonb(node):
class BinaryJSONField(IndexedFieldMixin, JSONField):
field_type = 'JSONB'
_json_datatype = 'jsonb'
__hash__ = Field.__hash__
def contains(self, other):
@ -449,7 +452,10 @@ class PostgresqlExtDatabase(PostgresqlDatabase):
def cursor(self, commit=None):
if self.is_closed():
self.connect()
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
if commit is __named_cursor__:
return self._state.conn.cursor(name=str(uuid.uuid1()))
return self._state.conn.cursor()

@ -224,7 +224,7 @@ class PostgresqlMetadata(Metadata):
23: IntegerField,
25: TextField,
700: FloatField,
701: FloatField,
701: DoubleField,
1042: CharField, # blank-padded CHAR
1043: CharField,
1082: DateField,

@ -73,7 +73,7 @@ def model_to_dict(model, recurse=True, backrefs=False, only=None,
field_data = model.__data__.get(field.name)
if isinstance(field, ForeignKeyField) and recurse:
if field_data:
if field_data is not None:
seen.add(field)
rel_obj = getattr(model, field.name)
field_data = model_to_dict(
@ -191,6 +191,10 @@ class ReconnectMixin(object):
(OperationalError, '2006'), # MySQL server has gone away.
(OperationalError, '2013'), # Lost connection to MySQL server.
(OperationalError, '2014'), # Commands out of sync.
# mysql-connector raises a slightly different error when an idle
# connection is terminated by the server. This is equivalent to 2013.
(OperationalError, 'MySQL Connection not available.'),
)
def __init__(self, *args, **kwargs):

@ -65,7 +65,7 @@ class Model(_Model):
pre_init.send(self)
def save(self, *args, **kwargs):
pk_value = self._pk
pk_value = self._pk if self._meta.primary_key else True
created = kwargs.get('force_insert', False) or not bool(pk_value)
pre_save.send(self, created=created)
ret = super(Model, self).save(*args, **kwargs)

@ -69,6 +69,11 @@ class AutoIncrementField(AutoField):
return NodeList((node_list, SQL('AUTOINCREMENT')))
class TDecimalField(DecimalField):
field_type = 'TEXT'
def get_modifiers(self): pass
class JSONPath(ColumnBase):
def __init__(self, field, path=None):
super(JSONPath, self).__init__()
@ -1190,6 +1195,33 @@ def bm25(raw_match_info, *args):
L_O = A_O + col_count
X_O = L_O + col_count
# Worked example of pcnalx for two columns and two phrases, 100 docs total.
# {
# p = 2
# c = 2
# n = 100
# a0 = 4 -- avg number of tokens for col0, e.g. title
# a1 = 40 -- avg number of tokens for col1, e.g. body
# l0 = 5 -- curr doc has 5 tokens in col0
# l1 = 30 -- curr doc has 30 tokens in col1
#
# x000 -- hits this row for phrase0, col0
# x001 -- hits all rows for phrase0, col0
# x002 -- rows with phrase0 in col0 at least once
#
# x010 -- hits this row for phrase0, col1
# x011 -- hits all rows for phrase0, col1
# x012 -- rows with phrase0 in col1 at least once
#
# x100 -- hits this row for phrase1, col0
# x101 -- hits all rows for phrase1, col0
# x102 -- rows with phrase1 in col0 at least once
#
# x110 -- hits this row for phrase1, col1
# x111 -- hits all rows for phrase1, col1
# x112 -- rows with phrase1 in col1 at least once
# }
weights = get_weights(col_count, args)
for i in range(term_count):
@ -1213,8 +1245,8 @@ def bm25(raw_match_info, *args):
avg_length = float(match_info[A_O + j]) or 1. # avgdl
ratio = doc_length / avg_length
num = term_frequency * (K + 1)
b_part = 1 - B + (B * ratio)
num = term_frequency * (K + 1.0)
b_part = 1.0 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * (num / denom)

Loading…
Cancel
Save