From 73015d2a728e4feba2b2bdbf756ffcaaa8a75307 Mon Sep 17 00:00:00 2001 From: Evghenii Goncearov Date: Mon, 2 Jul 2018 14:08:43 +0300 Subject: [PATCH] Allow arbitrary types in model (#209) * Allow arbitrary types in model * Replaced ConfigError with RuntimeError * Corrections of the ArbitraryTypeError exception class --- docs/index.rst | 2 ++ pydantic/errors.py | 10 ++++++++ pydantic/fields.py | 3 ++- pydantic/main.py | 1 + pydantic/validators.py | 15 ++++++++++-- tests/test_main.py | 52 ++++++++++++++++++++++++++++++++++++------ 6 files changed, 73 insertions(+), 10 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index eb5c382..3012f8f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -345,6 +345,8 @@ Options: ``False``) :error_msg_templates: let's you to override default error message templates. Pass in a dictionary with keys matching the error messages you want to override (default: ``{}``) +:arbitrary_types_allowed: whether to allow arbitrary user types for fields (they are validated simply by checking if the + value is instance of that type). If False - RuntimeError will be raised on model declaration (default: ``False``) .. warning:: diff --git a/pydantic/errors.py b/pydantic/errors.py index a1939af..5b12193 100644 --- a/pydantic/errors.py +++ b/pydantic/errors.py @@ -2,6 +2,8 @@ from decimal import Decimal from pathlib import Path from typing import Union +from .utils import display_as_type + class PydanticErrorMixin: code: str @@ -225,3 +227,11 @@ class UUIDVersionError(PydanticValueError): def __init__(self, *, required_version: int) -> None: super().__init__(required_version=required_version) + + +class ArbitraryTypeError(PydanticTypeError): + code = 'arbitrary_type' + msg_template = 'instance of {expected_arbitrary_type} expected' + + def __init__(self, *, expected_arbitrary_type) -> None: + super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type)) diff --git a/pydantic/fields.py b/pydantic/fields.py index 4667b1c..85e9156 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -213,7 +213,8 @@ class Field: get_validators = getattr(self.type_, 'get_validators', None) v_funcs = ( *tuple(v.func for v in self.class_validators if not v.whole and v.pre), - *(get_validators() if get_validators else find_validators(self.type_)), + *(get_validators() if get_validators else find_validators(self.type_, + self.model_config.arbitrary_types_allowed)), *tuple(v.func for v in self.class_validators if not v.whole and not v.pre), ) self.validators = self._prep_vals(v_funcs) diff --git a/pydantic/main.py b/pydantic/main.py index 7a0c6d8..739cc46 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -30,6 +30,7 @@ class BaseConfig: fields = {} validate_assignment = False error_msg_templates: Dict[str, str] = {} + arbitrary_types_allowed = False @classmethod def get_field_schema(cls, name): diff --git a/pydantic/validators.py b/pydantic/validators.py index 37ebb78..1031ce7 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -214,6 +214,14 @@ def path_exists_validator(v) -> Path: return v +def make_arbitrary_type_validator(type_): + def arbitrary_type_validator(v) -> type_: + if isinstance(v, type_): + return v + raise errors.ArbitraryTypeError(expected_arbitrary_type=type_) + return arbitrary_type_validator + + # order is important here, for example: bool is a subclass of int so has to come first, datetime before date same _VALIDATORS = [ (Enum, [enum_validator]), @@ -242,7 +250,7 @@ _VALIDATORS = [ ] -def find_validators(type_): +def find_validators(type_, arbitrary_types_allowed=False): if type_ is Any: return [] for val_type, validators in _VALIDATORS: @@ -251,4 +259,7 @@ def find_validators(type_): return validators except TypeError as e: raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})') from e - raise errors.ConfigError(f'no validator found for {type_}') + if arbitrary_types_allowed: + return [make_arbitrary_type_validator(type_)] + else: + raise RuntimeError(f'no validator found for {type_}') diff --git a/tests/test_main.py b/tests/test_main.py index 6aa22cf..425d7ef 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -205,13 +205,6 @@ def test_invalid_validator(): assert exc_info.value.args[0].startswith('Invalid signature for validator') -def test_no_validator(): - with pytest.raises(errors.ConfigError) as exc_info: - class NoValidatorModel(BaseModel): - x: object = ... - assert exc_info.value.args[0] == "no validator found for " - - def test_unable_to_infer(): with pytest.raises(errors.ConfigError) as exc_info: class InvalidDefinitionModel(BaseModel): @@ -462,3 +455,48 @@ def test_default_copy(): u1 = User() u2 = User() assert u1.friends is not u2.friends + + +class ArbitraryType: + pass + + +def test_arbitrary_type_allowed_validation_success(): + class ArbitraryTypeAllowedModel(BaseModel): + t: ArbitraryType + + class Config: + arbitrary_types_allowed = True + + arbitrary_type_instance = ArbitraryType() + m = ArbitraryTypeAllowedModel(t=arbitrary_type_instance) + assert m.t == arbitrary_type_instance + + +def test_arbitrary_type_allowed_validation_fails(): + class ArbitraryTypeAllowedModel(BaseModel): + t: ArbitraryType + + class Config: + arbitrary_types_allowed = True + + class C: + pass + + with pytest.raises(ValidationError) as exc_info: + ArbitraryTypeAllowedModel(t=C()) + assert exc_info.value.errors() == [ + { + 'loc': ('t',), + 'msg': "instance of ArbitraryType expected", + 'type': 'type_error.arbitrary_type', + 'ctx': {'expected_arbitrary_type': 'ArbitraryType'} + }, + ] + + +def test_arbitrary_types_not_allowed(): + with pytest.raises(RuntimeError) as exc_info: + class ArbitraryTypeNotAllowedModel(BaseModel): + t: ArbitraryType + assert exc_info.value.args[0].startswith('no validator found for')