mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
fix: prevent RecursionError while using recursive GenericModels (#2338)
Co-authored-by: Samuel Colvin <samcolvin@gmail.com> Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
fix: prevent `RecursionError` while using recursive `GenericModel`s
|
||||
+9
-1
@@ -427,7 +427,7 @@ class ModelField(Representation):
|
||||
e.g. calling it it multiple times may modify the field and configure it incorrectly.
|
||||
"""
|
||||
self._set_default_and_type()
|
||||
if self.type_.__class__ == ForwardRef:
|
||||
if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType:
|
||||
# self.type_ is currently a ForwardRef and there's nothing we can do now,
|
||||
# user will need to call model.update_forward_refs()
|
||||
return
|
||||
@@ -676,6 +676,8 @@ class ModelField(Representation):
|
||||
self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None
|
||||
) -> 'ValidateReturn':
|
||||
|
||||
assert self.type_.__class__ is not DeferredType
|
||||
|
||||
if self.type_.__class__ is ForwardRef:
|
||||
assert cls is not None
|
||||
raise ConfigError(
|
||||
@@ -983,3 +985,9 @@ def PrivateAttr(
|
||||
default,
|
||||
default_factory=default_factory,
|
||||
)
|
||||
|
||||
|
||||
class DeferredType:
|
||||
"""
|
||||
Used to postpone field preparation, while creating recursive generic models.
|
||||
"""
|
||||
|
||||
+42
-21
@@ -9,6 +9,7 @@ from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
@@ -19,7 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from .class_validators import gather_all_validators
|
||||
from .fields import FieldInfo, ModelField
|
||||
from .fields import DeferredType
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import display_as_type, get_args, get_origin, typing_base
|
||||
from .utils import all_identical, lenient_issubclass
|
||||
@@ -69,19 +70,15 @@ class GenericModel(BaseModel):
|
||||
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
|
||||
return cls # if arguments are equal to parameters it's the same object
|
||||
|
||||
# Recursively walk class type hints and replace generic typevars
|
||||
# with concrete types that were passed.
|
||||
type_hints = get_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
concrete_type_hints: Dict[str, Type[Any]] = {
|
||||
k: replace_types(v, typevars_map) for k, v in instance_type_hints.items()
|
||||
}
|
||||
|
||||
# Create new model with original model as parent inserting fields with
|
||||
# updated type hints.
|
||||
# Create new model with original model as parent inserting fields with DeferredType.
|
||||
model_name = cls.__concrete_name__(params)
|
||||
validators = gather_all_validators(cls)
|
||||
fields = _build_generic_fields(cls.__fields__, concrete_type_hints)
|
||||
|
||||
type_hints = get_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
|
||||
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
|
||||
|
||||
model_module, called_globally = get_caller_frame_info()
|
||||
created_model = cast(
|
||||
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
|
||||
@@ -121,6 +118,11 @@ class GenericModel(BaseModel):
|
||||
_generic_types_cache[(cls, params)] = created_model
|
||||
if len(params) == 1:
|
||||
_generic_types_cache[(cls, params[0])] = created_model
|
||||
|
||||
# Recursively walk class type hints and replace generic typevars
|
||||
# with concrete types that were passed.
|
||||
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
|
||||
|
||||
return created_model
|
||||
|
||||
@classmethod
|
||||
@@ -140,11 +142,11 @@ class GenericModel(BaseModel):
|
||||
return f'{cls.__name__}[{params_component}]'
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Dict[Any, Any]) -> Any:
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
|
||||
"""Return type with all occurances of `type_map` keys recursively replaced with their values.
|
||||
|
||||
:param type_: Any type, class or generic alias
|
||||
:type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
:param type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
:return: New type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
@@ -218,13 +220,6 @@ def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def _build_generic_fields(
|
||||
raw_fields: Dict[str, ModelField],
|
||||
concrete_type_hints: Dict[str, Type[Any]],
|
||||
) -> Dict[str, Tuple[Type[Any], FieldInfo]]:
|
||||
return {k: (v, raw_fields[k].field_info) for k, v in concrete_type_hints.items() if k in raw_fields}
|
||||
|
||||
|
||||
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Used inside a function to check whether it was called globally
|
||||
@@ -241,3 +236,29 @@ def get_caller_frame_info() -> Tuple[Optional[str], bool]:
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
def _prepare_model_fields(
|
||||
created_model: Type[GenericModel],
|
||||
fields: Mapping[str, Any],
|
||||
instance_type_hints: Mapping[str, type],
|
||||
typevars_map: Mapping[Any, type],
|
||||
) -> None:
|
||||
"""
|
||||
Replace DeferredType fields with concrete type hints and prepare them.
|
||||
"""
|
||||
|
||||
for key, field in created_model.__fields__.items():
|
||||
if key not in fields:
|
||||
assert field.type_.__class__ is not DeferredType
|
||||
# https://github.com/nedbat/coveragepy/issues/198
|
||||
continue # pragma: no cover
|
||||
|
||||
assert field.type_.__class__ is DeferredType, field.type_.__class__
|
||||
|
||||
field_type_hint = instance_type_hints[key]
|
||||
concrete_type = replace_types(field_type_hint, typevars_map)
|
||||
field.type_ = concrete_type
|
||||
field.outer_type_ = concrete_type
|
||||
field.prepare()
|
||||
created_model.__annotations__[key] = concrete_type
|
||||
|
||||
@@ -1015,3 +1015,27 @@ def test_generic_with_partial_callable():
|
||||
Model[str, U].__concrete__ is False
|
||||
Model[str, U].__parameters__ == [U]
|
||||
Model[str, int].__concrete__ is False
|
||||
|
||||
|
||||
@skip_36
|
||||
def test_generic_recursive_models(create_module):
|
||||
@create_module
|
||||
def module():
|
||||
from typing import Generic, TypeVar, Union
|
||||
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class Model1(GenericModel, Generic[T]):
|
||||
ref: 'Model2[T]'
|
||||
|
||||
class Model2(GenericModel, Generic[T]):
|
||||
ref: Union[T, Model1[T]]
|
||||
|
||||
Model1.update_forward_refs()
|
||||
|
||||
Model1 = module.Model1
|
||||
Model2 = module.Model2
|
||||
result = Model1[str].parse_obj(dict(ref=dict(ref=dict(ref=dict(ref=123)))))
|
||||
assert result == Model1(ref=Model2(ref=Model1(ref=Model2(ref='123'))))
|
||||
|
||||
Reference in New Issue
Block a user