support kw_only on dataclasses (#3674)

* support `kw_only`

* add changes file

* add test for `kw_only`

* tweak error message

Co-authored-by: detachhead <detachhead@users.noreply.github.com>
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
This commit is contained in:
DetachHead
2022-08-08 22:44:16 +10:00
committed by GitHub
parent 90a103ec3c
commit 0b8b7eb4b6
3 changed files with 96 additions and 32 deletions
+1
View File
@@ -0,0 +1 @@
Support `kw_only` in dataclasses
+81 -32
View File
@@ -31,6 +31,7 @@ This means we **don't want to create a new dataclass that inherits from it**
The trick is to create a wrapper around `M` that will act as a proxy to trigger
validation without altering default `M` behaviour.
"""
import sys
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
@@ -85,38 +86,73 @@ __all__ = [
'make_dataclass_validator',
]
if sys.version_info >= (3, 10):
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> 'DataclassClassOrWrapper':
...
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...
else:
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@@ -131,6 +167,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = False,
) -> Union[Callable[[Type[Any]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
"""
Like the python standard lib dataclasses but with type validation.
@@ -149,9 +186,21 @@ def dataclass(
default_validate_on_init = False
else:
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
if sys.version_info >= (3, 10):
dc_cls = dataclasses.dataclass(
cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
kw_only=kw_only,
)
else:
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
default_validate_on_init = True
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
+14
View File
@@ -1,6 +1,7 @@
import dataclasses
import pickle
import re
import sys
from collections.abc import Hashable
from datetime import datetime
from pathlib import Path
@@ -1346,3 +1347,16 @@ def test_self_reference_dataclass():
self_reference: 'MyDataclass'
assert MyDataclass.__pydantic_model__.__fields__['self_reference'].type_ is MyDataclass
@pytest.mark.skipif(sys.version_info < (3, 10), reason='kw_only is not available in python < 3.10')
def test_kw_only():
@pydantic.dataclasses.dataclass(kw_only=True)
class A:
a: int | None = None
b: str
with pytest.raises(TypeError, match='takes 1 positional argument but 3 were given'):
A(1, '')
assert A(b='hi').b == 'hi'