fix: support arbitrary types with custom __eq__ (#2502)

This commit is contained in:
Eric Jolibois
2021-05-09 11:44:53 +02:00
committed by GitHub
parent 07908b3846
commit 9cc19e9a8e
9 changed files with 66 additions and 42 deletions
+2
View File
@@ -0,0 +1,2 @@
- support arbitrary types with custom `__eq__`
- support `Annotated` in `validate_arguments` and in generic models with python 3.9
+5 -18
View File
@@ -1,23 +1,10 @@
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from . import validator
from .errors import ConfigError
from .main import BaseModel, Extra, create_model
from .typing import get_all_type_hints
from .utils import to_camel
__all__ = ('validate_arguments',)
@@ -87,17 +74,17 @@ class ValidatedFunction:
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
type_hints = get_type_hints(function)
type_hints = get_all_type_hints(function)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation == p.empty:
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]
default = ... if p.default == p.empty else p.default
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
+4 -4
View File
@@ -177,7 +177,7 @@ class FieldInfo(Representation):
self.include = ValueItems.merge(value, current_value, intersect=True)
def _validate(self) -> None:
if self.default not in (Undefined, Ellipsis) and self.default_factory is not None:
if self.default is not Undefined and self.default_factory is not None:
raise ValueError('cannot specify both default and default_factory')
@@ -386,9 +386,10 @@ class ModelField(Representation):
field_info = next(iter(field_infos), None)
if field_info is not None:
field_info.update_from_config(field_info_from_config)
if field_info.default not in (Undefined, Ellipsis):
if field_info.default is not Undefined:
raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}')
if value not in (Undefined, Ellipsis):
if value is not Undefined and value is not Required:
# check also `Required` because of `validate_arguments` that sets `...` as default value
field_info.default = value
if isinstance(value, FieldInfo):
@@ -472,7 +473,6 @@ class ModelField(Representation):
self._type_analysis()
if self.required is Undefined:
self.required = True
self.field_info.default = Required
if self.default is Undefined and self.default_factory is None:
self.default = None
self.populate_validators()
+8 -3
View File
@@ -15,13 +15,14 @@ from typing import (
TypeVar,
Union,
cast,
get_type_hints,
)
from typing_extensions import Annotated
from .class_validators import gather_all_validators
from .fields import DeferredType
from .main import BaseModel, create_model
from .typing import display_as_type, get_args, get_origin, typing_base
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import all_identical, lenient_issubclass
_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
@@ -73,7 +74,7 @@ class GenericModel(BaseModel):
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
type_hints = get_type_hints(cls).items()
type_hints = get_all_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
@@ -159,6 +160,10 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is Annotated:
annotated_type, *annotations = type_args
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
+14
View File
@@ -18,6 +18,7 @@ from typing import ( # type: ignore
Union,
_eval_type,
cast,
get_type_hints,
)
from typing_extensions import Annotated, Literal
@@ -70,6 +71,18 @@ else:
return cast(Any, type_)._evaluate(globalns, localns, set())
if sys.version_info < (3, 9):
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
# For 3.6 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
# so it already returns the full annotation
get_all_type_hints = get_type_hints
else:
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
return get_type_hints(obj, globalns, localns, include_extras=True)
if sys.version_info < (3, 7):
from typing import Callable as Callable
@@ -225,6 +238,7 @@ __all__ = (
'get_args',
'get_origin',
'typing_base',
'get_all_type_hints',
)
+2 -12
View File
@@ -1,11 +1,9 @@
import sys
from typing import get_type_hints
import pytest
from typing_extensions import Annotated
from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic.typing import get_all_type_hints
@pytest.mark.parametrize(
@@ -43,15 +41,7 @@ def test_annotated(hint_fn, value):
assert M().x == 5
assert M(x=10).x == 10
# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
# full thing with the new include_extras flag.
if sys.version_info >= (3, 9):
assert get_type_hints(M)['x'] is int
assert get_type_hints(M, include_extras=True)['x'] == hint
else:
assert get_type_hints(M)['x'] == hint
assert get_all_type_hints(M)['x'] == hint
@pytest.mark.parametrize(
+4 -4
View File
@@ -3,14 +3,13 @@ import inspect
import sys
from pathlib import Path
from typing import List
from unittest.mock import ANY
import pytest
from typing_extensions import Annotated
from pydantic import BaseModel, Field, ValidationError, validate_arguments
from pydantic.decorator import ValidatedFunction
from pydantic.errors import ConfigError
from pydantic.typing import Annotated
skip_pre_38 = pytest.mark.skipif(sys.version_info < (3, 8), reason='testing >= 3.8 behaviour only')
@@ -154,13 +153,14 @@ def test_field_can_provide_factory() -> None:
assert foo(1, 2, 3) == 6
@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
def test_annotated_field_can_provide_factory() -> None:
@validate_arguments
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)] = ANY, *args: int) -> int:
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)], *args: int) -> int:
"""mypy reports Incompatible default for argument "b" if we don't supply ANY as default"""
return a + b + sum(args)
assert foo2(1) == 100
@skip_pre_38
def test_positional_only(create_module):
+16
View File
@@ -1839,3 +1839,19 @@ def test_config_field_info_allow_mutation():
with pytest.raises(TypeError):
b.a = 'y'
assert b.dict() == {'a': 'x'}
def test_arbitrary_types_allowed_custom_eq():
class Foo:
def __eq__(self, other):
if other.__class__ is not Foo:
raise TypeError(f'Cannot interpret {other.__class__.__name__!r} as a valid type')
return True
class Model(BaseModel):
x: Foo = Foo()
class Config:
arbitrary_types_allowed = True
assert Model().x == Foo()
+11 -1
View File
@@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import pytest
from typing_extensions import Literal
from typing_extensions import Annotated, Literal
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
@@ -1071,3 +1071,13 @@ def test_generic_literal():
Fields = Literal['foo', 'bar']
m = GModel[Fields, str](field={'foo': 'x'})
assert m.dict() == {'field': {'foo': 'x'}}
@skip_36
def test_generic_annotated():
T = TypeVar('T')
class SomeGenericModel(GenericModel, Generic[T]):
some_field: Annotated[T, Field(alias='the_alias')]
SomeGenericModel[str](the_alias='qwe')