You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
6.9 KiB
168 lines
6.9 KiB
2 years ago
|
from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants
|
||
|
from pyrsistent._field_common import (
|
||
|
set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
|
||
|
)
|
||
|
from pyrsistent._pmap import PMap, pmap
|
||
|
|
||
|
|
||
|
class _PRecordMeta(type):
|
||
|
def __new__(mcs, name, bases, dct):
|
||
|
set_fields(dct, bases, name='_precord_fields')
|
||
|
store_invariants(dct, bases, '_precord_invariants', '__invariant__')
|
||
|
|
||
|
dct['_precord_mandatory_fields'] = \
|
||
|
set(name for name, field in dct['_precord_fields'].items() if field.mandatory)
|
||
|
|
||
|
dct['_precord_initial_values'] = \
|
||
|
dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL)
|
||
|
|
||
|
|
||
|
dct['__slots__'] = ()
|
||
|
|
||
|
return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct)
|
||
|
|
||
|
|
||
|
class PRecord(PMap, CheckedType, metaclass=_PRecordMeta):
|
||
|
"""
|
||
|
A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting
|
||
|
from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element
|
||
|
access using subscript notation.
|
||
|
|
||
|
More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent
|
||
|
"""
|
||
|
def __new__(cls, **kwargs):
|
||
|
# Hack total! If these two special attributes exist that means we can create
|
||
|
# ourselves. Otherwise we need to go through the Evolver to create the structures
|
||
|
# for us.
|
||
|
if '_precord_size' in kwargs and '_precord_buckets' in kwargs:
|
||
|
return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets'])
|
||
|
|
||
|
factory_fields = kwargs.pop('_factory_fields', None)
|
||
|
ignore_extra = kwargs.pop('_ignore_extra', False)
|
||
|
|
||
|
initial_values = kwargs
|
||
|
if cls._precord_initial_values:
|
||
|
initial_values = dict((k, v() if callable(v) else v)
|
||
|
for k, v in cls._precord_initial_values.items())
|
||
|
initial_values.update(kwargs)
|
||
|
|
||
|
e = _PRecordEvolver(cls, pmap(pre_size=len(cls._precord_fields)), _factory_fields=factory_fields, _ignore_extra=ignore_extra)
|
||
|
for k, v in initial_values.items():
|
||
|
e[k] = v
|
||
|
|
||
|
return e.persistent()
|
||
|
|
||
|
def set(self, *args, **kwargs):
|
||
|
"""
|
||
|
Set a field in the record. This set function differs slightly from that in the PMap
|
||
|
class. First of all it accepts key-value pairs. Second it accepts multiple key-value
|
||
|
pairs to perform one, atomic, update of multiple fields.
|
||
|
"""
|
||
|
|
||
|
# The PRecord set() can accept kwargs since all fields that have been declared are
|
||
|
# valid python identifiers. Also allow multiple fields to be set in one operation.
|
||
|
if args:
|
||
|
return super(PRecord, self).set(args[0], args[1])
|
||
|
|
||
|
return self.update(kwargs)
|
||
|
|
||
|
def evolver(self):
|
||
|
"""
|
||
|
Returns an evolver of this object.
|
||
|
"""
|
||
|
return _PRecordEvolver(self.__class__, self)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "{0}({1})".format(self.__class__.__name__,
|
||
|
', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items()))
|
||
|
|
||
|
@classmethod
|
||
|
def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
|
||
|
"""
|
||
|
Factory method. Will create a new PRecord of the current type and assign the values
|
||
|
specified in kwargs.
|
||
|
|
||
|
:param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
|
||
|
in the set of fields on the PRecord.
|
||
|
"""
|
||
|
if isinstance(kwargs, cls):
|
||
|
return kwargs
|
||
|
|
||
|
if ignore_extra:
|
||
|
kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs}
|
||
|
|
||
|
return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs)
|
||
|
|
||
|
def __reduce__(self):
|
||
|
# Pickling support
|
||
|
return _restore_pickle, (self.__class__, dict(self),)
|
||
|
|
||
|
def serialize(self, format=None):
|
||
|
"""
|
||
|
Serialize the current PRecord using custom serializer functions for fields where
|
||
|
such have been supplied.
|
||
|
"""
|
||
|
return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items())
|
||
|
|
||
|
|
||
|
class _PRecordEvolver(PMap._Evolver):
|
||
|
__slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra')
|
||
|
|
||
|
def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False):
|
||
|
super(_PRecordEvolver, self).__init__(original_pmap)
|
||
|
self._destination_cls = cls
|
||
|
self._invariant_error_codes = []
|
||
|
self._missing_fields = []
|
||
|
self._factory_fields = _factory_fields
|
||
|
self._ignore_extra = _ignore_extra
|
||
|
|
||
|
def __setitem__(self, key, original_value):
|
||
|
self.set(key, original_value)
|
||
|
|
||
|
def set(self, key, original_value):
|
||
|
field = self._destination_cls._precord_fields.get(key)
|
||
|
if field:
|
||
|
if self._factory_fields is None or field in self._factory_fields:
|
||
|
try:
|
||
|
if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra):
|
||
|
value = field.factory(original_value, ignore_extra=self._ignore_extra)
|
||
|
else:
|
||
|
value = field.factory(original_value)
|
||
|
except InvariantException as e:
|
||
|
self._invariant_error_codes += e.invariant_errors
|
||
|
self._missing_fields += e.missing_fields
|
||
|
return self
|
||
|
else:
|
||
|
value = original_value
|
||
|
|
||
|
check_type(self._destination_cls, field, key, value)
|
||
|
|
||
|
is_ok, error_code = field.invariant(value)
|
||
|
if not is_ok:
|
||
|
self._invariant_error_codes.append(error_code)
|
||
|
|
||
|
return super(_PRecordEvolver, self).set(key, value)
|
||
|
else:
|
||
|
raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__))
|
||
|
|
||
|
def persistent(self):
|
||
|
cls = self._destination_cls
|
||
|
is_dirty = self.is_dirty()
|
||
|
pm = super(_PRecordEvolver, self).persistent()
|
||
|
if is_dirty or not isinstance(pm, cls):
|
||
|
result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size)
|
||
|
else:
|
||
|
result = pm
|
||
|
|
||
|
if cls._precord_mandatory_fields:
|
||
|
self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f
|
||
|
in (cls._precord_mandatory_fields - set(result.keys())))
|
||
|
|
||
|
if self._invariant_error_codes or self._missing_fields:
|
||
|
raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields),
|
||
|
'Field invariant failed')
|
||
|
|
||
|
check_global_invariants(result, cls._precord_invariants)
|
||
|
|
||
|
return result
|