fix: resolve forward refs for inherited dataclasses (#2220)

* fix: resolve forward refs for inherited dataclasses

closes #1668

* chore: add change file

* fix: make test work everywhere

* chore: rename file just in case

As it doesn't solve the target issue, let's change the PR number

* docs: update change description
This commit is contained in:
Eric Jolibois
2021-02-13 11:10:14 +01:00
committed by GitHub
parent c314f5a909
commit 7bef40bb11
3 changed files with 30 additions and 1 deletions
+1
View File
@@ -0,0 +1 @@
Resolve forward refs for stdlib dataclasses converted into _pydantic_ ones
+2 -1
View File
@@ -5,6 +5,7 @@ from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Required
from .main import create_model, validate_model
from .typing import resolve_annotations
from .utils import ClassAttribute
if TYPE_CHECKING:
@@ -128,7 +129,7 @@ def _process_class(
_cls.__name__,
(_cls,),
{
'__annotations__': _cls.__annotations__,
'__annotations__': resolve_annotations(_cls.__annotations__, _cls.__module__),
'__post_init__': _pydantic_post_init,
# attrs for pickle to find this class
'__module__': __name__,
+27
View File
@@ -4,6 +4,7 @@ from typing import Optional, Tuple
import pytest
from pydantic import BaseModel, ConfigError, ValidationError
from pydantic.typing import Literal
skip_pre_37 = pytest.mark.skipif(sys.version_info < (3, 7), reason='testing >= 3.7 behaviour only')
@@ -480,6 +481,32 @@ def test_forward_ref_with_create_model(create_module):
assert instance.sub.dict() == {'foo': 'bar'}
@skip_pre_37
@pytest.mark.skipif(not Literal, reason='typing_extensions not installed')
def test_resolve_forward_ref_dataclass(create_module):
module = create_module(
# language=Python
"""
from __future__ import annotations
from dataclasses import dataclass
from pydantic import BaseModel
from pydantic.typing import Literal
@dataclass
class Base:
literal: Literal[1, 2]
class What(BaseModel):
base: Base
"""
)
m = module.What(base=module.Base(literal=1))
assert m.base.literal == 1
def test_nested_forward_ref():
class NestedTuple(BaseModel):
x: Tuple[int, Optional['NestedTuple']] # noqa: F821