From 415eb54f966f0eec7cb2fd0aa61492c17e606367 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Sun, 5 Dec 2021 16:20:48 +0200 Subject: [PATCH] Try to evaluate forward refs after model created (#2588) * Try to evaluate forward refs after model created * Upadate docs and remove code duplication * Update changes/2588-uriyyo.md Co-authored-by: Eric Jolibois * Update docs/usage/postponed_annotations.md Co-authored-by: Eric Jolibois * Remove unused import Co-authored-by: Eric Jolibois --- changes/2588-uriyyo.md | 1 + ...nnotations_self_referencing_annotations.py | 2 - ...ned_annotations_self_referencing_string.py | 2 - docs/usage/postponed_annotations.md | 5 +- pydantic/main.py | 18 ++++-- pydantic/typing.py | 25 ++++++++ tests/test_forward_ref.py | 59 ++++++++++++++++--- 7 files changed, 91 insertions(+), 21 deletions(-) create mode 100644 changes/2588-uriyyo.md diff --git a/changes/2588-uriyyo.md b/changes/2588-uriyyo.md new file mode 100644 index 0000000..938d282 --- /dev/null +++ b/changes/2588-uriyyo.md @@ -0,0 +1 @@ +Try to evaluate forward refs automatically at model creation. \ No newline at end of file diff --git a/docs/examples/postponed_annotations_self_referencing_annotations.py b/docs/examples/postponed_annotations_self_referencing_annotations.py index 902bd62..aa8966c 100644 --- a/docs/examples/postponed_annotations_self_referencing_annotations.py +++ b/docs/examples/postponed_annotations_self_referencing_annotations.py @@ -8,7 +8,5 @@ class Foo(BaseModel): sibling: Foo = None -Foo.update_forward_refs() - print(Foo()) print(Foo(sibling={'a': '321'})) diff --git a/docs/examples/postponed_annotations_self_referencing_string.py b/docs/examples/postponed_annotations_self_referencing_string.py index 2c918a2..3290051 100644 --- a/docs/examples/postponed_annotations_self_referencing_string.py +++ b/docs/examples/postponed_annotations_self_referencing_string.py @@ -7,7 +7,5 @@ class Foo(BaseModel): sibling: 'Foo' = None -Foo.update_forward_refs() - print(Foo()) print(Foo(sibling={'a': '321'})) diff --git a/docs/usage/postponed_annotations.md b/docs/usage/postponed_annotations.md index cc44588..8e975e2 100644 --- a/docs/usage/postponed_annotations.md +++ b/docs/usage/postponed_annotations.md @@ -45,9 +45,8 @@ Resolving this is beyond the call for *pydantic*: either remove the future impor ## Self-referencing Models -Data structures with self-referencing models are also supported, provided the function -`update_forward_refs()` is called once the model is created (you will be reminded -with a friendly error message if you forget). +Data structures with self-referencing models are also supported. Self-referencing fields will be automatically +resolved after model creation. Within the model, you can refer to the not-yet-constructed model using a string: diff --git a/pydantic/main.py b/pydantic/main.py index ff63969..cf900e3 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1,4 +1,3 @@ -import sys import warnings from abc import ABCMeta from copy import deepcopy @@ -42,7 +41,7 @@ from .typing import ( is_namedtuple, is_union_origin, resolve_annotations, - update_field_forward_refs, + update_model_forward_refs, ) from .utils import ( ROOT_KEY, @@ -289,6 +288,8 @@ class ModelMetaclass(ABCMeta): cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) # set __signature__ attr only for model class, but not for its instances cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config)) + cls.__try_update_forward_refs__() + return cls @@ -746,15 +747,20 @@ class BaseModel(Representation, metaclass=ModelMetaclass): else: return v + @classmethod + def __try_update_forward_refs__(cls) -> None: + """ + Same as update_forward_refs but will not raise exception + when forward references are not defined. + """ + update_model_forward_refs(cls, cls.__fields__.values(), {}, (NameError,)) + @classmethod def update_forward_refs(cls, **localns: Any) -> None: """ Try to update ForwardRefs on fields based on this Model, globalns and localns. """ - globalns = sys.modules[cls.__module__].__dict__.copy() - globalns.setdefault(cls.__name__, cls) - for f in cls.__fields__.values(): - update_field_forward_refs(f, globalns=globalns, localns=localns) + update_model_forward_refs(cls, cls.__fields__.values(), localns) def __iter__(self) -> 'TupleGenerator': """ diff --git a/pydantic/typing.py b/pydantic/typing.py index 5797432..9b28dd0 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -7,6 +7,7 @@ from typing import ( # type: ignore ClassVar, Dict, Generator, + Iterable, List, Mapping, NewType, @@ -251,6 +252,7 @@ __all__ = ( 'new_type_supertype', 'is_classvar', 'update_field_forward_refs', + 'update_model_forward_refs', 'TupleGenerator', 'DictStrAny', 'DictAny', @@ -441,6 +443,29 @@ def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) update_field_forward_refs(sub_f, globalns=globalns, localns=localns) +def update_model_forward_refs( + model: Type[Any], + fields: Iterable['ModelField'], + localns: 'DictStrAny', + exc_to_suppress: Tuple[Type[BaseException], ...] = (), +) -> None: + """ + Try to update model fields ForwardRefs based on model and localns. + """ + if model.__module__ in sys.modules: + globalns = sys.modules[model.__module__].__dict__.copy() + else: + globalns = {} + + globalns.setdefault(model.__name__, model) + + for f in fields: + try: + update_field_forward_refs(f, globalns=globalns, localns=localns) + except exc_to_suppress: + pass + + def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]: """ Tries to get the class of a Type[T] annotation. Returns True if Type is used diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 692a33d..12a337f 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -41,6 +41,57 @@ class Model(BaseModel): assert module.Model().dict() == {'a': None} +@skip_pre_37 +def test_postponed_annotations_auto_update_forward_refs(create_module): + module = create_module( + # language=Python + """ +from __future__ import annotations +from pydantic import BaseModel + +class Model(BaseModel): + a: Model +""" + ) + + assert module.Model.__fields__['a'].type_ is module.Model + + +def test_forward_ref_auto_update_no_model(create_module): + module = create_module( + # language=Python + """ +from pydantic import BaseModel + +class Foo(BaseModel): + a: 'Bar' + +class Bar(BaseModel): + b: 'Foo' +""" + ) + + from pydantic.typing import ForwardRef + + assert module.Foo.__fields__['a'].type_ == ForwardRef('Bar') + assert module.Bar.__fields__['b'].type_ is module.Foo + + +def test_forward_ref_one_of_fields_not_defined(create_module): + @create_module + def module(): + from pydantic import BaseModel + + class Foo(BaseModel): + foo: 'Foo' + bar: 'Bar' # noqa: F821 + + from pydantic.typing import ForwardRef + + assert module.Foo.__fields__['bar'].type_ == ForwardRef('Bar') + assert module.Foo.__fields__['foo'].type_ is module.Foo + + def test_basic_forward_ref(create_module): @create_module def module(): @@ -509,14 +560,6 @@ def test_nested_forward_ref(): class NestedTuple(BaseModel): x: Tuple[int, Optional['NestedTuple']] # noqa: F821 - with pytest.raises(ConfigError) as exc_info: - NestedTuple.parse_obj({'x': ('1', {'x': ('2', {'x': ('3', None)})})}) - assert str(exc_info.value) == ( - 'field "x_1" not yet prepared so type is still a ForwardRef, ' - 'you might need to call NestedTuple.update_forward_refs().' - ) - - NestedTuple.update_forward_refs() obj = NestedTuple.parse_obj({'x': ('1', {'x': ('2', {'x': ('3', None)})})}) assert obj.dict() == {'x': (1, {'x': (2, {'x': (3, None)})})}