You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
9.8 KiB
281 lines
9.8 KiB
import threading
|
|
|
|
from peewee import *
|
|
from peewee import Alias
|
|
from peewee import CompoundSelectQuery
|
|
from peewee import Metadata
|
|
from peewee import SENTINEL
|
|
from peewee import callable_
|
|
|
|
|
|
_clone_set = lambda s: set(s) if s else set()
|
|
|
|
|
|
def model_to_dict(model, recurse=True, backrefs=False, only=None,
|
|
exclude=None, seen=None, extra_attrs=None,
|
|
fields_from_query=None, max_depth=None, manytomany=False):
|
|
"""
|
|
Convert a model instance (and any related objects) to a dictionary.
|
|
|
|
:param bool recurse: Whether foreign-keys should be recursed.
|
|
:param bool backrefs: Whether lists of related objects should be recursed.
|
|
:param only: A list (or set) of field instances indicating which fields
|
|
should be included.
|
|
:param exclude: A list (or set) of field instances that should be
|
|
excluded from the dictionary.
|
|
:param list extra_attrs: Names of model instance attributes or methods
|
|
that should be included.
|
|
:param SelectQuery fields_from_query: Query that was source of model. Take
|
|
fields explicitly selected by the query and serialize them.
|
|
:param int max_depth: Maximum depth to recurse, value <= 0 means no max.
|
|
:param bool manytomany: Process many-to-many fields.
|
|
"""
|
|
max_depth = -1 if max_depth is None else max_depth
|
|
if max_depth == 0:
|
|
recurse = False
|
|
|
|
only = _clone_set(only)
|
|
extra_attrs = _clone_set(extra_attrs)
|
|
should_skip = lambda n: (n in exclude) or (only and (n not in only))
|
|
|
|
if fields_from_query is not None:
|
|
for item in fields_from_query._returning:
|
|
if isinstance(item, Field):
|
|
only.add(item)
|
|
elif isinstance(item, Alias):
|
|
extra_attrs.add(item._alias)
|
|
|
|
data = {}
|
|
exclude = _clone_set(exclude)
|
|
seen = _clone_set(seen)
|
|
exclude |= seen
|
|
model_class = type(model)
|
|
|
|
if manytomany:
|
|
for name, m2m in model._meta.manytomany.items():
|
|
if should_skip(name):
|
|
continue
|
|
|
|
exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref]))
|
|
for fkf in m2m.through_model._meta.refs:
|
|
exclude.add(fkf)
|
|
|
|
accum = []
|
|
for rel_obj in getattr(model, name):
|
|
accum.append(model_to_dict(
|
|
rel_obj,
|
|
recurse=recurse,
|
|
backrefs=backrefs,
|
|
only=only,
|
|
exclude=exclude,
|
|
max_depth=max_depth - 1))
|
|
data[name] = accum
|
|
|
|
for field in model._meta.sorted_fields:
|
|
if should_skip(field):
|
|
continue
|
|
|
|
field_data = model.__data__.get(field.name)
|
|
if isinstance(field, ForeignKeyField) and recurse:
|
|
if field_data is not None:
|
|
seen.add(field)
|
|
rel_obj = getattr(model, field.name)
|
|
field_data = model_to_dict(
|
|
rel_obj,
|
|
recurse=recurse,
|
|
backrefs=backrefs,
|
|
only=only,
|
|
exclude=exclude,
|
|
seen=seen,
|
|
max_depth=max_depth - 1)
|
|
else:
|
|
field_data = None
|
|
|
|
data[field.name] = field_data
|
|
|
|
if extra_attrs:
|
|
for attr_name in extra_attrs:
|
|
attr = getattr(model, attr_name)
|
|
if callable_(attr):
|
|
data[attr_name] = attr()
|
|
else:
|
|
data[attr_name] = attr
|
|
|
|
if backrefs and recurse:
|
|
for foreign_key, rel_model in model._meta.backrefs.items():
|
|
if foreign_key.backref == '+': continue
|
|
descriptor = getattr(model_class, foreign_key.backref)
|
|
if descriptor in exclude or foreign_key in exclude:
|
|
continue
|
|
if only and (descriptor not in only) and (foreign_key not in only):
|
|
continue
|
|
|
|
accum = []
|
|
exclude.add(foreign_key)
|
|
related_query = getattr(model, foreign_key.backref)
|
|
|
|
for rel_obj in related_query:
|
|
accum.append(model_to_dict(
|
|
rel_obj,
|
|
recurse=recurse,
|
|
backrefs=backrefs,
|
|
only=only,
|
|
exclude=exclude,
|
|
max_depth=max_depth - 1))
|
|
|
|
data[foreign_key.backref] = accum
|
|
|
|
return data
|
|
|
|
|
|
def update_model_from_dict(instance, data, ignore_unknown=False):
|
|
meta = instance._meta
|
|
backrefs = dict([(fk.backref, fk) for fk in meta.backrefs])
|
|
|
|
for key, value in data.items():
|
|
if key in meta.combined:
|
|
field = meta.combined[key]
|
|
is_backref = False
|
|
elif key in backrefs:
|
|
field = backrefs[key]
|
|
is_backref = True
|
|
elif ignore_unknown:
|
|
setattr(instance, key, value)
|
|
continue
|
|
else:
|
|
raise AttributeError('Unrecognized attribute "%s" for model '
|
|
'class %s.' % (key, type(instance)))
|
|
|
|
is_foreign_key = isinstance(field, ForeignKeyField)
|
|
|
|
if not is_backref and is_foreign_key and isinstance(value, dict):
|
|
try:
|
|
rel_instance = instance.__rel__[field.name]
|
|
except KeyError:
|
|
rel_instance = field.rel_model()
|
|
setattr(
|
|
instance,
|
|
field.name,
|
|
update_model_from_dict(rel_instance, value, ignore_unknown))
|
|
elif is_backref and isinstance(value, (list, tuple)):
|
|
instances = [
|
|
dict_to_model(field.model, row_data, ignore_unknown)
|
|
for row_data in value]
|
|
for rel_instance in instances:
|
|
setattr(rel_instance, field.name, instance)
|
|
setattr(instance, field.backref, instances)
|
|
else:
|
|
setattr(instance, field.name, value)
|
|
|
|
return instance
|
|
|
|
|
|
def dict_to_model(model_class, data, ignore_unknown=False):
|
|
return update_model_from_dict(model_class(), data, ignore_unknown)
|
|
|
|
|
|
class ReconnectMixin(object):
|
|
"""
|
|
Mixin class that attempts to automatically reconnect to the database under
|
|
certain error conditions.
|
|
|
|
For example, MySQL servers will typically close connections that are idle
|
|
for 28800 seconds ("wait_timeout" setting). If your application makes use
|
|
of long-lived connections, you may find your connections are closed after
|
|
a period of no activity. This mixin will attempt to reconnect automatically
|
|
when these errors occur.
|
|
|
|
This mixin class probably should not be used with Postgres (unless you
|
|
REALLY know what you are doing) and definitely has no business being used
|
|
with Sqlite. If you wish to use with Postgres, you will need to adapt the
|
|
`reconnect_errors` attribute to something appropriate for Postgres.
|
|
"""
|
|
reconnect_errors = (
|
|
# Error class, error message fragment (or empty string for all).
|
|
(OperationalError, '2006'), # MySQL server has gone away.
|
|
(OperationalError, '2013'), # Lost connection to MySQL server.
|
|
(OperationalError, '2014'), # Commands out of sync.
|
|
(OperationalError, '4031'), # Client interaction timeout.
|
|
|
|
# mysql-connector raises a slightly different error when an idle
|
|
# connection is terminated by the server. This is equivalent to 2013.
|
|
(OperationalError, 'MySQL Connection not available.'),
|
|
)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ReconnectMixin, self).__init__(*args, **kwargs)
|
|
|
|
# Normalize the reconnect errors to a more efficient data-structure.
|
|
self._reconnect_errors = {}
|
|
for exc_class, err_fragment in self.reconnect_errors:
|
|
self._reconnect_errors.setdefault(exc_class, [])
|
|
self._reconnect_errors[exc_class].append(err_fragment.lower())
|
|
|
|
def execute_sql(self, sql, params=None, commit=SENTINEL):
|
|
try:
|
|
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
|
except Exception as exc:
|
|
exc_class = type(exc)
|
|
if exc_class not in self._reconnect_errors:
|
|
raise exc
|
|
|
|
exc_repr = str(exc).lower()
|
|
for err_fragment in self._reconnect_errors[exc_class]:
|
|
if err_fragment in exc_repr:
|
|
break
|
|
else:
|
|
raise exc
|
|
|
|
if not self.is_closed():
|
|
self.close()
|
|
self.connect()
|
|
|
|
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
|
|
|
|
|
def resolve_multimodel_query(query, key='_model_identifier'):
|
|
mapping = {}
|
|
accum = [query]
|
|
while accum:
|
|
curr = accum.pop()
|
|
if isinstance(curr, CompoundSelectQuery):
|
|
accum.extend((curr.lhs, curr.rhs))
|
|
continue
|
|
|
|
model_class = curr.model
|
|
name = model_class._meta.table_name
|
|
mapping[name] = model_class
|
|
curr._returning.append(Value(name).alias(key))
|
|
|
|
def wrapped_iterator():
|
|
for row in query.dicts().iterator():
|
|
identifier = row.pop(key)
|
|
model = mapping[identifier]
|
|
yield model(**row)
|
|
|
|
return wrapped_iterator()
|
|
|
|
|
|
class ThreadSafeDatabaseMetadata(Metadata):
|
|
"""
|
|
Metadata class to allow swapping database at run-time in a multi-threaded
|
|
application. To use:
|
|
|
|
class Base(Model):
|
|
class Meta:
|
|
model_metadata_class = ThreadSafeDatabaseMetadata
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
# The database attribute is stored in a thread-local.
|
|
self._database = None
|
|
self._local = threading.local()
|
|
super(ThreadSafeDatabaseMetadata, self).__init__(*args, **kwargs)
|
|
|
|
def _get_db(self):
|
|
return getattr(self._local, 'database', self._database)
|
|
def _set_db(self, db):
|
|
if self._database is None:
|
|
self._database = db
|
|
self._local.database = db
|
|
database = property(_get_db, _set_db)
|