From 8f0980e9826517246c636e409bbffe4fa07bf28f Mon Sep 17 00:00:00 2001 From: xppt <21246102+xppt@users.noreply.github.com> Date: Fri, 26 Feb 2021 13:30:12 +0300 Subject: [PATCH] fix: prevent `RecursionError` while using recursive `GenericModel`s (#2338) Co-authored-by: Samuel Colvin Co-authored-by: Samuel Colvin --- changes/1370-xppt.md | 1 + pydantic/fields.py | 10 ++++++- pydantic/generics.py | 63 ++++++++++++++++++++++++++++-------------- tests/test_generics.py | 24 ++++++++++++++++ 4 files changed, 76 insertions(+), 22 deletions(-) create mode 100644 changes/1370-xppt.md diff --git a/changes/1370-xppt.md b/changes/1370-xppt.md new file mode 100644 index 0000000..095c10d --- /dev/null +++ b/changes/1370-xppt.md @@ -0,0 +1 @@ +fix: prevent `RecursionError` while using recursive `GenericModel`s diff --git a/pydantic/fields.py b/pydantic/fields.py index 79c49a9..4db34a2 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -427,7 +427,7 @@ class ModelField(Representation): e.g. calling it it multiple times may modify the field and configure it incorrectly. """ self._set_default_and_type() - if self.type_.__class__ == ForwardRef: + if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType: # self.type_ is currently a ForwardRef and there's nothing we can do now, # user will need to call model.update_forward_refs() return @@ -676,6 +676,8 @@ class ModelField(Representation): self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None ) -> 'ValidateReturn': + assert self.type_.__class__ is not DeferredType + if self.type_.__class__ is ForwardRef: assert cls is not None raise ConfigError( @@ -983,3 +985,9 @@ def PrivateAttr( default, default_factory=default_factory, ) + + +class DeferredType: + """ + Used to postpone field preparation, while creating recursive generic models. + """ diff --git a/pydantic/generics.py b/pydantic/generics.py index be09b50..eeebeb0 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -9,6 +9,7 @@ from typing import ( Iterable, Iterator, List, + Mapping, Optional, Tuple, Type, @@ -19,7 +20,7 @@ from typing import ( ) from .class_validators import gather_all_validators -from .fields import FieldInfo, ModelField +from .fields import DeferredType from .main import BaseModel, create_model from .typing import display_as_type, get_args, get_origin, typing_base from .utils import all_identical, lenient_issubclass @@ -69,19 +70,15 @@ class GenericModel(BaseModel): if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: return cls # if arguments are equal to parameters it's the same object - # Recursively walk class type hints and replace generic typevars - # with concrete types that were passed. - type_hints = get_type_hints(cls).items() - instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} - concrete_type_hints: Dict[str, Type[Any]] = { - k: replace_types(v, typevars_map) for k, v in instance_type_hints.items() - } - - # Create new model with original model as parent inserting fields with - # updated type hints. + # Create new model with original model as parent inserting fields with DeferredType. model_name = cls.__concrete_name__(params) validators = gather_all_validators(cls) - fields = _build_generic_fields(cls.__fields__, concrete_type_hints) + + type_hints = get_type_hints(cls).items() + instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} + + fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__} + model_module, called_globally = get_caller_frame_info() created_model = cast( Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes @@ -121,6 +118,11 @@ class GenericModel(BaseModel): _generic_types_cache[(cls, params)] = created_model if len(params) == 1: _generic_types_cache[(cls, params[0])] = created_model + + # Recursively walk class type hints and replace generic typevars + # with concrete types that were passed. + _prepare_model_fields(created_model, fields, instance_type_hints, typevars_map) + return created_model @classmethod @@ -140,11 +142,11 @@ class GenericModel(BaseModel): return f'{cls.__name__}[{params_component}]' -def replace_types(type_: Any, type_map: Dict[Any, Any]) -> Any: +def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: """Return type with all occurances of `type_map` keys recursively replaced with their values. :param type_: Any type, class or generic alias - :type_map: Mapping from `TypeVar` instance to concrete types. + :param type_map: Mapping from `TypeVar` instance to concrete types. :return: New type representing the basic structure of `type_` with all `typevar_map` keys recursively replaced. @@ -218,13 +220,6 @@ def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: yield from iter_contained_typevars(arg) -def _build_generic_fields( - raw_fields: Dict[str, ModelField], - concrete_type_hints: Dict[str, Type[Any]], -) -> Dict[str, Tuple[Type[Any], FieldInfo]]: - return {k: (v, raw_fields[k].field_info) for k, v in concrete_type_hints.items() if k in raw_fields} - - def get_caller_frame_info() -> Tuple[Optional[str], bool]: """ Used inside a function to check whether it was called globally @@ -241,3 +236,29 @@ def get_caller_frame_info() -> Tuple[Optional[str], bool]: return None, False frame_globals = previous_caller_frame.f_globals return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals + + +def _prepare_model_fields( + created_model: Type[GenericModel], + fields: Mapping[str, Any], + instance_type_hints: Mapping[str, type], + typevars_map: Mapping[Any, type], +) -> None: + """ + Replace DeferredType fields with concrete type hints and prepare them. + """ + + for key, field in created_model.__fields__.items(): + if key not in fields: + assert field.type_.__class__ is not DeferredType + # https://github.com/nedbat/coveragepy/issues/198 + continue # pragma: no cover + + assert field.type_.__class__ is DeferredType, field.type_.__class__ + + field_type_hint = instance_type_hints[key] + concrete_type = replace_types(field_type_hint, typevars_map) + field.type_ = concrete_type + field.outer_type_ = concrete_type + field.prepare() + created_model.__annotations__[key] = concrete_type diff --git a/tests/test_generics.py b/tests/test_generics.py index 07847e4..4728188 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1015,3 +1015,27 @@ def test_generic_with_partial_callable(): Model[str, U].__concrete__ is False Model[str, U].__parameters__ == [U] Model[str, int].__concrete__ is False + + +@skip_36 +def test_generic_recursive_models(create_module): + @create_module + def module(): + from typing import Generic, TypeVar, Union + + from pydantic.generics import GenericModel + + T = TypeVar('T') + + class Model1(GenericModel, Generic[T]): + ref: 'Model2[T]' + + class Model2(GenericModel, Generic[T]): + ref: Union[T, Model1[T]] + + Model1.update_forward_refs() + + Model1 = module.Model1 + Model2 = module.Model2 + result = Model1[str].parse_obj(dict(ref=dict(ref=dict(ref=dict(ref=123))))) + assert result == Model1(ref=Model2(ref=Model1(ref=Model2(ref='123'))))