diff --git a/HISTORY.rst b/HISTORY.rst index 2223b96..491253e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -7,6 +7,7 @@ v0.6.0 (2017-11-XX) ................... * assignment validation #94, thanks petroswork! * JSON in environment variables for complex types, #96 +* add ``validator`` decorators for complex validation, #97 v0.5.0 (2017-10-23) ................... diff --git a/docs/examples/validators_complex.py b/docs/examples/validators_complex.py new file mode 100644 index 0000000..7e4cc5d --- /dev/null +++ b/docs/examples/validators_complex.py @@ -0,0 +1,54 @@ +import json +from typing import List, Set + +from pydantic import BaseModel, ValidationError, validator + + +class DemoModel(BaseModel): + numbers: List[int] = [] + people: List[str] = [] + + @validator('people', 'numbers', pre=True, whole=True) + def json_decode(cls, v): + if isinstance(v, str): + try: + return json.loads(v) + except ValueError: + pass + return v + + @validator('numbers') + def check_numbers_low(cls, v): + if v > 4: + raise ValueError(f'number to large {v} > 4') + return v + + @validator('numbers', whole=True) + def check_sum_numbers_low(cls, v): + if sum(v) > 8: + raise ValueError(f'sum of numbers greater than 8') + return v + + +print(DemoModel(numbers='[1, 1, 2, 2]')) +# > DemoModel numbers=[1, 1, 2, 2] people=[] + +try: + DemoModel(numbers='[1, 2, 5]') +except ValidationError as e: + print(e) +""" +error validating input +numbers: + number to large 5 > 4 (error_type=ValueError track=int index=2) +""" + +try: + DemoModel(numbers=[3, 3, 3]) +except ValidationError as e: + print(e) +""" +error validating input +numbers: + sum of numbers greater than 8 (error_type=ValueError track=int) +""" diff --git a/docs/examples/validators_simple.py b/docs/examples/validators_simple.py new file mode 100644 index 0000000..0739599 --- /dev/null +++ b/docs/examples/validators_simple.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, ValidationError, validator + + +class UserModel(BaseModel): + name: str + password1: str + password2: str + + @validator('name') + def name_must_contain_space(cls, v): + if ' ' not in v: + raise ValueError('must contain a space') + return v.title() + + @validator('password2') + def passwords_match(cls, v, values, **kwargs): + if 'password1' in values and v != values['password1']: + raise ValueError('passwords do not match') + return v + + +print(UserModel(name='samuel colvin', password1='zxcvbn', password2='zxcvbn')) +# > UserModel name='Samuel Colvin' password1='zxcvbn' password2='zxcvbn' + +try: + UserModel(name='samuel', password1='zxcvbn', password2='zxcvbn2') +except ValidationError as e: + print(e) +""" +2 errors validating input +name: + must contain a space (error_type=ValueError track=str) +password2: + passwords do not match (error_type=ValueError track=str) +""" diff --git a/docs/index.rst b/docs/index.rst index 4a6fa70..ac63cf3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -106,6 +106,44 @@ pydantic uses python's standard ``enum`` classes to define choices. (This script is complete, it should run "as is") +Validators +.......... + +Custom validation and complex relationships between objects can achieved using the ``validator`` decorator. + +.. literalinclude:: examples/validators_simple.py + +(This script is complete, it should run "as is") + +A few things to note on validators: + +* validators are "class methods", the first value they receive here will be the ``UserModel`` not an instance + of ``UserModel`` +* their signature can with be ``(cls, value)`` or ``(cls, value, *, values, config, field)`` +* validator should either return the new value or raise a ``ValueError`` or ``TypeError`` +* where validators rely on other values, you should be aware that: + + - Validation is done in the order fields are defined, eg. here ``password2`` has access to ``password1`` + (and ``name``), but ``password1`` does not have access to ``password2``. You should heed the warning + :ref:`below ` regarding field order and required fields. + + - If validation fails on another field (or that field is missing) it will not be included in ``values``, hence + ``if 'password1' in values and ...`` in this example. + +Validators can do a few more complex things: + +.. literalinclude:: examples/validators_complex.py + +(This script is complete, it should run "as is") + +A few more things to note: + +* a single validator can apply to multiple fields +* the keyword argument ``pre`` will cause validators to be called prior to other validation +* the ``whole`` keyword argument will mean validators are applied to entire objects rather than individual values + (applies for complex typing objects eg. ``List``, ``Dict``, ``Set``) + + Recursive Models ................ @@ -227,6 +265,8 @@ Pydantic provides a few useful optional or union types: If these aren't sufficient you can of course define your own. +.. _usage_mypy_required: + Required Fields and mypy ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pydantic/__init__.py b/pydantic/__init__.py index a6bc7ea..792c5cb 100644 --- a/pydantic/__init__.py +++ b/pydantic/__init__.py @@ -2,7 +2,7 @@ from .env_settings import BaseSettings from .exceptions import * from .fields import Required -from .main import BaseModel +from .main import BaseModel, validator from .parse import Protocol from .types import * from .version import VERSION diff --git a/pydantic/fields.py b/pydantic/fields.py index 0ad116d..43009d9 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -12,6 +12,8 @@ Required: Any = Ellipsis class ValidatorSignature(IntEnum): JUST_VALUE = 1 VALUE_KWARGS = 2 + CLS_JUST_VALUE = 3 + CLS_VALUE_KWARGS = 4 class Shape(IntEnum): @@ -22,8 +24,11 @@ class Shape(IntEnum): class Field: - __slots__ = ('type_', 'key_type_', 'sub_fields', 'key_field', 'validators', 'default', 'required', 'model_config', - 'name', 'alias', 'description', 'info', 'validate_always', 'allow_none', 'shape', 'multipart') + __slots__ = ( + 'type_', 'key_type_', 'sub_fields', 'key_field', 'validators', 'whole_pre_validators', 'whole_post_validators', + 'default', 'required', 'model_config', 'name', 'alias', 'description', 'info', 'validate_always', + 'allow_none', 'shape' + ) def __init__( self, *, @@ -42,16 +47,17 @@ class Field: self.type_: type = type_ self.key_type_: type = None self.validate_always: bool = getattr(self.type_, 'validate_always', False) - self.sub_fields = None + self.sub_fields: List[Field] = None self.key_field: Field = None self.validators = [] + self.whole_pre_validators = None + self.whole_post_validators = None self.default: Any = default self.required: bool = required self.model_config = model_config self.description: str = description self.allow_none: bool = allow_none self.shape: Shape = Shape.SINGLETON - self.multipart = False self.info = {} self._prepare(class_validators or {}) @@ -92,8 +98,7 @@ class Field: self.allow_none = True self._populate_sub_fields(class_validators) - if self.sub_fields is None: - self._populate_validators(class_validators) + self._populate_validators(class_validators) self.info = OrderedDict([ ('type', type_display(self.type_)), @@ -102,7 +107,7 @@ class Field: ]) if self.required: self.info.pop('default') - if self.multipart: + if self.sub_fields: self.info['sub_fields'] = self.sub_fields else: self.info['validators'] = [v[1].__qualname__ for v in self.validators] @@ -113,7 +118,7 @@ class Field: def _populate_sub_fields(self, class_validators): # typing interface is horrible, we have to do some ugly checks - origin = getattr(self.type_, '__origin__', None) + origin = _get_type_origin(self.type_) if origin is None: # field is not "typing" object eg. Union, Dict, List etc. return @@ -134,7 +139,6 @@ class Field: name=f'{self.name}_{type_display(t)}', model_config=self.model_config, ) for t in types_] - self.multipart = True elif issubclass(origin, List): self.type_ = self.type_.__args__[0] self.shape = Shape.LIST @@ -156,8 +160,8 @@ class Field: model_config=self.model_config, ) - if self.sub_fields is None and getattr(self.type_, '__origin__', False): - self.multipart = True + if not self.sub_fields and _get_type_origin(self.type_): + # type_ has been refined eg. as the type of a List and sub_fields needs to be populated self.sub_fields = [self.__class__( type_=self.type_, class_validators=class_validators, @@ -169,46 +173,61 @@ class Field: )] def _populate_validators(self, class_validators): - get_validators = getattr(self.type_, 'get_validators', None) - v_funcs = ( - class_validators.get(f'validate_{self.name}_pre'), + if not self.sub_fields: + get_validators = getattr(self.type_, 'get_validators', None) + v_funcs = ( + *tuple(f for f, pre, whole in class_validators if not whole and pre), + *(get_validators() if get_validators else find_validators(self.type_)), + *tuple(f for f, pre, whole in class_validators if not whole and not pre), + ) + self.validators = self._prep_vals(v_funcs) - *(get_validators() if get_validators else find_validators(self.type_)), + if class_validators: + self.whole_pre_validators = self._prep_vals(f for f, pre, whole in class_validators if whole and pre) + self.whole_post_validators = self._prep_vals(f for f, pre, whole in class_validators if whole and not pre) - class_validators.get(f'validate_{self.name}'), - class_validators.get(f'validate_{self.name}_post'), - ) + def _prep_vals(self, v_funcs): + v = [] for f in v_funcs: if not f or (self.allow_none and f is not_none_validator): continue - self.validators.append(( + v.append(( _get_validator_signature(f), f, )) + return tuple(v) - def validate(self, v, values, index=None): + def validate(self, v, values, index=None, cls=None): if self.allow_none and v is None: return None, None + if self.whole_pre_validators: + v, errors = self._apply_validators(v, values, index, cls, self.whole_pre_validators) + if errors: + return v, errors + if self.shape is Shape.SINGLETON: - return self._validate_singleton(v, values, index) + v, errors = self._validate_singleton(v, values, index, cls) elif self.shape is Shape.MAPPING: - return self._validate_mapping(v, values) + v, errors = self._validate_mapping(v, values, cls) else: # list or set - result, errors = self._validate_sequence(v, values) + v, errors = self._validate_sequence(v, values, cls) if not errors and self.shape is Shape.SET: - return set(result), errors - return result, errors + v = set(v) - def _validate_sequence(self, v, values): + if not errors and self.whole_post_validators: + v, errors = self._apply_validators(v, values, index, cls, self.whole_post_validators) + return v, errors + + def _validate_sequence(self, v, values, cls): result, errors = [], [] try: v_iter = enumerate(v) except TypeError as exc: return v, Error(exc, None, None) for i, v_ in v_iter: - single_result, single_errors = self._validate_singleton(v_, values, i) + single_result, single_errors = self._validate_singleton(v_, values, i, cls) if single_errors: errors.append(single_errors) else: @@ -218,7 +237,7 @@ class Field: else: return result, None - def _validate_mapping(self, v, values): + def _validate_mapping(self, v, values, cls): if isinstance(v, dict): v_iter = v else: @@ -229,11 +248,11 @@ class Field: result, errors = {}, [] for k, v_ in v_iter.items(): - key_result, key_errors = self.key_field.validate(k, values, 'key') + key_result, key_errors = self.key_field.validate(k, values, 'key', cls) if key_errors: errors.append(key_errors) continue - value_result, value_errors = self._validate_singleton(v_, values, k) + value_result, value_errors = self._validate_singleton(v_, values, k, cls) if value_errors: errors.append(value_errors) continue @@ -243,27 +262,34 @@ class Field: else: return result, None - def _validate_singleton(self, v, values, index): - if self.multipart: + def _validate_singleton(self, v, values, index, cls): + if self.sub_fields: errors = [] for field in self.sub_fields: - value, error = field.validate(v, values, index) + value, error = field.validate(v, values, index, cls) if error: errors.append(error) else: return value, None return v, errors[0] if len(self.sub_fields) == 1 else errors else: - for signature, validator in self.validators: - try: - if signature is ValidatorSignature.JUST_VALUE: - v = validator(v) - else: - # ValidatorSignature.VALUE_KWARGS - v = validator(v, values=values, config=self.model_config, field=self) - except (ValueError, TypeError) as exc: - return v, Error(exc, self.type_, index) - return v, None + return self._apply_validators(v, values, index, cls, self.validators) + + def _apply_validators(self, v, values, index, cls, validators): + for signature, validator in validators: + try: + if signature is ValidatorSignature.JUST_VALUE: + v = validator(v) + elif signature is ValidatorSignature.VALUE_KWARGS: + v = validator(v, values=values, config=self.model_config, field=self) + elif signature is ValidatorSignature.CLS_JUST_VALUE: + v = validator(cls, v) + else: + # ValidatorSignature.CLS_VALUE_KWARGS + v = validator(cls, v, values=values, config=self.model_config, field=self) + except (ValueError, TypeError) as exc: + return v, Error(exc, self.type_, index) + return v, None def __repr__(self): return f'' @@ -279,6 +305,7 @@ def _get_validator_signature(validator): try: signature = inspect.signature(validator) except ValueError: + # TODO we should probably have a white list of allowed validators here, rather than assuming # happens on builtins like float return ValidatorSignature.JUST_VALUE @@ -286,15 +313,24 @@ def _get_validator_signature(validator): # 1. we can deal with it before validation begins # 2. (more importantly) it doesn't get confused with a TypeError when executing the validator try: - if len(signature.parameters) == 1: - signature.bind(1) - return ValidatorSignature.JUST_VALUE + if 'cls' in signature._parameters: + if len(signature.parameters) == 2: + signature.bind(object(), 1) + return ValidatorSignature.CLS_JUST_VALUE + else: + signature.bind(object(), 1, values=2, config=3, field=4) + return ValidatorSignature.CLS_VALUE_KWARGS else: - signature.bind(1, values=2, config=3, field=4) - return ValidatorSignature.VALUE_KWARGS + if len(signature.parameters) == 1: + signature.bind(1) + return ValidatorSignature.JUST_VALUE + else: + signature.bind(1, values=2, config=3, field=4) + return ValidatorSignature.VALUE_KWARGS except TypeError as e: raise ConfigError(f'Invalid signature for validator {validator}: {signature}, should be: ' - f'(value) or (value, *, values, config, field)') from e + f'(value) or (value, *, values, config, field) or for class validators ' + f'(cls, value) or (cls, value, *, values, config, field)') from e def _get_field_config(config, name): @@ -302,3 +338,10 @@ def _get_field_config(config, name): if isinstance(field_config, str): field_config = {'alias': field_config} return field_config + + +def _get_type_origin(obj): + """ + Like obj.__class__ or type(obj) but for typing objects + """ + return getattr(obj, '__origin__', None) diff --git a/pydantic/main.py b/pydantic/main.py index 757450d..b709680 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -36,6 +36,20 @@ def inherit_config(self_config, parent_config) -> BaseConfig: TYPE_BLACKLIST = FunctionType, property, type, classmethod, staticmethod +def _extract_validators(namespace): + validators = {} + for var_name, value in namespace.items(): + validator_config = getattr(value, '__validator_config', None) + if validator_config: + fields, *v = validator_config + for field in fields: + if field in validators: + validators[field].append(v) + else: + validators[field] = [v] + return validators + + class MetaModel(type): @classmethod def __prepare__(mcs, *args, **kwargs): @@ -50,9 +64,7 @@ class MetaModel(type): config = inherit_config(base.config, config) config = inherit_config(namespace.get('Config'), config) - class_validators = { - n: f for n, f in namespace.items() if n.startswith('validate_') and isinstance(f, FunctionType) - } + validators = _extract_validators(namespace) for f in fields.values(): f.set_config(config) @@ -65,7 +77,7 @@ class MetaModel(type): name=ann_name, value=..., annotation=ann_type, - class_validators=class_validators, + class_validators=validators.get(ann_name), config=config, ) @@ -75,7 +87,7 @@ class MetaModel(type): name=var_name, value=value, annotation=annotations.get(var_name), - class_validators=class_validators, + class_validators=validators.get(var_name), config=config, ) @@ -227,9 +239,11 @@ class BaseModel(metaclass=MetaModel): values[name] = field.default continue - values[name], errors_ = field.validate(value, values) + v_, errors_ = field.validate(value, values, cls=self.__class__) if errors_: errors[field.alias] = errors_ + else: + values[name] = v_ if (not self.config.ignore_extra) or self.config.allow_extra: extra = input_data.keys() - {f.alias for f in self.__fields__.values()} @@ -287,3 +301,17 @@ class BaseModel(metaclass=MetaModel): def __str__(self): return self.to_string() + + +def validator(*fields, pre=False, whole=False): + """ + Decorate methods on the class indicating that they should be used to validate fields + :param fields: which field(s) the method should be called on + :param pre: whether or not this validator should be called before the standard validators (else after) + :param whole: for complex objects (sets, lists etc.) whether to validate individual elements or the whole object + """ + def dec(f): + f_cls = classmethod(f) + f_cls.__validator_config = fields, f, pre, whole + return f_cls + return dec diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..0f0f84e --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,160 @@ +from typing import List + +import pytest + +from pydantic import BaseModel, ValidationError, validator + + +def test_simple(): + class Model(BaseModel): + a: str + + @validator('a') + def check_a(cls, v): + if 'foobar' not in v: + raise ValueError('"foobar" not found in a') + return v + + assert Model(a='this is foobar good').a == 'this is foobar good' + with pytest.raises(ValidationError) as exc_info: + Model(a='snap') + assert '"foobar" not found in a' in str(exc_info.value) + + +def test_validate_whole(): + class Model(BaseModel): + a: List[int] + + @validator('a', whole=True, pre=True) + def check_a1(cls, v): + v.append('123') + return v + + @validator('a', whole=True) + def check_a2(cls, v): + v.append(456) + return v + + assert Model(a=[1, 2]).a == [1, 2, 123, 456] + + +def test_validate_kwargs(): + class Model(BaseModel): + b: int + a: List[int] + + @validator('a') + def check_a1(cls, v, values, **kwargs): + return v + values['b'] + + assert Model(a=[1, 2], b=6).a == [7, 8] + + +def test_validate_whole_error(): + calls = [] + + class Model(BaseModel): + a: List[int] + + @validator('a', whole=True, pre=True) + def check_a1(cls, v): + calls.append(f'check_a1 {v}') + if 1 in v: + raise ValueError('a1 broken') + v[0] += 1 + return v + + @validator('a', whole=True) + def check_a2(cls, v): + calls.append(f'check_a2 {v}') + if 10 in v: + raise ValueError('a2 broken') + return v + + assert Model(a=[3, 8]).a == [4, 8] + assert calls == ['check_a1 [3, 8]', 'check_a2 [4, 8]'] + calls = [] + with pytest.raises(ValidationError) as exc_info: + Model(a=[1, 3]) + assert 'a1 broken' in str(exc_info.value) + assert calls == ['check_a1 [1, 3]'] + + calls = [] + with pytest.raises(ValidationError) as exc_info: + Model(a=[5, 10]) + assert 'a2 broken' in str(exc_info.value) + assert calls == ['check_a1 [5, 10]', 'check_a2 [6, 10]'] + + +class ValidateAssignmentModel(BaseModel): + a: int = 4 + b: str = ... + + @validator('b') + def b_length(cls, v, values, **kwargs): + if 'a' in values and len(v) < values['a']: + raise ValueError('b too short') + return v + + class Config: + validate_assignment = True + + +def test_validating_assignment_ok(): + p = ValidateAssignmentModel(b='hello') + assert p.b == 'hello' + + +def test_validating_assignment_fail(): + with pytest.raises(ValidationError): + ValidateAssignmentModel(a=10, b='hello') + + p = ValidateAssignmentModel(b='hello') + with pytest.raises(ValidationError): + p.b = 'x' + + +def test_validating_assignment_values(): + with pytest.raises(ValidationError) as exc_info: + ValidateAssignmentModel(a='x', b='xx') + assert """\ +error validating input +a: + invalid literal for int() with base 10: 'x' (error_type=ValueError track=int)""" == str(exc_info.value) + + +def test_validate_multiple(): + # also test TypeError + class Model(BaseModel): + a: str + b: str + + @validator('a', 'b') + def check_a_and_b(cls, v, field, **kwargs): + if len(v) < 4: + raise TypeError(f'{field.alias} is too short') + return v + 'x' + + assert Model(a='1234', b='5678').values() == {'a': '1234x', 'b': '5678x'} + with pytest.raises(ValidationError) as exc_info: + Model(a='x', b='x') + assert """\ +2 errors validating input +a: + a is too short (error_type=TypeError track=str) +b: + b is too short (error_type=TypeError track=str)""" == str(exc_info.value) + + +def test_classmethod(): + class Model(BaseModel): + a: str + + @validator('a') + def check_a(cls, v): + assert cls is Model + return v + + m = Model(a='this is foobar good') + assert m.a == 'this is foobar good' + m.check_a('x')