You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
265 lines
10 KiB
265 lines
10 KiB
1 year ago
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from . import validator
from .config import Extra
from .errors import ConfigError
from .main import BaseModel, create_model
from .typing import get_all_type_hints
from .utils import to_camel
__all__ = ('validate_arguments',)
from .typing import AnyCallable
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, Type[Any], Dict[str, Any]]
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
Decorator to validate the arguments passed to a function.
def validate(_func: 'AnyCallable') -> 'AnyCallable':
vd = ValidatedFunction(_func, config)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return*args, **kwargs)
wrapper_function.vd = vd # type: ignore
wrapper_function.validate = vd.init_model_instance # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return wrapper_function
if func:
return validate(func)
return validate
ALT_V_ARGS = 'v__args'
ALT_V_KWARGS = 'v__kwargs'
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
class ValidatedFunction:
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
from inspect import Parameter, signature
parameters: Mapping[str, Parameter] = signature(function).parameters
raise ConfigError(
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
self.raw_function = function
self.arg_mapping: Dict[int, str] = {}
self.positional_only_args = set()
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
type_hints = get_all_type_hints(function)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
annotation = type_hints[name]
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_DUPLICATE_KWARGS] = List[str], None
elif p.kind == Parameter.KEYWORD_ONLY:
fields[name] = annotation, default
elif p.kind == Parameter.VAR_POSITIONAL:
self.v_args_name = name
fields[name] = Tuple[annotation, ...], None
takes_args = True
assert p.kind == Parameter.VAR_KEYWORD, p.kind
self.v_kwargs_name = name
fields[name] = Dict[str, annotation], None # type: ignore
takes_kwargs = True
# these checks avoid a clash between "args" and a field with that name
if not takes_args and self.v_args_name in fields:
self.v_args_name = ALT_V_ARGS
# same with "kwargs"
if not takes_kwargs and self.v_kwargs_name in fields:
self.v_kwargs_name = ALT_V_KWARGS
if not takes_args:
# we add the field so validation below can raise the correct exception
fields[self.v_args_name] = List[Any], None
if not takes_kwargs:
# same with kwargs
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs, config)
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
values = self.build_values(args, kwargs)
return self.model(**values)
def call(self, *args: Any, **kwargs: Any) -> Any:
m = self.init_model_instance(*args, **kwargs)
return self.execute(m)
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
values: Dict[str, Any] = {}
if args:
arg_iter = enumerate(args)
while True:
i, a = next(arg_iter)
except StopIteration:
arg_name = self.arg_mapping.get(i)
if arg_name is not None:
values[arg_name] = a
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
var_kwargs: Dict[str, Any] = {}
wrong_positional_args = []
duplicate_kwargs = []
fields_alias = [
for name, field in self.model.__fields__.items()
if name not in (self.v_args_name, self.v_kwargs_name)
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
for k, v in kwargs.items():
if k in non_var_fields or k in fields_alias:
if k in self.positional_only_args:
if k in values:
values[k] = v
var_kwargs[k] = v
if var_kwargs:
values[self.v_kwargs_name] = var_kwargs
if wrong_positional_args:
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
if duplicate_kwargs:
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
return values
def execute(self, m: BaseModel) -> Any:
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
var_kwargs = d.pop(self.v_kwargs_name, {})
if self.v_args_name in d:
args_: List[Any] = []
in_kwargs = False
kwargs = {}
for name, value in d.items():
if in_kwargs:
kwargs[name] = value
elif name == self.v_args_name:
args_ += value
in_kwargs = True
return self.raw_function(*args_, **kwargs, **var_kwargs)
elif self.positional_only_args:
args_ = []
kwargs = {}
for name, value in d.items():
if name in self.positional_only_args:
kwargs[name] = value
return self.raw_function(*args_, **kwargs, **var_kwargs)
return self.raw_function(**d, **var_kwargs)
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
pos_args = len(self.arg_mapping)
class CustomConfig:
if not TYPE_CHECKING: # pragma: no branch
if isinstance(config, dict):
CustomConfig = type('Config', (), config) # noqa: F811
elif config is not None:
CustomConfig = config # noqa: F811
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
raise ConfigError(
'Setting the "fields" and "alias_generator" property on custom Config for '
'@validate_arguments is not yet supported, please remove.'
class DecoratorBaseModel(BaseModel):
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
if takes_args or v is None:
return v
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if takes_kwargs or v is None:
return v
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v.keys()))
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
def check_positional_only(cls, v: Optional[List[str]]) -> None:
if v is None:
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
if v is None:
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'multiple values for argument{plural}: {keys}')
class Config(CustomConfig):
extra = getattr(CustomConfig, 'extra', Extra.forbid)
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)