diff --git a/pydantic/decorator.py b/pydantic/decorator.py index d99ab1d..1f901c8 100644 --- a/pydantic/decorator.py +++ b/pydantic/decorator.py @@ -14,20 +14,27 @@ if TYPE_CHECKING: Callable = TypeVar('Callable', bound=AnyCallable) -def validate_arguments(function: 'Callable') -> 'Callable': +def validate_arguments(func: 'Callable' = None, **config_params: Any) -> 'Callable': """ Decorator to validate the arguments passed to a function. """ - vd = ValidatedFunction(function) - @wraps(function) - def wrapper_function(*args: Any, **kwargs: Any) -> Any: - return vd.call(*args, **kwargs) + def validate(_func: 'Callable') -> 'Callable': + vd = ValidatedFunction(_func, **config_params) - wrapper_function.vd = vd # type: ignore - wrapper_function.raw_function = vd.raw_function # type: ignore - wrapper_function.model = vd.model # type: ignore - return cast('Callable', wrapper_function) + @wraps(_func) + def wrapper_function(*args: Any, **kwargs: Any) -> Any: + return vd.call(*args, **kwargs) + + wrapper_function.vd = vd # type: ignore + wrapper_function.raw_function = vd.raw_function # type: ignore + wrapper_function.model = vd.model # type: ignore + return cast('Callable', wrapper_function) + + if func: + return validate(func) + else: + return cast('Callable', validate) ALT_V_ARGS = 'v__args' @@ -36,7 +43,7 @@ V_POSITIONAL_ONLY_NAME = 'v__positional_only' class ValidatedFunction: - def __init__(self, function: 'Callable'): + def __init__(self, function: 'Callable', **config_params: Any): from inspect import signature, Parameter parameters: Mapping[str, Parameter] = signature(function).parameters @@ -100,7 +107,7 @@ class ValidatedFunction: # same with kwargs fields[self.v_kwargs_name] = Dict[Any, Any], None - self.create_model(fields, takes_args, takes_kwargs) + self.create_model(fields, takes_args, takes_kwargs, **config_params) def call(self, *args: Any, **kwargs: Any) -> Any: values = self.build_values(args, kwargs) @@ -170,9 +177,17 @@ class ValidatedFunction: else: return self.raw_function(**d) - def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool) -> None: + def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, **config_params: Any) -> None: pos_args = len(self.arg_mapping) + if TYPE_CHECKING: + + class CustomConfig: + pass + + else: + CustomConfig = type('Config', (), config_params) + class DecoratorBaseModel(BaseModel): @validator(self.v_args_name, check_fields=False, allow_reuse=True) def check_args(cls, v: List[Any]) -> List[Any]: @@ -196,7 +211,7 @@ class ValidatedFunction: keys = ', '.join(map(repr, v)) raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') - class Config: + class Config(CustomConfig): extra = Extra.forbid self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)