fix: dataclass wrapper was not always called (#4484)

This commit is contained in:
Eric Jolibois
2022-09-05 18:32:29 +02:00
committed by GitHub
parent 91bb8d4482
commit f1e9883157
3 changed files with 101 additions and 2 deletions
+1
View File
@@ -0,0 +1 @@
fix: dataclass wrapper was not always called
+23 -2
View File
@@ -34,7 +34,20 @@ validation without altering default `M` behaviour.
import sys
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Optional,
Set,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import dataclass_transform
@@ -184,7 +197,7 @@ def dataclass(
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
import dataclasses
if is_builtin_dataclass(cls):
if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
dc_cls_doc = ''
dc_cls = DataclassProxy(cls)
default_validate_on_init = False
@@ -418,6 +431,14 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value:
object.__setattr__(self, name, value)
def _extra_dc_args(cls: Type[Any]) -> Set[str]:
return {
x
for x in dir(cls)
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
}
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
Whether a class is a stdlib dataclass
+77
View File
@@ -1395,3 +1395,80 @@ def test_extra_forbid_list_error():
@pydantic.dataclasses.dataclass
class Foo:
a: List[Bar(a=1)]
def test_parent_post_init():
@dataclasses.dataclass
class A:
a: float = 1
def __post_init__(self):
self.a *= 2
@pydantic.dataclasses.dataclass
class B(A):
@validator('a')
def validate_a(cls, value):
value += 3
return value
assert B().a == 5 # 1 * 2 + 3
def test_subclass_post_init_post_parse():
@dataclasses.dataclass
class A:
a: float = 1
@pydantic.dataclasses.dataclass
class B(A):
def __post_init_post_parse__(self):
self.a *= 2
@validator('a')
def validate_a(cls, value):
value += 3
return value
assert B().a == 8 # (1 + 3) * 2
def test_subclass_post_init():
@dataclasses.dataclass
class A:
a: int = 1
@pydantic.dataclasses.dataclass
class B(A):
def __post_init__(self):
self.a *= 2
@validator('a')
def validate_a(cls, value):
value += 3
return value
assert B().a == 5 # 1 * 2 + 3
def test_subclass_post_init_inheritance():
@dataclasses.dataclass
class A:
a: int = 1
@pydantic.dataclasses.dataclass
class B(A):
def __post_init__(self):
self.a *= 2
@validator('a')
def validate_a(cls, value):
value += 3
return value
@pydantic.dataclasses.dataclass
class C(B):
def __post_init__(self):
self.a *= 3
assert C().a == 6 # 1 * 3 + 3