add configs to validate_arguments (#1378)

* add `configs` to validate_arguments

* simplify `validate_arguments` and add annotation for parameter `configs`

* change double quotes to single quotes

* reformat code

* fix mypy error

* fix mypy 'maximum semantic analysis' error

* rename 'configs' > 'config_params'

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
This commit is contained in:
quantpy
2020-06-27 19:48:10 +08:00
committed by GitHub
parent d122b1dbdc
commit e690f0878e
+28 -13
View File
@@ -14,20 +14,27 @@ if TYPE_CHECKING:
Callable = TypeVar('Callable', bound=AnyCallable)
def validate_arguments(function: 'Callable') -> 'Callable':
def validate_arguments(func: 'Callable' = None, **config_params: Any) -> 'Callable':
"""
Decorator to validate the arguments passed to a function.
"""
vd = ValidatedFunction(function)
@wraps(function)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return vd.call(*args, **kwargs)
def validate(_func: 'Callable') -> 'Callable':
vd = ValidatedFunction(_func, **config_params)
wrapper_function.vd = vd # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return cast('Callable', wrapper_function)
@wraps(_func)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return vd.call(*args, **kwargs)
wrapper_function.vd = vd # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return cast('Callable', wrapper_function)
if func:
return validate(func)
else:
return cast('Callable', validate)
ALT_V_ARGS = 'v__args'
@@ -36,7 +43,7 @@ V_POSITIONAL_ONLY_NAME = 'v__positional_only'
class ValidatedFunction:
def __init__(self, function: 'Callable'):
def __init__(self, function: 'Callable', **config_params: Any):
from inspect import signature, Parameter
parameters: Mapping[str, Parameter] = signature(function).parameters
@@ -100,7 +107,7 @@ class ValidatedFunction:
# same with kwargs
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs)
self.create_model(fields, takes_args, takes_kwargs, **config_params)
def call(self, *args: Any, **kwargs: Any) -> Any:
values = self.build_values(args, kwargs)
@@ -170,9 +177,17 @@ class ValidatedFunction:
else:
return self.raw_function(**d)
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool) -> None:
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, **config_params: Any) -> None:
pos_args = len(self.arg_mapping)
if TYPE_CHECKING:
class CustomConfig:
pass
else:
CustomConfig = type('Config', (), config_params)
class DecoratorBaseModel(BaseModel):
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
def check_args(cls, v: List[Any]) -> List[Any]:
@@ -196,7 +211,7 @@ class ValidatedFunction:
keys = ', '.join(map(repr, v))
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
class Config:
class Config(CustomConfig):
extra = Extra.forbid
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)