Better validators (#97)

* working on improved validators

* full tests for validators

* tweask

* tweaking fields.py

* adding docs

* add history

* fix classmethod validators
This commit is contained in:
Samuel Colvin
2017-11-07 13:06:44 +00:00
committed by GitHub
parent 02dc2f2697
commit dfc5924936
8 changed files with 417 additions and 56 deletions
+1
View File
@@ -7,6 +7,7 @@ v0.6.0 (2017-11-XX)
...................
* assignment validation #94, thanks petroswork!
* JSON in environment variables for complex types, #96
* add ``validator`` decorators for complex validation, #97
v0.5.0 (2017-10-23)
...................
+54
View File
@@ -0,0 +1,54 @@
import json
from typing import List, Set
from pydantic import BaseModel, ValidationError, validator
class DemoModel(BaseModel):
numbers: List[int] = []
people: List[str] = []
@validator('people', 'numbers', pre=True, whole=True)
def json_decode(cls, v):
if isinstance(v, str):
try:
return json.loads(v)
except ValueError:
pass
return v
@validator('numbers')
def check_numbers_low(cls, v):
if v > 4:
raise ValueError(f'number to large {v} > 4')
return v
@validator('numbers', whole=True)
def check_sum_numbers_low(cls, v):
if sum(v) > 8:
raise ValueError(f'sum of numbers greater than 8')
return v
print(DemoModel(numbers='[1, 1, 2, 2]'))
# > DemoModel numbers=[1, 1, 2, 2] people=[]
try:
DemoModel(numbers='[1, 2, 5]')
except ValidationError as e:
print(e)
"""
error validating input
numbers:
number to large 5 > 4 (error_type=ValueError track=int index=2)
"""
try:
DemoModel(numbers=[3, 3, 3])
except ValidationError as e:
print(e)
"""
error validating input
numbers:
sum of numbers greater than 8 (error_type=ValueError track=int)
"""
+35
View File
@@ -0,0 +1,35 @@
from pydantic import BaseModel, ValidationError, validator
class UserModel(BaseModel):
name: str
password1: str
password2: str
@validator('name')
def name_must_contain_space(cls, v):
if ' ' not in v:
raise ValueError('must contain a space')
return v.title()
@validator('password2')
def passwords_match(cls, v, values, **kwargs):
if 'password1' in values and v != values['password1']:
raise ValueError('passwords do not match')
return v
print(UserModel(name='samuel colvin', password1='zxcvbn', password2='zxcvbn'))
# > UserModel name='Samuel Colvin' password1='zxcvbn' password2='zxcvbn'
try:
UserModel(name='samuel', password1='zxcvbn', password2='zxcvbn2')
except ValidationError as e:
print(e)
"""
2 errors validating input
name:
must contain a space (error_type=ValueError track=str)
password2:
passwords do not match (error_type=ValueError track=str)
"""
+40
View File
@@ -106,6 +106,44 @@ pydantic uses python's standard ``enum`` classes to define choices.
(This script is complete, it should run "as is")
Validators
..........
Custom validation and complex relationships between objects can achieved using the ``validator`` decorator.
.. literalinclude:: examples/validators_simple.py
(This script is complete, it should run "as is")
A few things to note on validators:
* validators are "class methods", the first value they receive here will be the ``UserModel`` not an instance
of ``UserModel``
* their signature can with be ``(cls, value)`` or ``(cls, value, *, values, config, field)``
* validator should either return the new value or raise a ``ValueError`` or ``TypeError``
* where validators rely on other values, you should be aware that:
- Validation is done in the order fields are defined, eg. here ``password2`` has access to ``password1``
(and ``name``), but ``password1`` does not have access to ``password2``. You should heed the warning
:ref:`below <usage_mypy_required>` regarding field order and required fields.
- If validation fails on another field (or that field is missing) it will not be included in ``values``, hence
``if 'password1' in values and ...`` in this example.
Validators can do a few more complex things:
.. literalinclude:: examples/validators_complex.py
(This script is complete, it should run "as is")
A few more things to note:
* a single validator can apply to multiple fields
* the keyword argument ``pre`` will cause validators to be called prior to other validation
* the ``whole`` keyword argument will mean validators are applied to entire objects rather than individual values
(applies for complex typing objects eg. ``List``, ``Dict``, ``Set``)
Recursive Models
................
@@ -227,6 +265,8 @@ Pydantic provides a few useful optional or union types:
If these aren't sufficient you can of course define your own.
.. _usage_mypy_required:
Required Fields and mypy
~~~~~~~~~~~~~~~~~~~~~~~~
+1 -1
View File
@@ -2,7 +2,7 @@
from .env_settings import BaseSettings
from .exceptions import *
from .fields import Required
from .main import BaseModel
from .main import BaseModel, validator
from .parse import Protocol
from .types import *
from .version import VERSION
+92 -49
View File
@@ -12,6 +12,8 @@ Required: Any = Ellipsis
class ValidatorSignature(IntEnum):
JUST_VALUE = 1
VALUE_KWARGS = 2
CLS_JUST_VALUE = 3
CLS_VALUE_KWARGS = 4
class Shape(IntEnum):
@@ -22,8 +24,11 @@ class Shape(IntEnum):
class Field:
__slots__ = ('type_', 'key_type_', 'sub_fields', 'key_field', 'validators', 'default', 'required', 'model_config',
'name', 'alias', 'description', 'info', 'validate_always', 'allow_none', 'shape', 'multipart')
__slots__ = (
'type_', 'key_type_', 'sub_fields', 'key_field', 'validators', 'whole_pre_validators', 'whole_post_validators',
'default', 'required', 'model_config', 'name', 'alias', 'description', 'info', 'validate_always',
'allow_none', 'shape'
)
def __init__(
self, *,
@@ -42,16 +47,17 @@ class Field:
self.type_: type = type_
self.key_type_: type = None
self.validate_always: bool = getattr(self.type_, 'validate_always', False)
self.sub_fields = None
self.sub_fields: List[Field] = None
self.key_field: Field = None
self.validators = []
self.whole_pre_validators = None
self.whole_post_validators = None
self.default: Any = default
self.required: bool = required
self.model_config = model_config
self.description: str = description
self.allow_none: bool = allow_none
self.shape: Shape = Shape.SINGLETON
self.multipart = False
self.info = {}
self._prepare(class_validators or {})
@@ -92,8 +98,7 @@ class Field:
self.allow_none = True
self._populate_sub_fields(class_validators)
if self.sub_fields is None:
self._populate_validators(class_validators)
self._populate_validators(class_validators)
self.info = OrderedDict([
('type', type_display(self.type_)),
@@ -102,7 +107,7 @@ class Field:
])
if self.required:
self.info.pop('default')
if self.multipart:
if self.sub_fields:
self.info['sub_fields'] = self.sub_fields
else:
self.info['validators'] = [v[1].__qualname__ for v in self.validators]
@@ -113,7 +118,7 @@ class Field:
def _populate_sub_fields(self, class_validators):
# typing interface is horrible, we have to do some ugly checks
origin = getattr(self.type_, '__origin__', None)
origin = _get_type_origin(self.type_)
if origin is None:
# field is not "typing" object eg. Union, Dict, List etc.
return
@@ -134,7 +139,6 @@ class Field:
name=f'{self.name}_{type_display(t)}',
model_config=self.model_config,
) for t in types_]
self.multipart = True
elif issubclass(origin, List):
self.type_ = self.type_.__args__[0]
self.shape = Shape.LIST
@@ -156,8 +160,8 @@ class Field:
model_config=self.model_config,
)
if self.sub_fields is None and getattr(self.type_, '__origin__', False):
self.multipart = True
if not self.sub_fields and _get_type_origin(self.type_):
# type_ has been refined eg. as the type of a List and sub_fields needs to be populated
self.sub_fields = [self.__class__(
type_=self.type_,
class_validators=class_validators,
@@ -169,46 +173,61 @@ class Field:
)]
def _populate_validators(self, class_validators):
get_validators = getattr(self.type_, 'get_validators', None)
v_funcs = (
class_validators.get(f'validate_{self.name}_pre'),
if not self.sub_fields:
get_validators = getattr(self.type_, 'get_validators', None)
v_funcs = (
*tuple(f for f, pre, whole in class_validators if not whole and pre),
*(get_validators() if get_validators else find_validators(self.type_)),
*tuple(f for f, pre, whole in class_validators if not whole and not pre),
)
self.validators = self._prep_vals(v_funcs)
*(get_validators() if get_validators else find_validators(self.type_)),
if class_validators:
self.whole_pre_validators = self._prep_vals(f for f, pre, whole in class_validators if whole and pre)
self.whole_post_validators = self._prep_vals(f for f, pre, whole in class_validators if whole and not pre)
class_validators.get(f'validate_{self.name}'),
class_validators.get(f'validate_{self.name}_post'),
)
def _prep_vals(self, v_funcs):
v = []
for f in v_funcs:
if not f or (self.allow_none and f is not_none_validator):
continue
self.validators.append((
v.append((
_get_validator_signature(f),
f,
))
return tuple(v)
def validate(self, v, values, index=None):
def validate(self, v, values, index=None, cls=None):
if self.allow_none and v is None:
return None, None
if self.whole_pre_validators:
v, errors = self._apply_validators(v, values, index, cls, self.whole_pre_validators)
if errors:
return v, errors
if self.shape is Shape.SINGLETON:
return self._validate_singleton(v, values, index)
v, errors = self._validate_singleton(v, values, index, cls)
elif self.shape is Shape.MAPPING:
return self._validate_mapping(v, values)
v, errors = self._validate_mapping(v, values, cls)
else:
# list or set
result, errors = self._validate_sequence(v, values)
v, errors = self._validate_sequence(v, values, cls)
if not errors and self.shape is Shape.SET:
return set(result), errors
return result, errors
v = set(v)
def _validate_sequence(self, v, values):
if not errors and self.whole_post_validators:
v, errors = self._apply_validators(v, values, index, cls, self.whole_post_validators)
return v, errors
def _validate_sequence(self, v, values, cls):
result, errors = [], []
try:
v_iter = enumerate(v)
except TypeError as exc:
return v, Error(exc, None, None)
for i, v_ in v_iter:
single_result, single_errors = self._validate_singleton(v_, values, i)
single_result, single_errors = self._validate_singleton(v_, values, i, cls)
if single_errors:
errors.append(single_errors)
else:
@@ -218,7 +237,7 @@ class Field:
else:
return result, None
def _validate_mapping(self, v, values):
def _validate_mapping(self, v, values, cls):
if isinstance(v, dict):
v_iter = v
else:
@@ -229,11 +248,11 @@ class Field:
result, errors = {}, []
for k, v_ in v_iter.items():
key_result, key_errors = self.key_field.validate(k, values, 'key')
key_result, key_errors = self.key_field.validate(k, values, 'key', cls)
if key_errors:
errors.append(key_errors)
continue
value_result, value_errors = self._validate_singleton(v_, values, k)
value_result, value_errors = self._validate_singleton(v_, values, k, cls)
if value_errors:
errors.append(value_errors)
continue
@@ -243,27 +262,34 @@ class Field:
else:
return result, None
def _validate_singleton(self, v, values, index):
if self.multipart:
def _validate_singleton(self, v, values, index, cls):
if self.sub_fields:
errors = []
for field in self.sub_fields:
value, error = field.validate(v, values, index)
value, error = field.validate(v, values, index, cls)
if error:
errors.append(error)
else:
return value, None
return v, errors[0] if len(self.sub_fields) == 1 else errors
else:
for signature, validator in self.validators:
try:
if signature is ValidatorSignature.JUST_VALUE:
v = validator(v)
else:
# ValidatorSignature.VALUE_KWARGS
v = validator(v, values=values, config=self.model_config, field=self)
except (ValueError, TypeError) as exc:
return v, Error(exc, self.type_, index)
return v, None
return self._apply_validators(v, values, index, cls, self.validators)
def _apply_validators(self, v, values, index, cls, validators):
for signature, validator in validators:
try:
if signature is ValidatorSignature.JUST_VALUE:
v = validator(v)
elif signature is ValidatorSignature.VALUE_KWARGS:
v = validator(v, values=values, config=self.model_config, field=self)
elif signature is ValidatorSignature.CLS_JUST_VALUE:
v = validator(cls, v)
else:
# ValidatorSignature.CLS_VALUE_KWARGS
v = validator(cls, v, values=values, config=self.model_config, field=self)
except (ValueError, TypeError) as exc:
return v, Error(exc, self.type_, index)
return v, None
def __repr__(self):
return f'<Field {self}>'
@@ -279,6 +305,7 @@ def _get_validator_signature(validator):
try:
signature = inspect.signature(validator)
except ValueError:
# TODO we should probably have a white list of allowed validators here, rather than assuming
# happens on builtins like float
return ValidatorSignature.JUST_VALUE
@@ -286,15 +313,24 @@ def _get_validator_signature(validator):
# 1. we can deal with it before validation begins
# 2. (more importantly) it doesn't get confused with a TypeError when executing the validator
try:
if len(signature.parameters) == 1:
signature.bind(1)
return ValidatorSignature.JUST_VALUE
if 'cls' in signature._parameters:
if len(signature.parameters) == 2:
signature.bind(object(), 1)
return ValidatorSignature.CLS_JUST_VALUE
else:
signature.bind(object(), 1, values=2, config=3, field=4)
return ValidatorSignature.CLS_VALUE_KWARGS
else:
signature.bind(1, values=2, config=3, field=4)
return ValidatorSignature.VALUE_KWARGS
if len(signature.parameters) == 1:
signature.bind(1)
return ValidatorSignature.JUST_VALUE
else:
signature.bind(1, values=2, config=3, field=4)
return ValidatorSignature.VALUE_KWARGS
except TypeError as e:
raise ConfigError(f'Invalid signature for validator {validator}: {signature}, should be: '
f'(value) or (value, *, values, config, field)') from e
f'(value) or (value, *, values, config, field) or for class validators '
f'(cls, value) or (cls, value, *, values, config, field)') from e
def _get_field_config(config, name):
@@ -302,3 +338,10 @@ def _get_field_config(config, name):
if isinstance(field_config, str):
field_config = {'alias': field_config}
return field_config
def _get_type_origin(obj):
"""
Like obj.__class__ or type(obj) but for typing objects
"""
return getattr(obj, '__origin__', None)
+34 -6
View File
@@ -36,6 +36,20 @@ def inherit_config(self_config, parent_config) -> BaseConfig:
TYPE_BLACKLIST = FunctionType, property, type, classmethod, staticmethod
def _extract_validators(namespace):
validators = {}
for var_name, value in namespace.items():
validator_config = getattr(value, '__validator_config', None)
if validator_config:
fields, *v = validator_config
for field in fields:
if field in validators:
validators[field].append(v)
else:
validators[field] = [v]
return validators
class MetaModel(type):
@classmethod
def __prepare__(mcs, *args, **kwargs):
@@ -50,9 +64,7 @@ class MetaModel(type):
config = inherit_config(base.config, config)
config = inherit_config(namespace.get('Config'), config)
class_validators = {
n: f for n, f in namespace.items() if n.startswith('validate_') and isinstance(f, FunctionType)
}
validators = _extract_validators(namespace)
for f in fields.values():
f.set_config(config)
@@ -65,7 +77,7 @@ class MetaModel(type):
name=ann_name,
value=...,
annotation=ann_type,
class_validators=class_validators,
class_validators=validators.get(ann_name),
config=config,
)
@@ -75,7 +87,7 @@ class MetaModel(type):
name=var_name,
value=value,
annotation=annotations.get(var_name),
class_validators=class_validators,
class_validators=validators.get(var_name),
config=config,
)
@@ -227,9 +239,11 @@ class BaseModel(metaclass=MetaModel):
values[name] = field.default
continue
values[name], errors_ = field.validate(value, values)
v_, errors_ = field.validate(value, values, cls=self.__class__)
if errors_:
errors[field.alias] = errors_
else:
values[name] = v_
if (not self.config.ignore_extra) or self.config.allow_extra:
extra = input_data.keys() - {f.alias for f in self.__fields__.values()}
@@ -287,3 +301,17 @@ class BaseModel(metaclass=MetaModel):
def __str__(self):
return self.to_string()
def validator(*fields, pre=False, whole=False):
"""
Decorate methods on the class indicating that they should be used to validate fields
:param fields: which field(s) the method should be called on
:param pre: whether or not this validator should be called before the standard validators (else after)
:param whole: for complex objects (sets, lists etc.) whether to validate individual elements or the whole object
"""
def dec(f):
f_cls = classmethod(f)
f_cls.__validator_config = fields, f, pre, whole
return f_cls
return dec
+160
View File
@@ -0,0 +1,160 @@
from typing import List
import pytest
from pydantic import BaseModel, ValidationError, validator
def test_simple():
class Model(BaseModel):
a: str
@validator('a')
def check_a(cls, v):
if 'foobar' not in v:
raise ValueError('"foobar" not found in a')
return v
assert Model(a='this is foobar good').a == 'this is foobar good'
with pytest.raises(ValidationError) as exc_info:
Model(a='snap')
assert '"foobar" not found in a' in str(exc_info.value)
def test_validate_whole():
class Model(BaseModel):
a: List[int]
@validator('a', whole=True, pre=True)
def check_a1(cls, v):
v.append('123')
return v
@validator('a', whole=True)
def check_a2(cls, v):
v.append(456)
return v
assert Model(a=[1, 2]).a == [1, 2, 123, 456]
def test_validate_kwargs():
class Model(BaseModel):
b: int
a: List[int]
@validator('a')
def check_a1(cls, v, values, **kwargs):
return v + values['b']
assert Model(a=[1, 2], b=6).a == [7, 8]
def test_validate_whole_error():
calls = []
class Model(BaseModel):
a: List[int]
@validator('a', whole=True, pre=True)
def check_a1(cls, v):
calls.append(f'check_a1 {v}')
if 1 in v:
raise ValueError('a1 broken')
v[0] += 1
return v
@validator('a', whole=True)
def check_a2(cls, v):
calls.append(f'check_a2 {v}')
if 10 in v:
raise ValueError('a2 broken')
return v
assert Model(a=[3, 8]).a == [4, 8]
assert calls == ['check_a1 [3, 8]', 'check_a2 [4, 8]']
calls = []
with pytest.raises(ValidationError) as exc_info:
Model(a=[1, 3])
assert 'a1 broken' in str(exc_info.value)
assert calls == ['check_a1 [1, 3]']
calls = []
with pytest.raises(ValidationError) as exc_info:
Model(a=[5, 10])
assert 'a2 broken' in str(exc_info.value)
assert calls == ['check_a1 [5, 10]', 'check_a2 [6, 10]']
class ValidateAssignmentModel(BaseModel):
a: int = 4
b: str = ...
@validator('b')
def b_length(cls, v, values, **kwargs):
if 'a' in values and len(v) < values['a']:
raise ValueError('b too short')
return v
class Config:
validate_assignment = True
def test_validating_assignment_ok():
p = ValidateAssignmentModel(b='hello')
assert p.b == 'hello'
def test_validating_assignment_fail():
with pytest.raises(ValidationError):
ValidateAssignmentModel(a=10, b='hello')
p = ValidateAssignmentModel(b='hello')
with pytest.raises(ValidationError):
p.b = 'x'
def test_validating_assignment_values():
with pytest.raises(ValidationError) as exc_info:
ValidateAssignmentModel(a='x', b='xx')
assert """\
error validating input
a:
invalid literal for int() with base 10: 'x' (error_type=ValueError track=int)""" == str(exc_info.value)
def test_validate_multiple():
# also test TypeError
class Model(BaseModel):
a: str
b: str
@validator('a', 'b')
def check_a_and_b(cls, v, field, **kwargs):
if len(v) < 4:
raise TypeError(f'{field.alias} is too short')
return v + 'x'
assert Model(a='1234', b='5678').values() == {'a': '1234x', 'b': '5678x'}
with pytest.raises(ValidationError) as exc_info:
Model(a='x', b='x')
assert """\
2 errors validating input
a:
a is too short (error_type=TypeError track=str)
b:
b is too short (error_type=TypeError track=str)""" == str(exc_info.value)
def test_classmethod():
class Model(BaseModel):
a: str
@validator('a')
def check_a(cls, v):
assert cls is Model
return v
m = Model(a='this is foobar good')
assert m.a == 'this is foobar good'
m.check_a('x')