From 50fd2c5b48ffe611b5c4feb24f26f7202217faab Mon Sep 17 00:00:00 2001 From: primal100 Date: Thu, 11 Apr 2019 23:13:57 +0100 Subject: [PATCH] Validators for dataclasses (#454) * Added validators for dataclass, fix #415 * Added dataclass validators * Added dataclass validators * Updated docs for added validating to dataclass * Updated docs for added validating to dataclass * Fixed line endings * Set __validators__ type to Mapping instead of Dict * Update History * Use __mro__ instead of __bases__ for gather_validators * Fix PR number * Fix issue.rst header underline * Fix HISTORY.rst merge conflict * Fix utils.py merge conflict * fix utils.py * Rebase and other fixes * Fix rebase and other issues * Change history * Remove unnecessary lines in main.py * Rebase * Update history * Rename ModelType to ModelOrDc * Added inheritance replace test * More consiste dataclass validator tests * fix history. * Remove Optional ModelOrDc Type * Fix ModelOrDc --- HISTORY.rst | 3 +- .../{datetime.py => datetime_example.py} | 0 docs/examples/validators_dataclass.py | 20 ++++ docs/index.rst | 12 +- pydantic/class_validators.py | 11 +- pydantic/dataclasses.py | 10 +- pydantic/fields.py | 18 ++- pydantic/main.py | 13 ++- pydantic/types.py | 3 + tests/test_validators_dataclass.py | 110 ++++++++++++++++++ 10 files changed, 176 insertions(+), 24 deletions(-) rename docs/examples/{datetime.py => datetime_example.py} (100%) create mode 100644 docs/examples/validators_dataclass.py create mode 100755 tests/test_validators_dataclass.py diff --git a/HISTORY.rst b/HISTORY.rst index 2dfe00b..a9c7441 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -6,7 +6,8 @@ History v0.x (xxxx-xx-xx) .................. * fix handling ``ForwardRef`` in sub-types, like ``Union``, #464 by @tiangolo -* fix secret serialization, #465 by @Atheuz +* fix secret serialization, #465 by @atheuz +* Support custom validators for dataclasses, #454 by @primal100 v0.23 (2019-04-04) .................. diff --git a/docs/examples/datetime.py b/docs/examples/datetime_example.py similarity index 100% rename from docs/examples/datetime.py rename to docs/examples/datetime_example.py diff --git a/docs/examples/validators_dataclass.py b/docs/examples/validators_dataclass.py new file mode 100644 index 0000000..f8bec34 --- /dev/null +++ b/docs/examples/validators_dataclass.py @@ -0,0 +1,20 @@ +from datetime import datetime + +from pydantic import validator +from pydantic.dataclasses import dataclass + + +@dataclass +class DemoDataclass: + ts: datetime = None + + @validator('ts', pre=True, always=True) + def set_ts_now(cls, v): + return v or datetime.now() + + +print(DemoDataclass()) +# > DemoDataclass(ts=datetime.datetime(2019, 4, 2, 18, 1, 46, 66149)) + +print(DemoDataclass(ts='2017-11-08T14:00')) +# > DemoDataclass ts=datetime.datetime(2017, 11, 8, 14, 0) \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 1784732..9125bd7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -147,8 +147,6 @@ Since version ``v0.17`` nested dataclasses are supported both in dataclasses and Dataclasses attributes can be populated by tuples, dictionaries or instances of that dataclass. -Currently validators don't work with dataclasses, if it's something you want please create an issue on github. - Choices ....... @@ -222,6 +220,14 @@ to set a dynamic default value. You'll often want to use this together with ``pre`` since otherwise the with ``always=True`` *pydantic* would try to validate the default ``None`` which would cause an error. +Dataclass Validators +~~~~~~~~~~~~~~~~~~~~ + +Validators also work in Dataclasses. + +.. literalinclude:: examples/validators_dataclass.py + +(This script is complete, it should run "as is") Field Checks ~~~~~~~~~~~~ @@ -434,7 +440,7 @@ types: * ``[±]P[DD]T[HH]H[MM]M[SS]S`` (ISO 8601 format for timedelta) -.. literalinclude:: examples/datetime.py +.. literalinclude:: examples/datetime_example.py Exotic Types diff --git a/pydantic/class_validators.py b/pydantic/class_validators.py index b71ce29..ca0023b 100644 --- a/pydantic/class_validators.py +++ b/pydantic/class_validators.py @@ -1,3 +1,4 @@ +from collections import ChainMap from dataclasses import dataclass from functools import wraps from inspect import Signature, signature @@ -9,10 +10,11 @@ from .errors import ConfigError from .utils import AnyCallable, in_ipython if TYPE_CHECKING: # pragma: no cover - from .main import BaseConfig, BaseModel + from .main import BaseConfig from .fields import Field + from .types import ModelOrDc - ValidatorCallable = Callable[[Optional[Type[BaseModel]], Any, Dict[str, Any], Field, Type[BaseConfig]], Any] + ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], Field, Type[BaseConfig]], Any] @dataclass @@ -211,3 +213,8 @@ def _generic_validator_basic(validator: AnyCallable, sig: Signature, args: Set[s else: # args == {'values', 'field', 'config'} return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) + + +def gather_validators(type_: 'ModelOrDc') -> Dict[str, classmethod]: + all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) + return {k: v for k, v in all_attributes.items() if hasattr(v, '__validator_config')} diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 4b6af0f..9f43940 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -1,6 +1,7 @@ import dataclasses from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Optional, Type, Union +from .class_validators import gather_validators from .error_wrappers import ValidationError from .errors import DataclassTypeError from .fields import Required @@ -24,7 +25,7 @@ if TYPE_CHECKING: # pragma: no cover def _pydantic_post_init(self: 'DataclassType') -> None: - d = validate_model(self.__pydantic_model__, self.__dict__) + d = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__) object.__setattr__(self, '__dict__', d) object.__setattr__(self, '__initialised__', True) if self.__post_init_original__: @@ -50,7 +51,7 @@ def setattr_validate_assignment(self: 'DataclassType', name: str, value: Any) -> if self.__initialised__: d = dict(self.__dict__) d.pop(name) - value, error_ = self.__pydantic_model__.__fields__[name].validate(value, d, loc=name) + value, error_ = self.__pydantic_model__.__fields__[name].validate(value, d, loc=name, cls=self.__class__) if error_: raise ValidationError([error_]) @@ -79,7 +80,10 @@ def _process_class( } cls.__post_init_original__ = post_init_original - cls.__pydantic_model__ = create_model(cls.__name__, __config__=config, __module__=_cls.__module__, **fields) + validators = gather_validators(cls) + cls.__pydantic_model__ = create_model( + cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields + ) cls.__initialised__ = False cls.__validate__ = classmethod(_validate_dataclass) diff --git a/pydantic/fields.py b/pydantic/fields.py index 70a846b..96c8356 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: # pragma: no cover from .error_wrappers import ErrorList from .main import BaseConfig, BaseModel # noqa: F401 from .schema import Schema # noqa: F401 + from .types import ModelOrDc # noqa: F401 ValidatorsList = List[ValidatorCallable] ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] @@ -261,7 +262,7 @@ class Field: return [make_generic_validator(f) for f in v_funcs if f] def validate( - self, v: Any, values: Dict[str, Any], *, loc: 'LocType', cls: Optional[Type['BaseModel']] = None + self, v: Any, values: Dict[str, Any], *, loc: 'LocType', cls: Optional['ModelOrDc'] = None ) -> 'ValidateReturn': if self.allow_none and not self.validate_always and v is None: return None, None @@ -300,7 +301,7 @@ class Field: return v, ErrorWrapper(exc, loc=loc, config=self.model_config) def _validate_sequence_like( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional[Type['BaseModel']] + self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': """ Validate sequence-like containers: lists, tuples, sets and generators @@ -343,7 +344,7 @@ class Field: return converted, None def _validate_tuple( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional[Type['BaseModel']] + self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': e: Optional[Exception] = None if not sequence_like(v): @@ -372,7 +373,7 @@ class Field: return tuple(result), None def _validate_mapping( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional[Type['BaseModel']] + self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': try: v_iter = dict_validator(v) @@ -400,7 +401,7 @@ class Field: return result, None def _validate_singleton( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional[Type['BaseModel']] + self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': if self.sub_fields: errors = [] @@ -415,12 +416,7 @@ class Field: return self._apply_validators(v, values, loc, cls, self.validators) def _apply_validators( - self, - v: Any, - values: Dict[str, Any], - loc: 'LocType', - cls: Optional[Type['BaseModel']], - validators: 'ValidatorsList', + self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' ) -> 'ValidateReturn': for validator in validators: try: diff --git a/pydantic/main.py b/pydantic/main.py index 0331e07..29efe97 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -45,7 +45,8 @@ from .utils import ( ) if TYPE_CHECKING: # pragma: no cover - from .types import CallableGenerator + from .dataclasses import DataclassType # noqa: F401 + from .types import CallableGenerator, ModelOrDc from .class_validators import ValidatorListDict AnyGenerator = Generator[Any, None, None] @@ -513,12 +514,13 @@ class BaseModel(metaclass=MetaModel): return ret -def create_model( +def create_model( # noqa: C901 (ignore complexity) model_name: str, *, __config__: Type[BaseConfig] = None, __base__: Type[BaseModel] = None, __module__: Optional[str] = None, + __validators__: Dict[str, classmethod] = None, **field_definitions: Any, ) -> BaseModel: """ @@ -526,6 +528,7 @@ def create_model( :param model_name: name of the created model :param __config__: config class to use for the new model :param __base__: base class for the new model to inherit from + :param __validators__: a dict of method names and @validator class methods :param **field_definitions: fields of the model (or extra fields if a base is supplied) in the format `=(, )` or `= eg. `foobar=(str, ...)` or `foobar=123` """ @@ -558,6 +561,8 @@ def create_model( fields[f_name] = f_value namespace: 'DictStrAny' = {'__annotations__': annotations, '__module__': __module__} + if __validators__: + namespace.update(__validators__) namespace.update(fields) if __config__: namespace['Config'] = inherit_config(__config__, BaseConfig) @@ -566,7 +571,7 @@ def create_model( def validate_model( # noqa: C901 (ignore complexity) - model: Union[BaseModel, Type[BaseModel]], input_data: 'DictStrAny', raise_exc: bool = True + model: Union[BaseModel, Type[BaseModel]], input_data: 'DictStrAny', raise_exc: bool = True, cls: 'ModelOrDc' = None ) -> Union['DictStrAny', Tuple['DictStrAny', Optional[ValidationError]]]: """ validate data against a model. @@ -601,7 +606,7 @@ def validate_model( # noqa: C901 (ignore complexity) elif check_extra: names_used.add(field.name if using_name else field.alias) - v_, errors_ = field.validate(value, values, loc=field.alias, cls=model.__class__) # type: ignore + v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls or model.__class__) # type: ignore if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): diff --git a/pydantic/types.py b/pydantic/types.py index 791608d..05eec58 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -89,9 +89,12 @@ NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, in if TYPE_CHECKING: # pragma: no cover + from .dataclasses import DataclassType # noqa: F401 + from .main import BaseModel # noqa: F401 from .utils import AnyCallable CallableGenerator = Generator[AnyCallable, None, None] + ModelOrDc = Type[Union['BaseModel', 'DataclassType']] class StrictStr(str): diff --git a/tests/test_validators_dataclass.py b/tests/test_validators_dataclass.py new file mode 100755 index 0000000..7e408c4 --- /dev/null +++ b/tests/test_validators_dataclass.py @@ -0,0 +1,110 @@ +from dataclasses import asdict, is_dataclass +from typing import List + +import pytest + +from pydantic import ValidationError, validator +from pydantic.dataclasses import dataclass + + +def test_simple(): + @dataclass + class MyDataclass: + a: str + + @validator('a') + def change_a(cls, v): + return v + ' changed' + + assert MyDataclass(a='this is foobar good').a == 'this is foobar good changed' + + +def test_validate_whole(): + @dataclass + class MyDataclass: + a: List[int] + + @validator('a', whole=True, pre=True) + def check_a1(cls, v): + v.append('123') + return v + + @validator('a', whole=True) + def check_a2(cls, v): + v.append(456) + return v + + assert MyDataclass(a=[1, 2]).a == [1, 2, 123, 456] + + +def test_validate_multiple(): + # also test TypeError + @dataclass + class MyDataclass: + a: str + b: str + + @validator('a', 'b') + def check_a_and_b(cls, v, field, **kwargs): + if len(v) < 4: + raise TypeError(f'{field.alias} is too short') + return v + 'x' + + assert asdict(MyDataclass(a='1234', b='5678')) == {'a': '1234x', 'b': '5678x'} + + with pytest.raises(ValidationError) as exc_info: + MyDataclass(a='x', b='x') + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'a is too short', 'type': 'type_error'}, + {'loc': ('b',), 'msg': 'b is too short', 'type': 'type_error'}, + ] + + +def test_classmethod(): + @dataclass + class MyDataclass: + a: str + + @validator('a') + def check_a(cls, v): + assert cls is MyDataclass and is_dataclass(MyDataclass) + return v + + m = MyDataclass(a='this is foobar good') + assert m.a == 'this is foobar good' + m.check_a('x') + + +def test_validate_parent(): + @dataclass + class Parent: + a: str + + @validator('a') + def change_a(cls, v): + return v + ' changed' + + @dataclass + class Child(Parent): + pass + + assert Parent(a='this is foobar good').a == 'this is foobar good changed' + assert Child(a='this is foobar good').a == 'this is foobar good changed' + + +def test_inheritance_replace(): + @dataclass + class Parent: + a: int + + @validator('a') + def add_to_a(cls, v): + return v + 1 + + @dataclass + class Child(Parent): + @validator('a') + def add_to_a(cls, v): + return v + 5 + + assert Child(a=0).a == 5