From 16263bafea12e0089e22a6facc2f36ef6be0c71d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 18 Sep 2019 11:38:21 +0100 Subject: [PATCH] None behaviour (#803) * tweaks to None behaviour * prevent sub_fields for Optional fields by default * rewrite None validation * rename whole > each_item on validators * cleanup processing of the Json type * fix schema coverage and cleanup * tweak validate_model * change and docs * fix validators on optional fields * coverage * remove is_none_validator * minor performance improvements to ErrorWrapper * fix coverage * fix PaymentCardNumber * undo schema changes, fix * tweak validators --- changes/803-samuelcolvin.rst | 2 + docs/examples/validators_pre_item.py | 56 ++++++++++++ docs/examples/validators_pre_whole.py | 54 ------------ docs/index.rst | 18 ++-- pydantic/class_validators.py | 27 ++++-- pydantic/color.py | 3 - pydantic/error_wrappers.py | 71 ++++++++------- pydantic/fields.py | 121 +++++++++++++++----------- pydantic/main.py | 13 ++- pydantic/networks.py | 3 +- pydantic/schema.py | 8 +- pydantic/types.py | 29 ++---- pydantic/validators.py | 54 +++++------- tests/test_edge_cases.py | 54 ++++++++++-- tests/test_errors.py | 17 +--- tests/test_generics.py | 8 +- tests/test_schema.py | 25 +++++- tests/test_types.py | 22 ++++- tests/test_validators.py | 74 ++++++++++++---- tests/test_validators_dataclass.py | 6 +- 20 files changed, 396 insertions(+), 269 deletions(-) create mode 100644 changes/803-samuelcolvin.rst create mode 100644 docs/examples/validators_pre_item.py delete mode 100644 docs/examples/validators_pre_whole.py diff --git a/changes/803-samuelcolvin.rst b/changes/803-samuelcolvin.rst new file mode 100644 index 0000000..76eda87 --- /dev/null +++ b/changes/803-samuelcolvin.rst @@ -0,0 +1,2 @@ +Improve handling of ``None`` and ``Optional``, replace ``whole`` with ``each_item`` (inverse meaning, default ``False``) +on validators. diff --git a/docs/examples/validators_pre_item.py b/docs/examples/validators_pre_item.py new file mode 100644 index 0000000..4a941a9 --- /dev/null +++ b/docs/examples/validators_pre_item.py @@ -0,0 +1,56 @@ +from typing import List +from pydantic import BaseModel, ValidationError, validator + +class DemoModel(BaseModel): + square_numbers: List[int] = [] + cube_numbers: List[int] = [] + + @validator('*', pre=True) # '*' is same as 'cube_numbers', 'square_numbers' here + def split_str(cls, v): + if isinstance(v, str): + return v.split('|') + return v + + @validator('cube_numbers', 'square_numbers') + def check_sum(cls, v): + if sum(v) > 42: + raise ValueError(f'sum of numbers greater than 42') + return v + + @validator('square_numbers', each_item=True) + def check_squares(cls, v): + assert v ** 0.5 % 1 == 0, f'{v} is not a square number' + return v + + @validator('cube_numbers', each_item=True) + def check_cubes(cls, v): + # 64 ** (1 / 3) == 3.9999999999999996! this is not a good way of checking cubes + assert v ** (1 / 3) % 1 == 0, f'{v} is not a cubed number' + return v + +print(DemoModel(square_numbers=[1, 4, 9])) +# > DemoModel square_numbers=[1, 4, 9] cube_numbers=[] +print(DemoModel(square_numbers='1|4|16')) +# > DemoModel square_numbers=[1, 4, 16] cube_numbers=[] +print(DemoModel(square_numbers=[16], cube_numbers=[8, 27])) +# > DemoModel square_numbers=[16] cube_numbers=[8, 27] + +try: + DemoModel(square_numbers=[1, 4, 2]) +except ValidationError as e: + print(e) +""" +1 validation error for DemoModel +square_numbers -> 2 + 2 is not a square number (type=assertion_error) +""" + +try: + DemoModel(cube_numbers=[27, 27]) +except ValidationError as e: + print(e) +""" +1 validation error for DemoModel +cube_numbers + sum of numbers greater than 42 (type=value_error) +""" diff --git a/docs/examples/validators_pre_whole.py b/docs/examples/validators_pre_whole.py deleted file mode 100644 index a81bace..0000000 --- a/docs/examples/validators_pre_whole.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -from typing import List - -from pydantic import BaseModel, ValidationError, validator - - -class DemoModel(BaseModel): - numbers: List[int] = [] - people: List[str] = [] - - @validator('people', 'numbers', pre=True, whole=True) - def json_decode(cls, v): - if isinstance(v, str): - try: - return json.loads(v) - except ValueError: - pass - return v - - @validator('numbers') - def check_numbers_low(cls, v): - if v > 4: - raise ValueError(f'number too large {v} > 4') - return v - - @validator('numbers', whole=True) - def check_sum_numbers_low(cls, v): - if sum(v) > 8: - raise ValueError(f'sum of numbers greater than 8') - return v - - -print(DemoModel(numbers='[1, 1, 2, 2]')) -# > DemoModel numbers=[1, 1, 2, 2] people=[] - -try: - DemoModel(numbers='[1, 2, 5]') -except ValidationError as e: - print(e) -""" -1 validation error -numbers -> 2 - number too large 5 > 4 (type=value_error) -""" - -try: - DemoModel(numbers=[3, 3, 3]) -except ValidationError as e: - print(e) -""" -1 validation error -numbers - sum of numbers greater than 8 (type=value_error) -""" diff --git a/docs/index.rst b/docs/index.rst index 75231fb..d54c2dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -198,7 +198,7 @@ A few things to note on validators: * validators are "class methods", the first value they receive here will be the ``UserModel`` not an instance of ``UserModel`` -* their signature can be ``(cls, value)`` or ``(cls, value, values, config, field)``. As of **v0.20**, any subset of +* their signature can be ``(cls, value)`` or ``(cls, value, values, config, field)``. Any subset of ``values``, ``config`` and ``field`` is also permitted, eg. ``(cls, value, field)``, however due to the way validators are inspected, the variadic key word argument ("``**kwargs``") **must** be called ``kwargs``. * validators should either return the new value or raise a ``ValueError``, ``TypeError``, or ``AssertionError`` @@ -226,18 +226,12 @@ A few things to note on validators: (Within each group fields remain in the order they were defined.) -.. note:: - - From ``v0.18`` onwards validators are not called on keys of dictionaries. If you wish to validate keys, - use ``whole`` (see below). - - -Pre and Whole Validators -~~~~~~~~~~~~~~~~~~~~~~~~ +Pre and per-item validators +~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validators can do a few more complex things: -.. literalinclude:: examples/validators_pre_whole.py +.. literalinclude:: examples/validators_pre_item.py (This script is complete, it should run "as is") @@ -246,8 +240,8 @@ A few more things to note: * a single validator can apply to multiple fields, either by defining multiple fields or by the special value ``'*'`` which means that validator will be called for all fields. * the keyword argument ``pre`` will cause validators to be called prior to other validation -* the ``whole`` keyword argument will mean validators are applied to entire objects rather than individual values - (applies for complex typing objects eg. ``List``, ``Dict``, ``Set``) +* the ``each_item`` keyword argument will mean validators are applied to individual values + (eg. of ``List``, ``Dict``, ``Set`` etc.) not the whole object Validate Always ~~~~~~~~~~~~~~~ diff --git a/pydantic/class_validators.py b/pydantic/class_validators.py index 17044b3..bd223ab 100644 --- a/pydantic/class_validators.py +++ b/pydantic/class_validators.py @@ -1,3 +1,4 @@ +import warnings from collections import ChainMap from functools import wraps from inspect import Signature, signature @@ -18,10 +19,12 @@ if TYPE_CHECKING: # pragma: no cover class Validator: - def __init__(self, func: AnyCallable, pre: bool, whole: bool, always: bool, check_fields: bool): + __slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields' + + def __init__(self, func: AnyCallable, pre: bool, each_item: bool, always: bool, check_fields: bool): self.func = func self.pre = pre - self.whole = whole + self.each_item = each_item self.always = always self.check_fields = check_fields @@ -30,13 +33,19 @@ _FUNCS: Set[str] = set() def validator( - *fields: str, pre: bool = False, whole: bool = False, always: bool = False, check_fields: bool = True + *fields: str, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool = True, + whole: bool = None, ) -> Callable[[AnyCallable], classmethod]: """ Decorate methods on the class indicating that they should be used to validate fields :param fields: which field(s) the method should be called on :param pre: whether or not this validator should be called before the standard validators (else after) - :param whole: for complex objects (sets, lists etc.) whether to validate individual elements or the whole object + :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the + whole object :param always: whether this method and other validators should be called even if the value is missing :param check_fields: whether to check that the fields actually exist on the model """ @@ -48,6 +57,14 @@ def validator( "E.g. usage should be `@validator('', ...)`" ) + if whole is not None: + warnings.warn( + 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead', + DeprecationWarning, + ) + assert each_item is False, '"each_item" and "whole" conflict, remove "whole"' + each_item = not whole + def dec(f: AnyCallable) -> classmethod: # avoid validators with duplicated names since without this validators can be overwritten silently # which generally isn't the intended behaviour, don't run in ipython - see #312 @@ -59,7 +76,7 @@ def validator( f_cls = classmethod(f) f_cls.__validator_config = ( # type: ignore fields, - Validator(func=f, pre=pre, whole=whole, always=always, check_fields=check_fields), + Validator(func=f, pre=pre, each_item=each_item, always=always, check_fields=check_fields), ) return f_cls diff --git a/pydantic/color.py b/pydantic/color.py index 6f9ed5a..314ebab 100644 --- a/pydantic/color.py +++ b/pydantic/color.py @@ -12,8 +12,6 @@ import re from colorsys import hls_to_rgb, rgb_to_hls from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast -from pydantic.validators import not_none_validator - from .errors import ColorError from .utils import almost_equal_floats @@ -184,7 +182,6 @@ class Color: @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield cls def __str__(self) -> str: diff --git a/pydantic/error_wrappers.py b/pydantic/error_wrappers.py index edbe319..5f1d98a 100644 --- a/pydantic/error_wrappers.py +++ b/pydantic/error_wrappers.py @@ -5,39 +5,26 @@ if TYPE_CHECKING: # pragma: no cover from .main import BaseConfig # noqa: F401 from .types import ModelOrDc # noqa: F401 + Loc = Tuple[Union[int, str], ...] + __all__ = 'ErrorWrapper', 'ValidationError' class ErrorWrapper: - __slots__ = 'exc', 'loc' + __slots__ = 'exc', '_loc' - def __init__(self, exc: Exception, *, loc: Union[Tuple[str, ...], str]) -> None: + def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None: self.exc = exc - self.loc: Tuple[str, ...] = loc if isinstance(loc, tuple) else (loc,) # type: ignore + self._loc = loc - def dict(self, config: Type['BaseConfig'], *, loc_prefix: Optional[Tuple[str, ...]] = None) -> Dict[str, Any]: - loc = self.loc if loc_prefix is None else loc_prefix + self.loc - - type_ = get_exc_type(type(self.exc)) - msg_template = config.error_msg_templates.get(type_) or getattr(self.exc, 'msg_template', None) - ctx = getattr(self.exc, 'ctx', None) - if msg_template: - if ctx: - msg: str = msg_template.format(**ctx) - else: - msg = msg_template + def loc_tuple(self) -> 'Loc': + if isinstance(self._loc, tuple): + return self._loc else: - msg = str(self.exc) - - d: Dict[str, Any] = {'loc': loc, 'msg': msg, 'type': type_} - - if ctx is not None: - d['ctx'] = ctx - - return d + return (self._loc,) def __repr__(self) -> str: - return f'' + return f'' # ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper] @@ -92,24 +79,46 @@ def _display_error_type_and_ctx(error: Dict[str, Any]) -> str: def flatten_errors( - errors: Sequence[Any], config: Type['BaseConfig'], *, loc: Optional[Tuple[str, ...]] = None + errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None ) -> Generator[Dict[str, Any], None, None]: for error in errors: if isinstance(error, ErrorWrapper): - if isinstance(error.exc, ValidationError): - if loc is not None: - error_loc = loc + error.loc - else: - error_loc = error.loc - yield from flatten_errors(error.exc.raw_errors, config, loc=error_loc) + + if loc: + error_loc = loc + error.loc_tuple() else: - yield error.dict(config, loc_prefix=loc) + error_loc = error.loc_tuple() + + if isinstance(error.exc, ValidationError): + yield from flatten_errors(error.exc.raw_errors, config, error_loc) + else: + yield error_dict(error.exc, config, error_loc) elif isinstance(error, list): yield from flatten_errors(error, config, loc=loc) else: raise RuntimeError(f'Unknown error object: {error}') +def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> Dict[str, Any]: + type_ = get_exc_type(type(exc)) + msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None) + ctx = getattr(exc, 'ctx', None) + if msg_template: + if ctx: + msg = msg_template.format(**ctx) + else: + msg = msg_template + else: + msg = str(exc) + + d: Dict[str, Any] = {'loc': loc, 'msg': msg, 'type': type_} + + if ctx is not None: + d['ctx'] = ctx + + return d + + _EXC_TYPE_CACHE: Dict[Type[Exception], str] = {} diff --git a/pydantic/fields.py b/pydantic/fields.py index b3337e0..6bf6a6f 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -21,10 +21,11 @@ from typing import ( from . import errors as errors_ from .class_validators import Validator, make_generic_validator from .error_wrappers import ErrorWrapper +from .errors import NoneIsNotAllowedError from .types import Json, JsonWrapper from .typing import AnyCallable, AnyType, Callable, ForwardRef, display_as_type, is_literal_type, literal_values from .utils import lenient_issubclass, sequence_like -from .validators import NoneType, constant_validator, dict_validator, find_validators +from .validators import constant_validator, dict_validator, find_validators, validate_json try: from typing_extensions import Literal @@ -32,6 +33,7 @@ except ImportError: Literal = None # type: ignore Required: Any = Ellipsis +NoneType = type(None) if TYPE_CHECKING: # pragma: no cover from .class_validators import ValidatorCallable # noqa: F401 @@ -42,7 +44,7 @@ if TYPE_CHECKING: # pragma: no cover ValidatorsList = List[ValidatorCallable] ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] - LocType = Union[Tuple[str, ...], str] + LocStr = Union[Tuple[Union[int, str], ...], str] # used to be an enum but changed to int's for small performance improvement as less access overhead @@ -51,7 +53,7 @@ SHAPE_LIST = 2 SHAPE_SET = 3 SHAPE_MAPPING = 4 SHAPE_TUPLE = 5 -SHAPE_TUPLE_ELLIPS = 6 +SHAPE_TUPLE_ELLIPSIS = 6 SHAPE_SEQUENCE = 7 SHAPE_FROZENSET = 8 @@ -62,8 +64,8 @@ class Field: 'sub_fields', 'key_field', 'validators', - 'whole_pre_validators', - 'whole_post_validators', + 'pre_validators', + 'post_validators', 'default', 'required', 'model_config', @@ -106,8 +108,8 @@ class Field: self.sub_fields: Optional[List[Field]] = None self.key_field: Optional[Field] = None self.validators: 'ValidatorsList' = [] - self.whole_pre_validators: 'ValidatorsList' = [] - self.whole_post_validators: 'ValidatorsList' = [] + self.pre_validators: Optional['ValidatorsList'] = None + self.post_validators: Optional['ValidatorsList'] = None self.parse_json: bool = False self.shape: int = SHAPE_SINGLETON self.prepare() @@ -175,14 +177,17 @@ class Field: if not self.required and self.default is None: self.allow_none = True - self._populate_sub_fields() + self._type_analysis() self._populate_validators() - def _populate_sub_fields(self) -> None: # noqa: C901 (ignore complexity) + def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) # typing interface is horrible, we have to do some ugly checks if lenient_issubclass(self.type_, JsonWrapper): self.type_ = self.type_.inner_type # type: ignore self.parse_json = True + elif lenient_issubclass(self.type_, Json): + self.type_ = Any # type: ignore + self.parse_json = True if self.type_ is Pattern: # python 3.7 only, Pattern is a typing object but without sub fields @@ -203,10 +208,17 @@ class Field: types_ = [] for type_ in self.type_.__args__: # type: ignore if type_ is NoneType: # type: ignore - self.allow_none = True self.required = False + self.allow_none = True + continue types_.append(type_) - self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_] + + if len(types_) == 1: + self.type_ = types_[0] + # re-run to correctly interpret the new self.type_ + self._type_analysis() + else: + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_] return if issubclass(origin, Tuple): # type: ignore @@ -215,7 +227,7 @@ class Field: for i, t in enumerate(self.type_.__args__): # type: ignore if t is Ellipsis: self.type_ = self.type_.__args__[0] # type: ignore - self.shape = SHAPE_TUPLE_ELLIPS + self.shape = SHAPE_TUPLE_ELLIPSIS return self.sub_fields.append(self._create_sub_type(t, f'{self.name}_{i}')) return @@ -226,7 +238,7 @@ class Field: if get_validators: self.class_validators.update( { - f'list_{i}': Validator(validator, whole=True, pre=True, always=True, check_fields=False) + f'list_{i}': Validator(validator, each_item=False, pre=True, always=True, check_fields=False) for i, validator in enumerate(get_validators()) } ) @@ -260,7 +272,7 @@ class Field: return self.__class__( type_=type_, name=name, - class_validators=None if for_keys else {k: v for k, v in self.class_validators.items() if not v.whole}, + class_validators=None if for_keys else {k: v for k, v in self.class_validators.items() if v.each_item}, model_config=self.model_config, ) @@ -269,43 +281,51 @@ class Field: if not self.sub_fields: get_validators = getattr(self.type_, '__get_validators__', None) v_funcs = ( - *[v.func for v in class_validators_ if not v.whole and v.pre], + *[v.func for v in class_validators_ if v.each_item and v.pre], *(get_validators() if get_validators else list(find_validators(self.type_, self.model_config))), - *[v.func for v in class_validators_ if not v.whole and not v.pre], + *[v.func for v in class_validators_ if v.each_item and not v.pre], ) self.validators = self._prep_vals(v_funcs) # Add const validator - if self.schema is not None and self.schema.const: - self.whole_pre_validators = self._prep_vals([constant_validator]) + self.pre_validators = [] + self.post_validators = [] + if self.schema and self.schema.const: + self.pre_validators = [make_generic_validator(constant_validator)] if class_validators_: - self.whole_pre_validators.extend(self._prep_vals(v.func for v in class_validators_ if v.whole and v.pre)) - self.whole_post_validators = self._prep_vals(v.func for v in class_validators_ if v.whole and not v.pre) + self.pre_validators += self._prep_vals(v.func for v in class_validators_ if not v.each_item and v.pre) + self.post_validators = self._prep_vals(v.func for v in class_validators_ if not v.each_item and not v.pre) + + if self.parse_json: + self.pre_validators.append(make_generic_validator(validate_json)) + + self.pre_validators = self.pre_validators or None + self.post_validators = self.post_validators or None @staticmethod def _prep_vals(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList': return [make_generic_validator(f) for f in v_funcs if f] def validate( - self, v: Any, values: Dict[str, Any], *, loc: 'LocType', cls: Optional['ModelOrDc'] = None + self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None ) -> 'ValidateReturn': - if self.allow_none and not self.validate_always and v is None: - return None, None - loc = loc if isinstance(loc, tuple) else (loc,) - - if v is not None and self.parse_json: - v, error = self._validate_json(v, loc) - if error: - return v, error - - errors: Optional['ErrorList'] = None - if self.whole_pre_validators: - v, errors = self._apply_validators(v, values, loc, cls, self.whole_pre_validators) + errors: Optional['ErrorList'] + if self.pre_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.pre_validators) if errors: return v, errors + if v is None: + if self.allow_none: + if self.post_validators: + return self._apply_validators(v, values, loc, cls, self.post_validators) + else: + return None, None + else: + return v, ErrorWrapper(NoneIsNotAllowedError(), loc) + if self.shape == SHAPE_SINGLETON: v, errors = self._validate_singleton(v, values, loc, cls) elif self.shape == SHAPE_MAPPING: @@ -313,21 +333,15 @@ class Field: elif self.shape == SHAPE_TUPLE: v, errors = self._validate_tuple(v, values, loc, cls) else: - # sequence, list, tuple, set, generator + # sequence, list, set, generator, tuple with ellipsis, frozen set v, errors = self._validate_sequence_like(v, values, loc, cls) - if not errors and self.whole_post_validators: - v, errors = self._apply_validators(v, values, loc, cls, self.whole_post_validators) + if not errors and self.post_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.post_validators) return v, errors - def _validate_json(self, v: Any, loc: Tuple[str, ...]) -> Tuple[Optional[Any], Optional[ErrorWrapper]]: - try: - return Json.validate(v), None - except (ValueError, TypeError) as exc: - return v, ErrorWrapper(exc, loc=loc) - def _validate_sequence_like( # noqa: C901 (ignore complexity) - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': """ Validate sequence-like containers: lists, tuples, sets and generators @@ -344,8 +358,9 @@ class Field: e = errors_.FrozenSetError() else: e = errors_.SequenceError() - return v, ErrorWrapper(e, loc=loc) + return v, ErrorWrapper(e, loc) + loc = loc if isinstance(loc, tuple) else (loc,) result = [] errors: List[ErrorList] = [] for i, v_ in enumerate(v): @@ -365,7 +380,7 @@ class Field: converted = set(result) elif self.shape == SHAPE_FROZENSET: converted = frozenset(result) - elif self.shape == SHAPE_TUPLE_ELLIPS: + elif self.shape == SHAPE_TUPLE_ELLIPSIS: converted = tuple(result) elif self.shape == SHAPE_SEQUENCE: if isinstance(v, tuple): @@ -377,7 +392,7 @@ class Field: return converted, None def _validate_tuple( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': e: Optional[Exception] = None if not sequence_like(v): @@ -388,8 +403,9 @@ class Field: e = errors_.TupleLengthError(actual_length=actual_length, expected_length=expected_length) if e: - return v, ErrorWrapper(e, loc=loc) + return v, ErrorWrapper(e, loc) + loc = loc if isinstance(loc, tuple) else (loc,) result = [] errors: List[ErrorList] = [] for i, (v_, field) in enumerate(zip(v, self.sub_fields)): # type: ignore @@ -406,13 +422,14 @@ class Field: return tuple(result), None def _validate_mapping( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': try: v_iter = dict_validator(v) except TypeError as exc: - return v, ErrorWrapper(exc, loc=loc) + return v, ErrorWrapper(exc, loc) + loc = loc if isinstance(loc, tuple) else (loc,) result, errors = {}, [] for k, v_ in v_iter.items(): v_loc = *loc, '__key__' @@ -434,7 +451,7 @@ class Field: return result, None def _validate_singleton( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'] + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] ) -> 'ValidateReturn': if self.sub_fields: errors = [] @@ -449,13 +466,13 @@ class Field: return self._apply_validators(v, values, loc, cls, self.validators) def _apply_validators( - self, v: Any, values: Dict[str, Any], loc: 'LocType', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' ) -> 'ValidateReturn': for validator in validators: try: v = validator(cls, v, values, self, self.model_config) except (ValueError, TypeError, AssertionError) as exc: - return v, ErrorWrapper(exc, loc=loc) + return v, ErrorWrapper(exc, loc) return v, None def include_in_schema(self) -> bool: diff --git a/pydantic/main.py b/pydantic/main.py index 1c1c72d..60319d5 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -684,12 +684,13 @@ def validate_model( # noqa: C901 (ignore complexity) fields_set = set() config = model.__config__ check_extra = config.extra is not Extra.ignore + cls_ = cls or model.__class__ for name, field in model.__fields__.items(): if type(field.type_) == ForwardRef: raise ConfigError( f'field "{field.name}" not yet prepared so type is still a ForwardRef, ' - f'you might need to call {model.__class__.__name__}.update_forward_refs().' + f'you might need to call {cls_.__name__}.update_forward_refs().' ) value = input_data.get(field.alias, _missing) @@ -702,7 +703,13 @@ def validate_model( # noqa: C901 (ignore complexity) if field.required: errors.append(ErrorWrapper(MissingError(), loc=field.alias)) continue - value = deepcopy(field.default) + + if field.default is None: + # deepcopy is quite slow on None + value = None + else: + value = deepcopy(field.default) + if not model.__config__.validate_all and not field.validate_always: values[name] = value continue @@ -711,7 +718,7 @@ def validate_model( # noqa: C901 (ignore complexity) if check_extra: names_used.add(field.name if using_name else field.alias) - v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls or model.__class__) # type: ignore + v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_) # type: ignore if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): diff --git a/pydantic/networks.py b/pydantic/networks.py index ce9431f..2080803 100644 --- a/pydantic/networks.py +++ b/pydantic/networks.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Set, Tuple, Ty from . import errors from .utils import change_exception -from .validators import constr_length_validator, not_none_validator, str_validator +from .validators import constr_length_validator, str_validator if TYPE_CHECKING: # pragma: no cover from .fields import Field @@ -141,7 +141,6 @@ class AnyUrl(str): @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield cls.validate @classmethod diff --git a/pydantic/schema.py b/pydantic/schema.py index c68c087..3d87798 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -680,6 +680,7 @@ field_class_to_schema_enum_enabled: Tuple[Tuple[Any, Dict[str, Any]], ...] = ( (Color, {'type': 'string', 'format': 'color'}), ) +json_scheme = {'type': 'string', 'format': 'json-string'} # Order is important, subclasses of Path must go before Path, etc field_class_to_schema_enum_disabled = ( @@ -690,7 +691,7 @@ field_class_to_schema_enum_disabled = ( (date, {'type': 'string', 'format': 'date'}), (time, {'type': 'string', 'format': 'time'}), (timedelta, {'type': 'number', 'format': 'time-delta'}), - (Json, {'type': 'string', 'format': 'json-string'}), + (Json, json_scheme), (IPv4Network, {'type': 'string', 'format': 'ipv4network'}), (IPv6Network, {'type': 'string', 'format': 'ipv6network'}), (IPvAnyNetwork, {'type': 'string', 'format': 'ipvanynetwork'}), @@ -732,7 +733,10 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) known_models=known_models, ) if field.type_ is Any or type(field.type_) == TypeVar: - return {}, definitions, nested_models # no restrictions + if field.parse_json: + return json_scheme, definitions, nested_models + else: + return {}, definitions, nested_models # no restrictions if is_callable_type(field.type_): raise SkipField(f'Callable {field.name} was excluded from schema since JSON schema has no equivalent type.') f_schema: Dict[str, Any] = {} diff --git a/pydantic/types.py b/pydantic/types.py index 0abaa79..a971de5 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -1,4 +1,3 @@ -import json import re from decimal import Decimal from enum import Enum @@ -17,7 +16,6 @@ from .validators import ( decimal_validator, float_validator, int_validator, - not_none_validator, number_multiple_validator, number_size_validator, path_exists_validator, @@ -96,7 +94,6 @@ class ConstrainedBytes(bytes): @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield bytes_validator yield constr_strip_whitespace yield constr_length_validator @@ -155,7 +152,6 @@ class ConstrainedStr(str): @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield strict_str_validator if cls.strict else str_validator yield constr_strip_whitespace yield constr_length_validator @@ -235,11 +231,10 @@ class PyObject: except errors.StrError: raise errors.PyObjectError(error_message='value is neither a valid import path not a valid callable') - if value is not None: - try: - return import_string(value) - except ImportError as e: - raise errors.PyObjectError(error_message=str(e)) + try: + return import_string(value) + except ImportError as e: + raise errors.PyObjectError(error_message=str(e)) class ConstrainedNumberMeta(type): @@ -342,7 +337,6 @@ class ConstrainedDecimal(Decimal, metaclass=ConstrainedNumberMeta): @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield decimal_validator yield number_size_validator yield number_multiple_validator @@ -458,19 +452,7 @@ class JsonMeta(type): class Json(metaclass=JsonMeta): - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield str_validator - yield cls.validate - - @classmethod - def validate(cls, v: Any) -> Any: - try: - return json.loads(v) - except ValueError: - raise errors.JsonError() - except TypeError: - raise errors.JsonTypeError() + pass class SecretStr: @@ -554,7 +536,6 @@ class PaymentCardNumber(str): @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield not_none_validator yield str_validator yield constr_strip_whitespace yield constr_length_validator diff --git a/pydantic/validators.py b/pydantic/validators.py index 388e8e8..7b07c03 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -1,3 +1,4 @@ +import json import re import sys from collections import OrderedDict @@ -39,19 +40,6 @@ if TYPE_CHECKING: # pragma: no cover Number = Union[int, float, Decimal] StrBytes = Union[str, bytes] -NoneType = type(None) - - -def not_none_validator(v: Any) -> Any: - if v is None: - raise errors.NoneIsNotAllowedError() - return v - - -def is_none_validator(v: Any) -> None: - if v is not None: - raise errors.NoneIsAllowedError() - def str_validator(v: Any) -> Optional[str]: if isinstance(v, str): @@ -59,8 +47,6 @@ def str_validator(v: Any) -> Optional[str]: return v.value else: return v - elif v is None: - return None elif isinstance(v, (float, int, Decimal)): # is there anything else we want to add here? If you think so, create an issue. return str(v) @@ -89,12 +75,12 @@ def bytes_validator(v: Any) -> bytes: raise errors.BytesError() -BOOL_FALSE = {False, 0, '0', 'off', 'f', 'false', 'n', 'no'} -BOOL_TRUE = {True, 1, '1', 'on', 't', 'true', 'y', 'yes'} +BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'} +BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'} def bool_validator(v: Any) -> bool: - if isinstance(v, bool): + if v is True or v is False: return v if isinstance(v, bytes): v = v.decode() @@ -109,7 +95,7 @@ def bool_validator(v: Any) -> bool: def int_validator(v: Any) -> int: - if isinstance(v, int) and not isinstance(v, bool): + if isinstance(v, int) and not (v is True or v is False): return v with change_exception(errors.IntegerError, TypeError, ValueError): @@ -410,6 +396,15 @@ def constr_strip_whitespace(v: 'StrBytes', field: 'Field', config: 'BaseConfig') return v +def validate_json(v: Any) -> Any: + try: + return json.loads(v) + except ValueError: + raise errors.JsonError() + except TypeError: + raise errors.JsonTypeError() + + T = TypeVar('T') @@ -451,7 +446,7 @@ class IfConfig: return any(getattr(config, name) not in {None, False} for name in self.config_attr_names) -pattern_validators = [not_none_validator, str_validator, pattern_validator] +pattern_validators = [str_validator, pattern_validator] # order is important here, for example: bool is a subclass of int so has to come first, datetime before date same, # IPv4Interface before IPv4Address, etc _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ @@ -460,7 +455,6 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ ( str, [ - not_none_validator, str_validator, IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), @@ -469,7 +463,6 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ ( bytes, [ - not_none_validator, bytes_validator, IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), @@ -478,7 +471,6 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ (bool, [bool_validator]), (int, [int_validator]), (float, [float_validator]), - (NoneType, [is_none_validator]), # type: ignore (Path, [path_validator]), (datetime, [parse_datetime]), (date, [parse_date]), @@ -490,14 +482,14 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ (tuple, [tuple_validator]), (set, [set_validator]), (frozenset, [frozenset_validator]), - (UUID, [not_none_validator, uuid_validator]), - (Decimal, [not_none_validator, decimal_validator]), - (IPv4Interface, [not_none_validator, ip_v4_interface_validator]), - (IPv6Interface, [not_none_validator, ip_v6_interface_validator]), - (IPv4Address, [not_none_validator, ip_v4_address_validator]), - (IPv6Address, [not_none_validator, ip_v6_address_validator]), - (IPv4Network, [not_none_validator, ip_v4_network_validator]), - (IPv6Network, [not_none_validator, ip_v6_network_validator]), + (UUID, [uuid_validator]), + (Decimal, [decimal_validator]), + (IPv4Interface, [ip_v4_interface_validator]), + (IPv6Interface, [ip_v6_interface_validator]), + (IPv4Address, [ip_v4_address_validator]), + (IPv6Address, [ip_v6_address_validator]), + (IPv4Network, [ip_v4_network_validator]), + (IPv6Network, [ip_v6_network_validator]), ] diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 83dc9eb..12dbcc8 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -16,6 +16,7 @@ from pydantic import ( constr, errors, validate_model, + validator, ) @@ -26,7 +27,6 @@ def test_str_bytes(): m = Model(v='s') assert m.v == 's' assert '' == repr(m.fields['v']) - assert 'not_none_validator' in [v.__qualname__ for v in m.fields['v'].sub_fields[0].validators] m = Model(v=b'b') assert m.v == 'b' @@ -34,8 +34,7 @@ def test_str_bytes(): with pytest.raises(ValidationError) as exc_info: Model(v=None) assert exc_info.value.errors() == [ - {'loc': ('v',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, - {'loc': ('v',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + {'loc': ('v',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'} ] @@ -73,8 +72,7 @@ def test_union_int_str(): with pytest.raises(ValidationError) as exc_info: Model(v=None) assert exc_info.value.errors() == [ - {'loc': ('v',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, - {'loc': ('v',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + {'loc': ('v',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'} ] @@ -274,9 +272,9 @@ def test_list_unions(): with pytest.raises(ValidationError) as exc_info: Model(v=[1, 2, None]) + assert exc_info.value.errors() == [ - {'loc': ('v', 2), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, - {'loc': ('v', 2), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + {'loc': ('v', 2), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'} ] @@ -773,7 +771,6 @@ def test_multiple_errors(): Model(a='foobar') assert exc_info.value.errors() == [ - {'loc': ('a',), 'msg': 'value is not none', 'type': 'type_error.none.allowed'}, {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, {'loc': ('a',), 'msg': 'value is not a valid float', 'type': 'type_error.float'}, {'loc': ('a',), 'msg': 'value is not a valid decimal', 'type': 'type_error.decimal'}, @@ -968,3 +965,44 @@ def test_ignored_type(): b: int assert Model.__fields__.keys() == {'b'} + + +def test_optional_subfields(): + class Model(BaseModel): + a: Optional[int] + + assert Model.__fields__['a'].sub_fields is None + assert Model.__fields__['a'].allow_none is True + + with pytest.raises(ValidationError) as exc_info: + Model(a='foobar') + + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + assert Model().a is None + assert Model(a=None).a is None + assert Model(a=12).a == 12 + + +def test_not_optional_subfields(): + class Model(BaseModel): + a: Optional[int] + + @validator('a') + def check_a(cls, v): + return v + + assert Model.__fields__['a'].sub_fields is None + # assert Model.__fields__['a'].required is True + assert Model.__fields__['a'].allow_none is True + + with pytest.raises(ValidationError) as exc_info: + Model(a='foobar') + + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + assert Model().a is None + assert Model(a=None).a is None + assert Model(a=12).a == 12 diff --git a/tests/test_errors.py b/tests/test_errors.py index 9e1ffb7..cc27c9a 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -38,7 +38,7 @@ def test_interval_validation_error(): class MyModel(BaseModel): foobar: Union[Foo, Bar] - @validator('foobar', pre=True, whole=True) + @validator('foobar', pre=True) def check_action(cls, v): if isinstance(v, dict): model_type = v.get('model_type') @@ -67,7 +67,7 @@ def test_error_on_optional(): class Foobar(BaseModel): foo: Optional[str] = None - @validator('foo', always=True, whole=True) + @validator('foo', always=True, pre=True) def check_foo(cls, v): raise ValueError('custom error') @@ -97,7 +97,6 @@ def test_error_on_optional(): {'loc': ('d',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, {'loc': ('d',), 'msg': 'value is not a valid uuid', 'type': 'type_error.uuid'}, {'loc': ('e', '__key__'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, - {'loc': ('f', 0), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, {'loc': ('f', 0), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, { 'loc': ('g',), @@ -171,14 +170,6 @@ def test_error_on_optional(): "msg": "value is not a valid integer", "type": "type_error.integer" }, - { - "loc": [ - "f", - 0 - ], - "msg": "value is not a valid integer", - "type": "type_error.integer" - }, { "loc": [ "f", @@ -212,7 +203,7 @@ def test_error_on_optional(): ( '__str__', """\ -11 validation errors for Model +10 validation errors for Model a value is not a valid integer (type=type_error.integer) b -> x @@ -227,8 +218,6 @@ d value is not a valid uuid (type=type_error.uuid) e -> __key__ value is not a valid integer (type=type_error.integer) -f -> 0 - value is not a valid integer (type=type_error.integer) f -> 0 none is not an allowed value (type=type_error.none.not_allowed) g diff --git a/tests/test_generics.py b/tests/test_generics.py index d84a34e..4de02c8 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -40,7 +40,7 @@ def test_value_validation(): class Response(GenericModel, Generic[T]): data: T - @validator('data') + @validator('data', each_item=True) def validate_value_nonzero(cls, v): if isinstance(v, dict): return v # ensure v is actually a value of the dict, not the dict itself @@ -330,15 +330,13 @@ def test_generic(): with pytest.raises(ValidationError) as exc_info: Result[Data, Error](data=[Data(number=1, text='a')], error=Error(message='error'), positive_number=1) assert exc_info.value.errors() == [ - {'loc': ('error',), 'msg': 'Must not provide both data and error', 'type': 'value_error'}, - {'loc': ('error',), 'msg': 'value is not none', 'type': 'type_error.none.allowed'}, + {'loc': ('error',), 'msg': 'Must not provide both data and error', 'type': 'value_error'} ] with pytest.raises(ValidationError) as exc_info: Result[Data, Error](data=[Data(number=1, text='a')], error=Error(message='error'), positive_number=1) assert exc_info.value.errors() == [ - {'loc': ('error',), 'msg': 'Must not provide both data and error', 'type': 'value_error'}, - {'loc': ('error',), 'msg': 'value is not none', 'type': 'type_error.none.allowed'}, + {'loc': ('error',), 'msg': 'Must not provide both data and error', 'type': 'value_error'} ] diff --git a/tests/test_schema.py b/tests/test_schema.py index 469b0fc..a59fd74 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -330,13 +330,10 @@ def test_tuple(field_type, expected_schema): base_schema = { 'title': 'Model', 'type': 'object', - 'properties': {'a': {'title': 'A', 'type': 'array', 'items': None}}, + 'properties': {'a': {'title': 'A', 'type': 'array'}}, 'required': ['a'], } - # noinspection PyTypeChecker base_schema['properties']['a']['items'] = expected_schema - if expected_schema is None: - base_schema['properties']['a'].pop('items', None) assert Model.schema() == base_schema @@ -1288,6 +1285,26 @@ def test_optional_dict(): assert Model(something={'foo': 'Bar'}).dict() == {'something': {'foo': 'Bar'}} +def test_optional_validator(): + class Model(BaseModel): + something: Optional[str] + + @validator('something', always=True) + def check_something(cls, v): + assert v is None or 'x' not in v, 'should not contain x' + return v + + assert Model.schema() == { + 'title': 'Model', + 'type': 'object', + 'properties': {'something': {'title': 'Something', 'type': 'string'}}, + } + + assert Model().dict() == {'something': None} + assert Model(something=None).dict() == {'something': None} + assert Model(something='hello').dict() == {'something': 'hello'} + + def test_field_with_validator(): class Model(BaseModel): something: Optional[int] = None diff --git a/tests/test_types.py b/tests/test_types.py index c6f9302..f3ec19f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -42,6 +42,7 @@ from pydantic import ( conlist, constr, create_model, + validator, ) try: @@ -1282,7 +1283,9 @@ def test_path_validation_fails(): with pytest.raises(ValidationError) as exc_info: Model(foo=None) - assert exc_info.value.errors() == [{'loc': ('foo',), 'msg': 'value is not a valid path', 'type': 'type_error.path'}] + assert exc_info.value.errors() == [ + {'loc': ('foo',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'} + ] @pytest.mark.parametrize( @@ -1623,6 +1626,23 @@ def test_json_not_str(): } +def test_json_pre_validator(): + call_count = 0 + + class JsonModel(BaseModel): + json_obj: Json + + @validator('json_obj', pre=True) + def check(cls, v): + assert v == '"foobar"' + nonlocal call_count + call_count += 1 + return v + + assert JsonModel(json_obj='"foobar"').dict() == {'json_obj': 'foobar'} + assert call_count == 1 + + def test_pattern(): class Foobar(BaseModel): pattern: Pattern diff --git a/tests/test_validators.py b/tests/test_validators.py index b13062c..c068249 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -58,12 +58,12 @@ def test_validate_whole(): class Model(BaseModel): a: List[int] - @validator('a', whole=True, pre=True) + @validator('a', pre=True) def check_a1(cls, v): v.append('123') return v - @validator('a', whole=True) + @validator('a') def check_a2(cls, v): v.append(456) return v @@ -76,20 +76,20 @@ def test_validate_kwargs(): b: int a: List[int] - @validator('a') + @validator('a', each_item=True) def check_a1(cls, v, values, **kwargs): return v + values['b'] assert Model(a=[1, 2], b=6).a == [7, 8] -def test_validate_whole_error(): +def test_validate_pre_error(): calls = [] class Model(BaseModel): a: List[int] - @validator('a', whole=True, pre=True) + @validator('a', pre=True) def check_a1(cls, v): calls.append(f'check_a1 {v}') if 1 in v: @@ -97,7 +97,7 @@ def test_validate_whole_error(): v[0] += 1 return v - @validator('a', whole=True) + @validator('a') def check_a2(cls, v): calls.append(f'check_a2 {v}') if 10 in v: @@ -475,22 +475,22 @@ def test_inheritance_new(): assert Child(a=0).a == 6 -def test_no_key_validation(): +def test_validation_each_item(): class Model(BaseModel): foobar: Dict[int, int] - @validator('foobar') + @validator('foobar', each_item=True) def check_foobar(cls, v): return v + 1 assert Model(foobar={1: 1}).foobar == {1: 2} -def test_key_validation_whole(): +def test_key_validation(): class Model(BaseModel): foobar: Dict[int, int] - @validator('foobar', whole=True) + @validator('foobar') def check_foobar(cls, value): return {k + 1: v + 1 for k, v in value.items()} @@ -515,6 +515,23 @@ def test_validator_always_optional(): assert check_calls == 2 +def test_validator_always_pre(): + check_calls = 0 + + class Model(BaseModel): + a: str = None + + @validator('a', always=True, pre=True) + def check_a(cls, v): + nonlocal check_calls + check_calls += 1 + return v or 'default value' + + assert Model(a='y').a == 'y' + assert Model().a == 'default value' + assert check_calls == 2 + + def test_validator_always_post(): class Model(BaseModel): a: str = None @@ -524,15 +541,14 @@ def test_validator_always_post(): return v or 'default value' assert Model(a='y').a == 'y' - with pytest.raises(ValidationError): - Model() + assert Model().a == 'default value' def test_validator_always_post_optional(): class Model(BaseModel): a: Optional[str] = None - @validator('a', always=True) + @validator('a', always=True, pre=True) def check_a(cls, v): return v or 'default value' @@ -560,13 +576,13 @@ def test_datetime_validator(): assert check_calls == 3 -def test_whole_called_once(): +def test_pre_called_once(): check_calls = 0 class Model(BaseModel): a: Tuple[int, int, int] - @validator('a', pre=True, whole=True) + @validator('a', pre=True) def check_a(cls, v): nonlocal check_calls check_calls += 1 @@ -671,3 +687,31 @@ def test_assert_raises_validation_error(): assert exc_info.value.errors() == [ {'loc': ('a',), 'msg': f'invalid a{injected_by_pytest}', 'type': 'assertion_error'} ] + + +def test_optional_validator(): + val_calls = [] + + class Model(BaseModel): + something: Optional[str] + + @validator('something') + def check_something(cls, v): + val_calls.append(v) + return v + + assert Model().dict() == {'something': None} + assert Model(something=None).dict() == {'something': None} + assert Model(something='hello').dict() == {'something': 'hello'} + assert val_calls == [None, 'hello'] + + +def test_whole(): + with pytest.warns(DeprecationWarning, match='The "whole" keyword argument is deprecated'): + + class Model(BaseModel): + x: List[int] + + @validator('x', whole=True) + def check_something(cls, v): + return v diff --git a/tests/test_validators_dataclass.py b/tests/test_validators_dataclass.py index 7e408c4..4da776d 100755 --- a/tests/test_validators_dataclass.py +++ b/tests/test_validators_dataclass.py @@ -19,17 +19,17 @@ def test_simple(): assert MyDataclass(a='this is foobar good').a == 'this is foobar good changed' -def test_validate_whole(): +def test_validate_pre(): @dataclass class MyDataclass: a: List[int] - @validator('a', whole=True, pre=True) + @validator('a', pre=True) def check_a1(cls, v): v.append('123') return v - @validator('a', whole=True) + @validator('a') def check_a2(cls, v): v.append(456) return v