From e8822775e3a331ca2d06a083429bf85840db6515 Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Fri, 24 Dec 2021 14:17:39 +0100 Subject: [PATCH] fix: support generic models with discriminated union (#3551) --- pydantic/fields.py | 4 ++++ tests/test_discrimated_union.py | 37 ++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pydantic/fields.py b/pydantic/fields.py index 44a3aa1..7f0094c 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -733,6 +733,10 @@ class ModelField(Representation): Note that this process can be aborted if a `ForwardRef` is encountered """ assert self.discriminator_key is not None + + if self.type_.__class__ is DeferredType: + return + assert self.sub_fields is not None sub_fields_mapping: Dict[str, 'ModelField'] = {} all_aliases: Set[str] = set() diff --git a/tests/test_discrimated_union.py b/tests/test_discrimated_union.py index 5203011..c7cd5f4 100644 --- a/tests/test_discrimated_union.py +++ b/tests/test_discrimated_union.py @@ -1,12 +1,14 @@ import re +import sys from enum import Enum -from typing import Union +from typing import Generic, TypeVar, Union import pytest from typing_extensions import Annotated, Literal from pydantic import BaseModel, Field, ValidationError from pydantic.errors import ConfigError +from pydantic.generics import GenericModel def test_discriminated_union_only_union(): @@ -361,3 +363,36 @@ def test_nested(): n: int assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog) + + +@pytest.mark.skipif(sys.version_info < (3, 7), reason='generics only supported for python 3.7 and above') +def test_generic(): + T = TypeVar('T') + + class Success(GenericModel, Generic[T]): + type: Literal['Success'] = 'Success' + data: T + + class Failure(BaseModel): + type: Literal['Failure'] = 'Failure' + error_message: str + + class Container(GenericModel, Generic[T]): + result: Union[Success[T], Failure] = Field(discriminator='type') + + with pytest.raises(ValidationError, match="Discriminator 'type' is missing in value"): + Container[str].parse_obj({'result': {}}) + + with pytest.raises( + ValidationError, + match=re.escape("No match for discriminator 'type' and value 'Other' (allowed values: 'Success', 'Failure')"), + ): + Container[str].parse_obj({'result': {'type': 'Other'}}) + + with pytest.raises( + ValidationError, match=re.escape('Container[str]\nresult -> Success[str] -> data\n field required') + ): + Container[str].parse_obj({'result': {'type': 'Success'}}) + + # coercion is done properly + assert Container[str].parse_obj({'result': {'type': 'Success', 'data': 1}}).result.data == '1'