mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
fix: support arbitrary types with custom __eq__ (#2502)
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
- support arbitrary types with custom `__eq__`
|
||||
- support `Annotated` in `validate_arguments` and in generic models with python 3.9
|
||||
+5
-18
@@ -1,23 +1,10 @@
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
overload,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from . import validator
|
||||
from .errors import ConfigError
|
||||
from .main import BaseModel, Extra, create_model
|
||||
from .typing import get_all_type_hints
|
||||
from .utils import to_camel
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
@@ -87,17 +74,17 @@ class ValidatedFunction:
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = get_type_hints(function)
|
||||
type_hints = get_all_type_hints(function)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation == p.empty:
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default == p.empty else p.default
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
|
||||
+4
-4
@@ -177,7 +177,7 @@ class FieldInfo(Representation):
|
||||
self.include = ValueItems.merge(value, current_value, intersect=True)
|
||||
|
||||
def _validate(self) -> None:
|
||||
if self.default not in (Undefined, Ellipsis) and self.default_factory is not None:
|
||||
if self.default is not Undefined and self.default_factory is not None:
|
||||
raise ValueError('cannot specify both default and default_factory')
|
||||
|
||||
|
||||
@@ -386,9 +386,10 @@ class ModelField(Representation):
|
||||
field_info = next(iter(field_infos), None)
|
||||
if field_info is not None:
|
||||
field_info.update_from_config(field_info_from_config)
|
||||
if field_info.default not in (Undefined, Ellipsis):
|
||||
if field_info.default is not Undefined:
|
||||
raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}')
|
||||
if value not in (Undefined, Ellipsis):
|
||||
if value is not Undefined and value is not Required:
|
||||
# check also `Required` because of `validate_arguments` that sets `...` as default value
|
||||
field_info.default = value
|
||||
|
||||
if isinstance(value, FieldInfo):
|
||||
@@ -472,7 +473,6 @@ class ModelField(Representation):
|
||||
self._type_analysis()
|
||||
if self.required is Undefined:
|
||||
self.required = True
|
||||
self.field_info.default = Required
|
||||
if self.default is Undefined and self.default_factory is None:
|
||||
self.default = None
|
||||
self.populate_validators()
|
||||
|
||||
@@ -15,13 +15,14 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .class_validators import gather_all_validators
|
||||
from .fields import DeferredType
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import display_as_type, get_args, get_origin, typing_base
|
||||
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from .utils import all_identical, lenient_issubclass
|
||||
|
||||
_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
|
||||
@@ -73,7 +74,7 @@ class GenericModel(BaseModel):
|
||||
model_name = cls.__concrete_name__(params)
|
||||
validators = gather_all_validators(cls)
|
||||
|
||||
type_hints = get_type_hints(cls).items()
|
||||
type_hints = get_all_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
|
||||
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
|
||||
@@ -159,6 +160,10 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import ( # type: ignore
|
||||
Union,
|
||||
_eval_type,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, Literal
|
||||
@@ -70,6 +71,18 @@ else:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, set())
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
|
||||
# For 3.6 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
|
||||
# so it already returns the full annotation
|
||||
get_all_type_hints = get_type_hints
|
||||
|
||||
else:
|
||||
|
||||
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
|
||||
return get_type_hints(obj, globalns, localns, include_extras=True)
|
||||
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
from typing import Callable as Callable
|
||||
|
||||
@@ -225,6 +238,7 @@ __all__ = (
|
||||
'get_args',
|
||||
'get_origin',
|
||||
'typing_base',
|
||||
'get_all_type_hints',
|
||||
)
|
||||
|
||||
|
||||
|
||||
+2
-12
@@ -1,11 +1,9 @@
|
||||
import sys
|
||||
from typing import get_type_hints
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import Undefined
|
||||
from pydantic.typing import get_all_type_hints
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -43,15 +41,7 @@ def test_annotated(hint_fn, value):
|
||||
|
||||
assert M().x == 5
|
||||
assert M(x=10).x == 10
|
||||
|
||||
# get_type_hints doesn't recognize typing_extensions.Annotated, so will return the full
|
||||
# annotation. 3.9 w/ stock Annotated will return the wrapped type by default, but return the
|
||||
# full thing with the new include_extras flag.
|
||||
if sys.version_info >= (3, 9):
|
||||
assert get_type_hints(M)['x'] is int
|
||||
assert get_type_hints(M, include_extras=True)['x'] == hint
|
||||
else:
|
||||
assert get_type_hints(M)['x'] == hint
|
||||
assert get_all_type_hints(M)['x'] == hint
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -3,14 +3,13 @@ import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, validate_arguments
|
||||
from pydantic.decorator import ValidatedFunction
|
||||
from pydantic.errors import ConfigError
|
||||
from pydantic.typing import Annotated
|
||||
|
||||
skip_pre_38 = pytest.mark.skipif(sys.version_info < (3, 8), reason='testing >= 3.8 behaviour only')
|
||||
|
||||
@@ -154,13 +153,14 @@ def test_field_can_provide_factory() -> None:
|
||||
assert foo(1, 2, 3) == 6
|
||||
|
||||
|
||||
@pytest.mark.skipif(not Annotated, reason='typing_extensions not installed')
|
||||
def test_annotated_field_can_provide_factory() -> None:
|
||||
@validate_arguments
|
||||
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)] = ANY, *args: int) -> int:
|
||||
def foo2(a: int, b: Annotated[int, Field(default_factory=lambda: 99)], *args: int) -> int:
|
||||
"""mypy reports Incompatible default for argument "b" if we don't supply ANY as default"""
|
||||
return a + b + sum(args)
|
||||
|
||||
assert foo2(1) == 100
|
||||
|
||||
|
||||
@skip_pre_38
|
||||
def test_positional_only(create_module):
|
||||
|
||||
@@ -1839,3 +1839,19 @@ def test_config_field_info_allow_mutation():
|
||||
with pytest.raises(TypeError):
|
||||
b.a = 'y'
|
||||
assert b.dict() == {'a': 'x'}
|
||||
|
||||
|
||||
def test_arbitrary_types_allowed_custom_eq():
|
||||
class Foo:
|
||||
def __eq__(self, other):
|
||||
if other.__class__ is not Foo:
|
||||
raise TypeError(f'Cannot interpret {other.__class__.__name__!r} as a valid type')
|
||||
return True
|
||||
|
||||
class Model(BaseModel):
|
||||
x: Foo = Foo()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
assert Model().x == Foo()
|
||||
|
||||
+11
-1
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Literal
|
||||
from typing_extensions import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
|
||||
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
|
||||
@@ -1071,3 +1071,13 @@ def test_generic_literal():
|
||||
Fields = Literal['foo', 'bar']
|
||||
m = GModel[Fields, str](field={'foo': 'x'})
|
||||
assert m.dict() == {'field': {'foo': 'x'}}
|
||||
|
||||
|
||||
@skip_36
|
||||
def test_generic_annotated():
|
||||
T = TypeVar('T')
|
||||
|
||||
class SomeGenericModel(GenericModel, Generic[T]):
|
||||
some_field: Annotated[T, Field(alias='the_alias')]
|
||||
|
||||
SomeGenericModel[str](the_alias='qwe')
|
||||
|
||||
Reference in New Issue
Block a user