From f08fd2fee738145efa894199fe47e49a84c96aa1 Mon Sep 17 00:00:00 2001 From: Timon Ruban Date: Mon, 16 Sep 2019 11:42:40 +0200 Subject: [PATCH] Add support for Type[T] typehints when arbitrary_types_allowed==True. (#808) * Add support for Type[T] typehints when arbitrary_types_allowe==True. * Add documentation. * Let black do its magic. * Ignore mypy warning - see here: https://github.com/python/mypy/issues/3060 * Prettify docs. * Change Changelog. * Refactor and simplify check for Type[T]. * Black again. ^^ - Really need pre-commit hooks. * Update pydantic/validators.py Co-Authored-By: Samuel Colvin * Rename arbitrary_class to class. * Black. * Add type hints. * Make private function public. * Add support for bare Type. * Black again. * Update docs. * CO_ct not meant for export. * Fix get_class for Python3.6 * Update error message of ClassError. * Use relative import. * Incorporate typing feedback (both versions are fine with mypy). * Move from issubclass to lenient_issubclass. * correct docs --- changes/807-timonbimon.rst | 1 + docs/examples/bare_type_type.py | 24 ++++++++++ docs/examples/type_type.py | 29 +++++++++++++ docs/index.rst | 14 +++++- pydantic/errors.py | 13 ++++++ pydantic/fields.py | 2 + pydantic/typing.py | 18 ++++++++ pydantic/validators.py | 27 +++++++++++- tests/test_main.py | 77 ++++++++++++++++++++++++++++++++- 9 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 changes/807-timonbimon.rst create mode 100644 docs/examples/bare_type_type.py create mode 100644 docs/examples/type_type.py diff --git a/changes/807-timonbimon.rst b/changes/807-timonbimon.rst new file mode 100644 index 0000000..09a4cbd --- /dev/null +++ b/changes/807-timonbimon.rst @@ -0,0 +1 @@ +add support for ``Type[T]`` type hints \ No newline at end of file diff --git a/docs/examples/bare_type_type.py b/docs/examples/bare_type_type.py new file mode 100644 index 0000000..bbe45ed --- /dev/null +++ b/docs/examples/bare_type_type.py @@ -0,0 +1,24 @@ +from typing import Type + +from pydantic import BaseModel, ValidationError + + +class Foo: + pass + + +class LenientSimpleModel(BaseModel): + any_class_goes: Type + + +LenientSimpleModel(any_class_goes=int) +LenientSimpleModel(any_class_goes=Foo) +try: + LenientSimpleModel(any_class_goes=Foo()) +except ValidationError as e: + print(e) +""" +1 validation error +any_class_goes + subclass of type expected (type=type_error.class) +""" diff --git a/docs/examples/type_type.py b/docs/examples/type_type.py new file mode 100644 index 0000000..985ab6b --- /dev/null +++ b/docs/examples/type_type.py @@ -0,0 +1,29 @@ +from typing import Type + +from pydantic import BaseModel +from pydantic import ValidationError + +class Foo: + pass + +class Bar(Foo): + pass + +class Other: + pass + +class SimpleModel(BaseModel): + just_subclasses: Type[Foo] + + +SimpleModel(just_subclasses=Foo) +SimpleModel(just_subclasses=Bar) +try: + SimpleModel(just_subclasses=Other) +except ValidationError as e: + print(e) +""" +1 validation error +just_subclasses + subclass of Foo expected (type=type_error.class) +""" diff --git a/docs/index.rst b/docs/index.rst index 52d7e22..8f9245c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -818,6 +818,18 @@ With proper ordering in an annotated ``Union``, you can use this to parse types (This script is complete, it should run "as is") +Type Type +............ + +Pydantic supports the use of ``Type[T]`` to specify that a field may only accept classes (not instances) +that are subclasses of ``T``. + +.. literalinclude:: examples/type_type.py + +You may also use ``Type`` to specify that any class is allowed. + +.. literalinclude:: examples/bare_type_type.py + Custom Data Types ................. @@ -898,7 +910,7 @@ Options: :error_msg_templates: let's you to override default error message templates. Pass in a dictionary with keys matching the error messages you want to override (default: ``{}``) :arbitrary_types_allowed: whether to allow arbitrary user types for fields (they are validated simply by checking if the - value is instance of that type). If False - RuntimeError will be raised on model declaration (default: ``False``) + value is instance of that type). If ``False`` - ``RuntimeError`` will be raised on model declaration (default: ``False``) :json_encoders: customise the way types are encoded to json, see :ref:`JSON Serialisation ` for more details. :orm_mode: allows usage of :ref:`ORM mode ` diff --git a/pydantic/errors.py b/pydantic/errors.py index 58059be..890006d 100644 --- a/pydantic/errors.py +++ b/pydantic/errors.py @@ -324,6 +324,19 @@ class ArbitraryTypeError(PydanticTypeError): super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type)) +class ClassError(PydanticTypeError): + code = 'class' + msg_template = 'a class is expected' + + +class SubclassError(PydanticTypeError): + code = 'subclass' + msg_template = 'subclass of {expected_class} expected' + + def __init__(self, *, expected_class: AnyType) -> None: + super().__init__(expected_class=display_as_type(expected_class)) + + class JsonError(PydanticValueError): msg_template = 'Invalid JSON' diff --git a/pydantic/fields.py b/pydantic/fields.py index 67b7aca..47104f3 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -248,6 +248,8 @@ class Field: ) self.type_ = self.type_.__args__[1] # type: ignore self.shape = SHAPE_MAPPING + elif issubclass(origin, Type): # type: ignore + return else: raise TypeError(f'Fields of type "{origin}" are not supported.') diff --git a/pydantic/typing.py b/pydantic/typing.py index a7ad890..3ab676f 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -193,3 +193,21 @@ def update_field_forward_refs(field: 'Field', globalns: Any, localns: Any) -> No if field.sub_fields: for sub_f in field.sub_fields: update_field_forward_refs(sub_f, globalns=globalns, localns=localns) + + +def get_class(type_: AnyType) -> Union[None, bool, AnyType]: + """ + Tries to get the class of a Type[T] annotation. Returns True if Type is used + without brackets. Otherwise returns None. + """ + try: + origin = getattr(type_, '__origin__') + if origin is None: # Python 3.6 + origin = type_ + if issubclass(origin, Type): # type: ignore + if type_.__args__ is None or not isinstance(type_.__args__[0], type): + return True + return type_.__args__[0] + except AttributeError: + pass + return None diff --git a/pydantic/validators.py b/pydantic/validators.py index 7088d59..4c7fa23 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -26,8 +26,8 @@ from uuid import UUID from . import errors from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, is_callable_type, is_literal_type -from .utils import almost_equal_floats, change_exception, sequence_like +from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, get_class, is_callable_type, is_literal_type +from .utils import almost_equal_floats, change_exception, lenient_issubclass, sequence_like if TYPE_CHECKING: # pragma: no cover from .fields import Field @@ -404,6 +404,21 @@ def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]: return arbitrary_type_validator +def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]: + def class_validator(v: Any) -> Type[T]: + if lenient_issubclass(v, type_): + return v + raise errors.SubclassError(expected_class=type_) + + return class_validator + + +def any_class_validator(v: Any) -> Type[T]: + if isinstance(v, type): + return v + raise errors.ClassError() + + def pattern_validator(v: Any) -> Pattern[str]: with change_exception(errors.PatternError, re.error): return re.compile(v) @@ -486,6 +501,14 @@ def find_validators( # noqa: C901 (ignore complexity) yield make_literal_validator(type_) return + class_ = get_class(type_) + if class_ is not None: + if isinstance(class_, type): + yield make_class_validator(class_) + else: + yield any_class_validator + return + supertype = _find_supertype(type_) if supertype is not None: type_ = supertype diff --git a/tests/test_main.py b/tests/test_main.py index 7c2d579..019776f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, ClassVar, List, Mapping +from typing import Any, ClassVar, List, Mapping, Type import pytest @@ -530,6 +530,81 @@ def test_arbitrary_types_not_allowed(): assert exc_info.value.args[0].startswith('no validator found for') +def test_type_type_validation_success(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type[ArbitraryType] + + arbitrary_type_class = ArbitraryType + m = ArbitraryClassAllowedModel(t=arbitrary_type_class) + assert m.t == arbitrary_type_class + + +def test_type_type_subclass_validation_success(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type[ArbitraryType] + + class ArbitrarySubType(ArbitraryType): + pass + + arbitrary_type_class = ArbitrarySubType + m = ArbitraryClassAllowedModel(t=arbitrary_type_class) + assert m.t == arbitrary_type_class + + +def test_type_type_validation_fails_for_instance(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type[ArbitraryType] + + class C: + pass + + with pytest.raises(ValidationError) as exc_info: + ArbitraryClassAllowedModel(t=C) + assert exc_info.value.errors() == [ + { + 'loc': ('t',), + 'msg': 'subclass of ArbitraryType expected', + 'type': 'type_error.subclass', + 'ctx': {'expected_class': 'ArbitraryType'}, + } + ] + + +def test_type_type_validation_fails_for_basic_type(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type[ArbitraryType] + + with pytest.raises(ValidationError) as exc_info: + ArbitraryClassAllowedModel(t=1) + assert exc_info.value.errors() == [ + { + 'loc': ('t',), + 'msg': 'subclass of ArbitraryType expected', + 'type': 'type_error.subclass', + 'ctx': {'expected_class': 'ArbitraryType'}, + } + ] + + +def test_bare_type_type_validation_success(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type + + arbitrary_type_class = ArbitraryType + m = ArbitraryClassAllowedModel(t=arbitrary_type_class) + assert m.t == arbitrary_type_class + + +def test_bare_type_type_validation_fails(): + class ArbitraryClassAllowedModel(BaseModel): + t: Type + + arbitrary_type = ArbitraryType() + with pytest.raises(ValidationError) as exc_info: + ArbitraryClassAllowedModel(t=arbitrary_type) + assert exc_info.value.errors() == [{'loc': ('t',), 'msg': 'a class is expected', 'type': 'type_error.class'}] + + def test_annotation_field_name_shadows_attribute(): with pytest.raises(NameError): # When defining a model that has an attribute with the name of a built-in attribute, an exception is raised