From fe2a5e2170a00d210e200468f37ef65501a907cf Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Wed, 22 Dec 2021 20:39:18 +0100 Subject: [PATCH] fix: smart union with typeddict (#3543) --- pydantic/fields.py | 3 ++- pydantic/utils.py | 8 ++++++++ tests/test_types.py | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pydantic/fields.py b/pydantic/fields.py index ff7b962..44a3aa1 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -50,6 +50,7 @@ from .utils import ( ValueItems, get_discriminator_alias_and_values, get_unique_discriminator_alias, + lenient_isinstance, lenient_issubclass, sequence_like, smart_deepcopy, @@ -1048,7 +1049,7 @@ class ModelField(Representation): return v, None except TypeError: # compound type - if isinstance(v, get_origin(field.outer_type_)): + if lenient_isinstance(v, get_origin(field.outer_type_)): value, error = field.validate(v, values, loc=loc, cls=cls) if not error: return value, None diff --git a/pydantic/utils.py b/pydantic/utils.py index ec94f76..d4ca693 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -53,6 +53,7 @@ __all__ = ( 'import_string', 'sequence_like', 'validate_field_name', + 'lenient_isinstance', 'lenient_issubclass', 'in_ipython', 'deep_update', @@ -163,6 +164,13 @@ def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None ) +def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]) -> bool: + try: + return isinstance(o, class_or_tuple) + except TypeError: + return False + + def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]) -> bool: try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) diff --git a/tests/test_types.py b/tests/test_types.py index 2a31613..bbf4c23 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -27,7 +27,7 @@ from typing import ( from uuid import UUID import pytest -from typing_extensions import Literal +from typing_extensions import Literal, TypedDict from pydantic import ( UUID1, @@ -3120,6 +3120,22 @@ def test_smart_union_compouned_types_edge_case(): assert Model(x=[1, '2']).x == ['1', '2'] +def test_smart_union_typeddict(): + class Dict1(TypedDict): + foo: str + + class Dict2(TypedDict): + bar: str + + class M(BaseModel): + d: Union[Dict2, Dict1] + + class Config: + smart_union = True + + assert M(d=dict(foo='baz')).d == {'foo': 'baz'} + + @pytest.mark.parametrize( 'value,result', (