You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
286 lines
8.0 KiB
286 lines
8.0 KiB
import copy
|
|
import re
|
|
import warnings
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections.abc import MutableMapping
|
|
from werkzeug.utils import cached_property
|
|
|
|
from .mask import Mask
|
|
from .errors import abort
|
|
|
|
from jsonschema import Draft4Validator
|
|
from jsonschema.exceptions import ValidationError
|
|
|
|
from .utils import not_none
|
|
from ._http import HTTPStatus
|
|
|
|
|
|
RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
|
|
|
|
|
|
def instance(cls):
|
|
if isinstance(cls, type):
|
|
return cls()
|
|
return cls
|
|
|
|
|
|
class ModelBase(object):
|
|
"""
|
|
Handles validation and swagger style inheritance for both subclasses.
|
|
Subclass must define `schema` attribute.
|
|
|
|
:param str name: The model public name
|
|
"""
|
|
|
|
def __init__(self, name, *args, **kwargs):
|
|
super(ModelBase, self).__init__(*args, **kwargs)
|
|
self.__apidoc__ = {"name": name}
|
|
self.name = name
|
|
self.__parents__ = []
|
|
|
|
def instance_inherit(name, *parents):
|
|
return self.__class__.inherit(name, self, *parents)
|
|
|
|
self.inherit = instance_inherit
|
|
|
|
@property
|
|
def ancestors(self):
|
|
"""
|
|
Return the ancestors tree
|
|
"""
|
|
ancestors = [p.ancestors for p in self.__parents__]
|
|
return set.union(set([self.name]), *ancestors)
|
|
|
|
def get_parent(self, name):
|
|
if self.name == name:
|
|
return self
|
|
else:
|
|
for parent in self.__parents__:
|
|
found = parent.get_parent(name)
|
|
if found:
|
|
return found
|
|
raise ValueError("Parent " + name + " not found")
|
|
|
|
@property
|
|
def __schema__(self):
|
|
schema = self._schema
|
|
|
|
if self.__parents__:
|
|
refs = [
|
|
{"$ref": "#/definitions/{0}".format(parent.name)}
|
|
for parent in self.__parents__
|
|
]
|
|
|
|
return {"allOf": refs + [schema]}
|
|
else:
|
|
return schema
|
|
|
|
@classmethod
|
|
def inherit(cls, name, *parents):
|
|
"""
|
|
Inherit this model (use the Swagger composition pattern aka. allOf)
|
|
:param str name: The new model name
|
|
:param dict fields: The new model extra fields
|
|
"""
|
|
model = cls(name, parents[-1])
|
|
model.__parents__ = parents[:-1]
|
|
return model
|
|
|
|
def validate(self, data, resolver=None, format_checker=None):
|
|
validator = Draft4Validator(
|
|
self.__schema__, resolver=resolver, format_checker=format_checker
|
|
)
|
|
try:
|
|
validator.validate(data)
|
|
except ValidationError:
|
|
abort(
|
|
HTTPStatus.BAD_REQUEST,
|
|
message="Input payload validation failed",
|
|
errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
|
|
)
|
|
|
|
def format_error(self, error):
|
|
path = list(error.path)
|
|
if error.validator == "required":
|
|
name = RE_REQUIRED.match(error.message).group("name")
|
|
path.append(name)
|
|
key = ".".join(str(p) for p in path)
|
|
return key, error.message
|
|
|
|
def __unicode__(self):
|
|
return "Model({name},{{{fields}}})".format(
|
|
name=self.name, fields=",".join(self.keys())
|
|
)
|
|
|
|
__str__ = __unicode__
|
|
|
|
|
|
class RawModel(ModelBase):
|
|
"""
|
|
A thin wrapper on ordered fields dict to store API doc metadata.
|
|
Can also be used for response marshalling.
|
|
|
|
:param str name: The model public name
|
|
:param str mask: an optional default model mask
|
|
:param bool strict: validation should raise error when there is param not provided in schema
|
|
"""
|
|
|
|
wrapper = dict
|
|
|
|
def __init__(self, name, *args, **kwargs):
|
|
self.__mask__ = kwargs.pop("mask", None)
|
|
self.__strict__ = kwargs.pop("strict", False)
|
|
if self.__mask__ and not isinstance(self.__mask__, Mask):
|
|
self.__mask__ = Mask(self.__mask__)
|
|
super(RawModel, self).__init__(name, *args, **kwargs)
|
|
|
|
def instance_clone(name, *parents):
|
|
return self.__class__.clone(name, self, *parents)
|
|
|
|
self.clone = instance_clone
|
|
|
|
@property
|
|
def _schema(self):
|
|
properties = self.wrapper()
|
|
required = set()
|
|
discriminator = None
|
|
for name, field in self.items():
|
|
field = instance(field)
|
|
properties[name] = field.__schema__
|
|
if field.required:
|
|
required.add(name)
|
|
if getattr(field, "discriminator", False):
|
|
discriminator = name
|
|
|
|
definition = {
|
|
"required": sorted(list(required)) or None,
|
|
"properties": properties,
|
|
"discriminator": discriminator,
|
|
"x-mask": str(self.__mask__) if self.__mask__ else None,
|
|
"type": "object",
|
|
}
|
|
|
|
if self.__strict__:
|
|
definition["additionalProperties"] = False
|
|
|
|
return not_none(definition)
|
|
|
|
@cached_property
|
|
def resolved(self):
|
|
"""
|
|
Resolve real fields before submitting them to marshal
|
|
"""
|
|
# Duplicate fields
|
|
resolved = copy.deepcopy(self)
|
|
|
|
# Recursively copy parent fields if necessary
|
|
for parent in self.__parents__:
|
|
resolved.update(parent.resolved)
|
|
|
|
# Handle discriminator
|
|
candidates = [f for f in resolved.values() if getattr(f, "discriminator", None)]
|
|
# Ensure the is only one discriminator
|
|
if len(candidates) > 1:
|
|
raise ValueError("There can only be one discriminator by schema")
|
|
# Ensure discriminator always output the model name
|
|
elif len(candidates) == 1:
|
|
candidates[0].default = self.name
|
|
|
|
return resolved
|
|
|
|
def extend(self, name, fields):
|
|
"""
|
|
Extend this model (Duplicate all fields)
|
|
|
|
:param str name: The new model name
|
|
:param dict fields: The new model extra fields
|
|
|
|
:deprecated: since 0.9. Use :meth:`clone` instead.
|
|
"""
|
|
warnings.warn(
|
|
"extend is is deprecated, use clone instead",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
if isinstance(fields, (list, tuple)):
|
|
return self.clone(name, *fields)
|
|
else:
|
|
return self.clone(name, fields)
|
|
|
|
@classmethod
|
|
def clone(cls, name, *parents):
|
|
"""
|
|
Clone these models (Duplicate all fields)
|
|
|
|
It can be used from the class
|
|
|
|
>>> model = Model.clone(fields_1, fields_2)
|
|
|
|
or from an Instanciated model
|
|
|
|
>>> new_model = model.clone(fields_1, fields_2)
|
|
|
|
:param str name: The new model name
|
|
:param dict parents: The new model extra fields
|
|
"""
|
|
fields = cls.wrapper()
|
|
for parent in parents:
|
|
fields.update(copy.deepcopy(parent))
|
|
return cls(name, fields)
|
|
|
|
def __deepcopy__(self, memo):
|
|
obj = self.__class__(
|
|
self.name,
|
|
[(key, copy.deepcopy(value, memo)) for key, value in self.items()],
|
|
mask=self.__mask__,
|
|
strict=self.__strict__,
|
|
)
|
|
obj.__parents__ = self.__parents__
|
|
return obj
|
|
|
|
|
|
class Model(RawModel, dict, MutableMapping):
|
|
"""
|
|
A thin wrapper on fields dict to store API doc metadata.
|
|
Can also be used for response marshalling.
|
|
|
|
:param str name: The model public name
|
|
:param str mask: an optional default model mask
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class OrderedModel(RawModel, OrderedDict, MutableMapping):
|
|
"""
|
|
A thin wrapper on ordered fields dict to store API doc metadata.
|
|
Can also be used for response marshalling.
|
|
|
|
:param str name: The model public name
|
|
:param str mask: an optional default model mask
|
|
"""
|
|
|
|
wrapper = OrderedDict
|
|
|
|
|
|
class SchemaModel(ModelBase):
|
|
"""
|
|
Stores API doc metadata based on a json schema.
|
|
|
|
:param str name: The model public name
|
|
:param dict schema: The json schema we are documenting
|
|
"""
|
|
|
|
def __init__(self, name, schema=None):
|
|
super(SchemaModel, self).__init__(name)
|
|
self._schema = schema or {}
|
|
|
|
def __unicode__(self):
|
|
return "SchemaModel({name},{schema})".format(
|
|
name=self.name, schema=self._schema
|
|
)
|
|
|
|
__str__ = __unicode__
|