You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
804 lines
25 KiB
804 lines
25 KiB
11 months ago
|
import keyword
|
||
|
import warnings
|
||
|
import weakref
|
||
|
from collections import OrderedDict, defaultdict, deque
|
||
|
from copy import deepcopy
|
||
|
from itertools import islice, zip_longest
|
||
|
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
AbstractSet,
|
||
|
Any,
|
||
|
Callable,
|
||
|
Collection,
|
||
|
Dict,
|
||
|
Generator,
|
||
|
Iterable,
|
||
|
Iterator,
|
||
|
List,
|
||
|
Mapping,
|
||
|
NoReturn,
|
||
|
Optional,
|
||
|
Set,
|
||
|
Tuple,
|
||
|
Type,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
from typing_extensions import Annotated
|
||
|
|
||
|
from .errors import ConfigError
|
||
|
from .typing import (
|
||
|
NoneType,
|
||
|
WithArgsTypes,
|
||
|
all_literal_values,
|
||
|
display_as_type,
|
||
|
get_args,
|
||
|
get_origin,
|
||
|
is_literal_type,
|
||
|
is_union,
|
||
|
)
|
||
|
from .version import version_info
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from inspect import Signature
|
||
|
from pathlib import Path
|
||
|
|
||
|
from .config import BaseConfig
|
||
|
from .dataclasses import Dataclass
|
||
|
from .fields import ModelField
|
||
|
from .main import BaseModel
|
||
|
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||
|
|
||
|
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
|
||
|
|
||
|
__all__ = (
|
||
|
'import_string',
|
||
|
'sequence_like',
|
||
|
'validate_field_name',
|
||
|
'lenient_isinstance',
|
||
|
'lenient_issubclass',
|
||
|
'in_ipython',
|
||
|
'is_valid_identifier',
|
||
|
'deep_update',
|
||
|
'update_not_none',
|
||
|
'almost_equal_floats',
|
||
|
'get_model',
|
||
|
'to_camel',
|
||
|
'is_valid_field',
|
||
|
'smart_deepcopy',
|
||
|
'PyObjectStr',
|
||
|
'Representation',
|
||
|
'GetterDict',
|
||
|
'ValueItems',
|
||
|
'version_info', # required here to match behaviour in v1.3
|
||
|
'ClassAttribute',
|
||
|
'path_type',
|
||
|
'ROOT_KEY',
|
||
|
'get_unique_discriminator_alias',
|
||
|
'get_discriminator_alias_and_values',
|
||
|
'DUNDER_ATTRIBUTES',
|
||
|
)
|
||
|
|
||
|
ROOT_KEY = '__root__'
|
||
|
# these are types that are returned unchanged by deepcopy
|
||
|
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
|
||
|
int,
|
||
|
float,
|
||
|
complex,
|
||
|
str,
|
||
|
bool,
|
||
|
bytes,
|
||
|
type,
|
||
|
NoneType,
|
||
|
FunctionType,
|
||
|
BuiltinFunctionType,
|
||
|
LambdaType,
|
||
|
weakref.ref,
|
||
|
CodeType,
|
||
|
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||
|
# It might be not a good idea in general, but considering that this function used only internally
|
||
|
# against default values of fields, this will allow to actually have a field with module as default value
|
||
|
ModuleType,
|
||
|
NotImplemented.__class__,
|
||
|
Ellipsis.__class__,
|
||
|
}
|
||
|
|
||
|
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||
|
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
|
||
|
list,
|
||
|
set,
|
||
|
tuple,
|
||
|
frozenset,
|
||
|
dict,
|
||
|
OrderedDict,
|
||
|
defaultdict,
|
||
|
deque,
|
||
|
}
|
||
|
|
||
|
|
||
|
def import_string(dotted_path: str) -> Any:
|
||
|
"""
|
||
|
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
|
||
|
last name in the path. Raise ImportError if the import fails.
|
||
|
"""
|
||
|
from importlib import import_module
|
||
|
|
||
|
try:
|
||
|
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
|
||
|
except ValueError as e:
|
||
|
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
|
||
|
|
||
|
module = import_module(module_path)
|
||
|
try:
|
||
|
return getattr(module, class_name)
|
||
|
except AttributeError as e:
|
||
|
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
|
||
|
|
||
|
|
||
|
def truncate(v: Union[str], *, max_len: int = 80) -> str:
|
||
|
"""
|
||
|
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
|
||
|
"""
|
||
|
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
|
||
|
if isinstance(v, str) and len(v) > (max_len - 2):
|
||
|
# -3 so quote + string + … + quote has correct length
|
||
|
return (v[: (max_len - 3)] + '…').__repr__()
|
||
|
try:
|
||
|
v = v.__repr__()
|
||
|
except TypeError:
|
||
|
v = v.__class__.__repr__(v) # in case v is a type
|
||
|
if len(v) > max_len:
|
||
|
v = v[: max_len - 1] + '…'
|
||
|
return v
|
||
|
|
||
|
|
||
|
def sequence_like(v: Any) -> bool:
|
||
|
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||
|
|
||
|
|
||
|
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
|
||
|
"""
|
||
|
Ensure that the field's name does not shadow an existing attribute of the model.
|
||
|
"""
|
||
|
for base in bases:
|
||
|
if getattr(base, field_name, None):
|
||
|
raise NameError(
|
||
|
f'Field name "{field_name}" shadows a BaseModel attribute; '
|
||
|
f'use a different field name with "alias=\'{field_name}\'".'
|
||
|
)
|
||
|
|
||
|
|
||
|
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||
|
try:
|
||
|
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||
|
except TypeError:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||
|
try:
|
||
|
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||
|
except TypeError:
|
||
|
if isinstance(cls, WithArgsTypes):
|
||
|
return False
|
||
|
raise # pragma: no cover
|
||
|
|
||
|
|
||
|
def in_ipython() -> bool:
|
||
|
"""
|
||
|
Check whether we're in an ipython environment, including jupyter notebooks.
|
||
|
"""
|
||
|
try:
|
||
|
eval('__IPYTHON__')
|
||
|
except NameError:
|
||
|
return False
|
||
|
else: # pragma: no cover
|
||
|
return True
|
||
|
|
||
|
|
||
|
def is_valid_identifier(identifier: str) -> bool:
|
||
|
"""
|
||
|
Checks that a string is a valid identifier and not a Python keyword.
|
||
|
:param identifier: The identifier to test.
|
||
|
:return: True if the identifier is valid.
|
||
|
"""
|
||
|
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||
|
|
||
|
|
||
|
KeyType = TypeVar('KeyType')
|
||
|
|
||
|
|
||
|
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
|
||
|
updated_mapping = mapping.copy()
|
||
|
for updating_mapping in updating_mappings:
|
||
|
for k, v in updating_mapping.items():
|
||
|
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||
|
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||
|
else:
|
||
|
updated_mapping[k] = v
|
||
|
return updated_mapping
|
||
|
|
||
|
|
||
|
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
|
||
|
mapping.update({k: v for k, v in update.items() if v is not None})
|
||
|
|
||
|
|
||
|
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
|
||
|
"""
|
||
|
Return True if two floats are almost equal
|
||
|
"""
|
||
|
return abs(value_1 - value_2) <= delta
|
||
|
|
||
|
|
||
|
def generate_model_signature(
|
||
|
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
|
||
|
) -> 'Signature':
|
||
|
"""
|
||
|
Generate signature for model based on its fields
|
||
|
"""
|
||
|
from inspect import Parameter, Signature, signature
|
||
|
|
||
|
from .config import Extra
|
||
|
|
||
|
present_params = signature(init).parameters.values()
|
||
|
merged_params: Dict[str, Parameter] = {}
|
||
|
var_kw = None
|
||
|
use_var_kw = False
|
||
|
|
||
|
for param in islice(present_params, 1, None): # skip self arg
|
||
|
if param.kind is param.VAR_KEYWORD:
|
||
|
var_kw = param
|
||
|
continue
|
||
|
merged_params[param.name] = param
|
||
|
|
||
|
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||
|
allow_names = config.allow_population_by_field_name
|
||
|
for field_name, field in fields.items():
|
||
|
param_name = field.alias
|
||
|
if field_name in merged_params or param_name in merged_params:
|
||
|
continue
|
||
|
elif not is_valid_identifier(param_name):
|
||
|
if allow_names and is_valid_identifier(field_name):
|
||
|
param_name = field_name
|
||
|
else:
|
||
|
use_var_kw = True
|
||
|
continue
|
||
|
|
||
|
# TODO: replace annotation with actual expected types once #1055 solved
|
||
|
kwargs = {'default': field.default} if not field.required else {}
|
||
|
merged_params[param_name] = Parameter(
|
||
|
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
|
||
|
)
|
||
|
|
||
|
if config.extra is Extra.allow:
|
||
|
use_var_kw = True
|
||
|
|
||
|
if var_kw and use_var_kw:
|
||
|
# Make sure the parameter for extra kwargs
|
||
|
# does not have the same name as a field
|
||
|
default_model_signature = [
|
||
|
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||
|
('data', Parameter.VAR_KEYWORD),
|
||
|
]
|
||
|
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||
|
# if this is the standard model signature, use extra_data as the extra args name
|
||
|
var_kw_name = 'extra_data'
|
||
|
else:
|
||
|
# else start from var_kw
|
||
|
var_kw_name = var_kw.name
|
||
|
|
||
|
# generate a name that's definitely unique
|
||
|
while var_kw_name in fields:
|
||
|
var_kw_name += '_'
|
||
|
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||
|
|
||
|
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||
|
|
||
|
|
||
|
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
|
||
|
from .main import BaseModel
|
||
|
|
||
|
try:
|
||
|
model_cls = obj.__pydantic_model__ # type: ignore
|
||
|
except AttributeError:
|
||
|
model_cls = obj
|
||
|
|
||
|
if not issubclass(model_cls, BaseModel):
|
||
|
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
|
||
|
return model_cls
|
||
|
|
||
|
|
||
|
def to_camel(string: str) -> str:
|
||
|
return ''.join(word.capitalize() for word in string.split('_'))
|
||
|
|
||
|
|
||
|
def to_lower_camel(string: str) -> str:
|
||
|
if len(string) >= 1:
|
||
|
pascal_string = to_camel(string)
|
||
|
return pascal_string[0].lower() + pascal_string[1:]
|
||
|
return string.lower()
|
||
|
|
||
|
|
||
|
T = TypeVar('T')
|
||
|
|
||
|
|
||
|
def unique_list(
|
||
|
input_list: Union[List[T], Tuple[T, ...]],
|
||
|
*,
|
||
|
name_factory: Callable[[T], str] = str,
|
||
|
) -> List[T]:
|
||
|
"""
|
||
|
Make a list unique while maintaining order.
|
||
|
We update the list if another one with the same name is set
|
||
|
(e.g. root validator overridden in subclass)
|
||
|
"""
|
||
|
result: List[T] = []
|
||
|
result_names: List[str] = []
|
||
|
for v in input_list:
|
||
|
v_name = name_factory(v)
|
||
|
if v_name not in result_names:
|
||
|
result_names.append(v_name)
|
||
|
result.append(v)
|
||
|
else:
|
||
|
result[result_names.index(v_name)] = v
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
class PyObjectStr(str):
|
||
|
"""
|
||
|
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||
|
representation of something that valid (or pseudo-valid) python.
|
||
|
"""
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return str(self)
|
||
|
|
||
|
|
||
|
class Representation:
|
||
|
"""
|
||
|
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
|
||
|
|
||
|
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
|
||
|
of objects.
|
||
|
"""
|
||
|
|
||
|
__slots__: Tuple[str, ...] = tuple()
|
||
|
|
||
|
def __repr_args__(self) -> 'ReprArgs':
|
||
|
"""
|
||
|
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||
|
|
||
|
Can either return:
|
||
|
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||
|
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||
|
"""
|
||
|
attrs = ((s, getattr(self, s)) for s in self.__slots__)
|
||
|
return [(a, v) for a, v in attrs if v is not None]
|
||
|
|
||
|
def __repr_name__(self) -> str:
|
||
|
"""
|
||
|
Name of the instance's class, used in __repr__.
|
||
|
"""
|
||
|
return self.__class__.__name__
|
||
|
|
||
|
def __repr_str__(self, join_str: str) -> str:
|
||
|
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||
|
|
||
|
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
|
||
|
"""
|
||
|
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
|
||
|
"""
|
||
|
yield self.__repr_name__() + '('
|
||
|
yield 1
|
||
|
for name, value in self.__repr_args__():
|
||
|
if name is not None:
|
||
|
yield name + '='
|
||
|
yield fmt(value)
|
||
|
yield ','
|
||
|
yield 0
|
||
|
yield -1
|
||
|
yield ')'
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self.__repr_str__(' ')
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||
|
|
||
|
def __rich_repr__(self) -> 'RichReprResult':
|
||
|
"""Get fields for Rich library"""
|
||
|
for name, field_repr in self.__repr_args__():
|
||
|
if name is None:
|
||
|
yield field_repr
|
||
|
else:
|
||
|
yield name, field_repr
|
||
|
|
||
|
|
||
|
class GetterDict(Representation):
|
||
|
"""
|
||
|
Hack to make object's smell just enough like dicts for validate_model.
|
||
|
|
||
|
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ('_obj',)
|
||
|
|
||
|
def __init__(self, obj: Any):
|
||
|
self._obj = obj
|
||
|
|
||
|
def __getitem__(self, key: str) -> Any:
|
||
|
try:
|
||
|
return getattr(self._obj, key)
|
||
|
except AttributeError as e:
|
||
|
raise KeyError(key) from e
|
||
|
|
||
|
def get(self, key: Any, default: Any = None) -> Any:
|
||
|
return getattr(self._obj, key, default)
|
||
|
|
||
|
def extra_keys(self) -> Set[Any]:
|
||
|
"""
|
||
|
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
|
||
|
"""
|
||
|
return set()
|
||
|
|
||
|
def keys(self) -> List[Any]:
|
||
|
"""
|
||
|
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
|
||
|
dictionaries.
|
||
|
"""
|
||
|
return list(self)
|
||
|
|
||
|
def values(self) -> List[Any]:
|
||
|
return [self[k] for k in self]
|
||
|
|
||
|
def items(self) -> Iterator[Tuple[str, Any]]:
|
||
|
for k in self:
|
||
|
yield k, self.get(k)
|
||
|
|
||
|
def __iter__(self) -> Iterator[str]:
|
||
|
for name in dir(self._obj):
|
||
|
if not name.startswith('_'):
|
||
|
yield name
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return sum(1 for _ in self)
|
||
|
|
||
|
def __contains__(self, item: Any) -> bool:
|
||
|
return item in self.keys()
|
||
|
|
||
|
def __eq__(self, other: Any) -> bool:
|
||
|
return dict(self) == dict(other.items())
|
||
|
|
||
|
def __repr_args__(self) -> 'ReprArgs':
|
||
|
return [(None, dict(self))]
|
||
|
|
||
|
def __repr_name__(self) -> str:
|
||
|
return f'GetterDict[{display_as_type(self._obj)}]'
|
||
|
|
||
|
|
||
|
class ValueItems(Representation):
|
||
|
"""
|
||
|
Class for more convenient calculation of excluded or included fields on values.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ('_items', '_type')
|
||
|
|
||
|
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
|
||
|
items = self._coerce_items(items)
|
||
|
|
||
|
if isinstance(value, (list, tuple)):
|
||
|
items = self._normalize_indexes(items, len(value))
|
||
|
|
||
|
self._items: 'MappingIntStrAny' = items
|
||
|
|
||
|
def is_excluded(self, item: Any) -> bool:
|
||
|
"""
|
||
|
Check if item is fully excluded.
|
||
|
|
||
|
:param item: key or index of a value
|
||
|
"""
|
||
|
return self.is_true(self._items.get(item))
|
||
|
|
||
|
def is_included(self, item: Any) -> bool:
|
||
|
"""
|
||
|
Check if value is contained in self._items
|
||
|
|
||
|
:param item: key or index of value
|
||
|
"""
|
||
|
return item in self._items
|
||
|
|
||
|
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
|
||
|
"""
|
||
|
:param e: key or index of element on value
|
||
|
:return: raw values for element if self._items is dict and contain needed element
|
||
|
"""
|
||
|
|
||
|
item = self._items.get(e)
|
||
|
return item if not self.is_true(item) else None
|
||
|
|
||
|
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
|
||
|
"""
|
||
|
:param items: dict or set of indexes which will be normalized
|
||
|
:param v_length: length of sequence indexes of which will be
|
||
|
|
||
|
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||
|
{0: True, 2: True, 3: True}
|
||
|
>>> self._normalize_indexes({'__all__': True}, 4)
|
||
|
{0: True, 1: True, 2: True, 3: True}
|
||
|
"""
|
||
|
|
||
|
normalized_items: 'DictIntStrAny' = {}
|
||
|
all_items = None
|
||
|
for i, v in items.items():
|
||
|
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
|
||
|
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||
|
if i == '__all__':
|
||
|
all_items = self._coerce_value(v)
|
||
|
continue
|
||
|
if not isinstance(i, int):
|
||
|
raise TypeError(
|
||
|
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||
|
'expected integer keys or keyword "__all__"'
|
||
|
)
|
||
|
normalized_i = v_length + i if i < 0 else i
|
||
|
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||
|
|
||
|
if not all_items:
|
||
|
return normalized_items
|
||
|
if self.is_true(all_items):
|
||
|
for i in range(v_length):
|
||
|
normalized_items.setdefault(i, ...)
|
||
|
return normalized_items
|
||
|
for i in range(v_length):
|
||
|
normalized_item = normalized_items.setdefault(i, {})
|
||
|
if not self.is_true(normalized_item):
|
||
|
normalized_items[i] = self.merge(all_items, normalized_item)
|
||
|
return normalized_items
|
||
|
|
||
|
@classmethod
|
||
|
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||
|
"""
|
||
|
Merge a ``base`` item with an ``override`` item.
|
||
|
|
||
|
Both ``base`` and ``override`` are converted to dictionaries if possible.
|
||
|
Sets are converted to dictionaries with the sets entries as keys and
|
||
|
Ellipsis as values.
|
||
|
|
||
|
Each key-value pair existing in ``base`` is merged with ``override``,
|
||
|
while the rest of the key-value pairs are updated recursively with this function.
|
||
|
|
||
|
Merging takes place based on the "union" of keys if ``intersect`` is
|
||
|
set to ``False`` (default) and on the intersection of keys if
|
||
|
``intersect`` is set to ``True``.
|
||
|
"""
|
||
|
override = cls._coerce_value(override)
|
||
|
base = cls._coerce_value(base)
|
||
|
if override is None:
|
||
|
return base
|
||
|
if cls.is_true(base) or base is None:
|
||
|
return override
|
||
|
if cls.is_true(override):
|
||
|
return base if intersect else override
|
||
|
|
||
|
# intersection or union of keys while preserving ordering:
|
||
|
if intersect:
|
||
|
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||
|
else:
|
||
|
merge_keys = list(base) + [k for k in override if k not in base]
|
||
|
|
||
|
merged: 'DictIntStrAny' = {}
|
||
|
for k in merge_keys:
|
||
|
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||
|
if merged_item is not None:
|
||
|
merged[k] = merged_item
|
||
|
|
||
|
return merged
|
||
|
|
||
|
@staticmethod
|
||
|
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
|
||
|
if isinstance(items, Mapping):
|
||
|
pass
|
||
|
elif isinstance(items, AbstractSet):
|
||
|
items = dict.fromkeys(items, ...)
|
||
|
else:
|
||
|
class_name = getattr(items, '__class__', '???')
|
||
|
assert_never(
|
||
|
items,
|
||
|
f'Unexpected type of exclude value {class_name}',
|
||
|
)
|
||
|
return items
|
||
|
|
||
|
@classmethod
|
||
|
def _coerce_value(cls, value: Any) -> Any:
|
||
|
if value is None or cls.is_true(value):
|
||
|
return value
|
||
|
return cls._coerce_items(value)
|
||
|
|
||
|
@staticmethod
|
||
|
def is_true(v: Any) -> bool:
|
||
|
return v is True or v is ...
|
||
|
|
||
|
def __repr_args__(self) -> 'ReprArgs':
|
||
|
return [(None, self._items)]
|
||
|
|
||
|
|
||
|
class ClassAttribute:
|
||
|
"""
|
||
|
Hide class attribute from its instances
|
||
|
"""
|
||
|
|
||
|
__slots__ = (
|
||
|
'name',
|
||
|
'value',
|
||
|
)
|
||
|
|
||
|
def __init__(self, name: str, value: Any) -> None:
|
||
|
self.name = name
|
||
|
self.value = value
|
||
|
|
||
|
def __get__(self, instance: Any, owner: Type[Any]) -> None:
|
||
|
if instance is None:
|
||
|
return self.value
|
||
|
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||
|
|
||
|
|
||
|
path_types = {
|
||
|
'is_dir': 'directory',
|
||
|
'is_file': 'file',
|
||
|
'is_mount': 'mount point',
|
||
|
'is_symlink': 'symlink',
|
||
|
'is_block_device': 'block device',
|
||
|
'is_char_device': 'char device',
|
||
|
'is_fifo': 'FIFO',
|
||
|
'is_socket': 'socket',
|
||
|
}
|
||
|
|
||
|
|
||
|
def path_type(p: 'Path') -> str:
|
||
|
"""
|
||
|
Find out what sort of thing a path is.
|
||
|
"""
|
||
|
assert p.exists(), 'path does not exist'
|
||
|
for method, name in path_types.items():
|
||
|
if getattr(p, method)():
|
||
|
return name
|
||
|
|
||
|
return 'unknown'
|
||
|
|
||
|
|
||
|
Obj = TypeVar('Obj')
|
||
|
|
||
|
|
||
|
def smart_deepcopy(obj: Obj) -> Obj:
|
||
|
"""
|
||
|
Return type as is for immutable built-in types
|
||
|
Use obj.copy() for built-in empty collections
|
||
|
Use copy.deepcopy() for non-empty collections and unknown objects
|
||
|
"""
|
||
|
|
||
|
obj_type = obj.__class__
|
||
|
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||
|
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||
|
try:
|
||
|
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||
|
# faster way for empty collections, no need to copy its members
|
||
|
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
|
||
|
except (TypeError, ValueError, RuntimeError):
|
||
|
# do we really dare to catch ALL errors? Seems a bit risky
|
||
|
pass
|
||
|
|
||
|
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||
|
|
||
|
|
||
|
def is_valid_field(name: str) -> bool:
|
||
|
if not name.startswith('_'):
|
||
|
return True
|
||
|
return ROOT_KEY == name
|
||
|
|
||
|
|
||
|
DUNDER_ATTRIBUTES = {
|
||
|
'__annotations__',
|
||
|
'__classcell__',
|
||
|
'__doc__',
|
||
|
'__module__',
|
||
|
'__orig_bases__',
|
||
|
'__orig_class__',
|
||
|
'__qualname__',
|
||
|
}
|
||
|
|
||
|
|
||
|
def is_valid_private_name(name: str) -> bool:
|
||
|
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
|
||
|
|
||
|
|
||
|
_EMPTY = object()
|
||
|
|
||
|
|
||
|
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
|
||
|
"""
|
||
|
Check that the items of `left` are the same objects as those in `right`.
|
||
|
|
||
|
>>> a, b = object(), object()
|
||
|
>>> all_identical([a, b, a], [a, b, a])
|
||
|
True
|
||
|
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||
|
False
|
||
|
"""
|
||
|
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||
|
if left_item is not right_item:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
|
||
|
"""
|
||
|
Helper to make sure that we have covered all possible types.
|
||
|
|
||
|
This is mostly useful for ``mypy``, docs:
|
||
|
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
|
||
|
"""
|
||
|
raise TypeError(msg)
|
||
|
|
||
|
|
||
|
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
|
||
|
"""Validate that all aliases are the same and if that's the case return the alias"""
|
||
|
unique_aliases = set(all_aliases)
|
||
|
if len(unique_aliases) > 1:
|
||
|
raise ConfigError(
|
||
|
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
|
||
|
)
|
||
|
return unique_aliases.pop()
|
||
|
|
||
|
|
||
|
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
|
||
|
"""
|
||
|
Get alias and all valid values in the `Literal` type of the discriminator field
|
||
|
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
|
||
|
"""
|
||
|
is_root_model = getattr(tp, '__custom_root_type__', False)
|
||
|
|
||
|
if get_origin(tp) is Annotated:
|
||
|
tp = get_args(tp)[0]
|
||
|
|
||
|
if hasattr(tp, '__pydantic_model__'):
|
||
|
tp = tp.__pydantic_model__
|
||
|
|
||
|
if is_union(get_origin(tp)):
|
||
|
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
|
||
|
return alias, tuple(v for values in all_values for v in values)
|
||
|
elif is_root_model:
|
||
|
union_type = tp.__fields__[ROOT_KEY].type_
|
||
|
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
|
||
|
|
||
|
if len(set(all_values)) > 1:
|
||
|
raise ConfigError(
|
||
|
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
|
||
|
)
|
||
|
|
||
|
return alias, all_values[0]
|
||
|
|
||
|
else:
|
||
|
try:
|
||
|
t_discriminator_type = tp.__fields__[discriminator_key].type_
|
||
|
except AttributeError as e:
|
||
|
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
|
||
|
except KeyError as e:
|
||
|
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
|
||
|
|
||
|
if not is_literal_type(t_discriminator_type):
|
||
|
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
|
||
|
|
||
|
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
|
||
|
|
||
|
|
||
|
def _get_union_alias_and_all_values(
|
||
|
union_type: Type[Any], discriminator_key: str
|
||
|
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
|
||
|
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
|
||
|
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
|
||
|
all_aliases, all_values = zip(*zipped_aliases_values)
|
||
|
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
|