mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
Try to evaluate forward refs after model created (#2588)
* Try to evaluate forward refs after model created * Upadate docs and remove code duplication * Update changes/2588-uriyyo.md Co-authored-by: Eric Jolibois <em.jolibois@gmail.com> * Update docs/usage/postponed_annotations.md Co-authored-by: Eric Jolibois <em.jolibois@gmail.com> * Remove unused import Co-authored-by: Eric Jolibois <em.jolibois@gmail.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
Try to evaluate forward refs automatically at model creation.
|
||||
@@ -8,7 +8,5 @@ class Foo(BaseModel):
|
||||
sibling: Foo = None
|
||||
|
||||
|
||||
Foo.update_forward_refs()
|
||||
|
||||
print(Foo())
|
||||
print(Foo(sibling={'a': '321'}))
|
||||
|
||||
@@ -7,7 +7,5 @@ class Foo(BaseModel):
|
||||
sibling: 'Foo' = None
|
||||
|
||||
|
||||
Foo.update_forward_refs()
|
||||
|
||||
print(Foo())
|
||||
print(Foo(sibling={'a': '321'}))
|
||||
|
||||
@@ -45,9 +45,8 @@ Resolving this is beyond the call for *pydantic*: either remove the future impor
|
||||
|
||||
## Self-referencing Models
|
||||
|
||||
Data structures with self-referencing models are also supported, provided the function
|
||||
`update_forward_refs()` is called once the model is created (you will be reminded
|
||||
with a friendly error message if you forget).
|
||||
Data structures with self-referencing models are also supported. Self-referencing fields will be automatically
|
||||
resolved after model creation.
|
||||
|
||||
Within the model, you can refer to the not-yet-constructed model using a string:
|
||||
|
||||
|
||||
+12
-6
@@ -1,4 +1,3 @@
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABCMeta
|
||||
from copy import deepcopy
|
||||
@@ -42,7 +41,7 @@ from .typing import (
|
||||
is_namedtuple,
|
||||
is_union_origin,
|
||||
resolve_annotations,
|
||||
update_field_forward_refs,
|
||||
update_model_forward_refs,
|
||||
)
|
||||
from .utils import (
|
||||
ROOT_KEY,
|
||||
@@ -289,6 +288,8 @@ class ModelMetaclass(ABCMeta):
|
||||
cls = super().__new__(mcs, name, bases, new_namespace, **kwargs)
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config))
|
||||
cls.__try_update_forward_refs__()
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@@ -746,15 +747,20 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
|
||||
else:
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __try_update_forward_refs__(cls) -> None:
|
||||
"""
|
||||
Same as update_forward_refs but will not raise exception
|
||||
when forward references are not defined.
|
||||
"""
|
||||
update_model_forward_refs(cls, cls.__fields__.values(), {}, (NameError,))
|
||||
|
||||
@classmethod
|
||||
def update_forward_refs(cls, **localns: Any) -> None:
|
||||
"""
|
||||
Try to update ForwardRefs on fields based on this Model, globalns and localns.
|
||||
"""
|
||||
globalns = sys.modules[cls.__module__].__dict__.copy()
|
||||
globalns.setdefault(cls.__name__, cls)
|
||||
for f in cls.__fields__.values():
|
||||
update_field_forward_refs(f, globalns=globalns, localns=localns)
|
||||
update_model_forward_refs(cls, cls.__fields__.values(), localns)
|
||||
|
||||
def __iter__(self) -> 'TupleGenerator':
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ( # type: ignore
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NewType,
|
||||
@@ -251,6 +252,7 @@ __all__ = (
|
||||
'new_type_supertype',
|
||||
'is_classvar',
|
||||
'update_field_forward_refs',
|
||||
'update_model_forward_refs',
|
||||
'TupleGenerator',
|
||||
'DictStrAny',
|
||||
'DictAny',
|
||||
@@ -441,6 +443,29 @@ def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any)
|
||||
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
|
||||
|
||||
|
||||
def update_model_forward_refs(
|
||||
model: Type[Any],
|
||||
fields: Iterable['ModelField'],
|
||||
localns: 'DictStrAny',
|
||||
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Try to update model fields ForwardRefs based on model and localns.
|
||||
"""
|
||||
if model.__module__ in sys.modules:
|
||||
globalns = sys.modules[model.__module__].__dict__.copy()
|
||||
else:
|
||||
globalns = {}
|
||||
|
||||
globalns.setdefault(model.__name__, model)
|
||||
|
||||
for f in fields:
|
||||
try:
|
||||
update_field_forward_refs(f, globalns=globalns, localns=localns)
|
||||
except exc_to_suppress:
|
||||
pass
|
||||
|
||||
|
||||
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
|
||||
"""
|
||||
Tries to get the class of a Type[T] annotation. Returns True if Type is used
|
||||
|
||||
@@ -41,6 +41,57 @@ class Model(BaseModel):
|
||||
assert module.Model().dict() == {'a': None}
|
||||
|
||||
|
||||
@skip_pre_37
|
||||
def test_postponed_annotations_auto_update_forward_refs(create_module):
|
||||
module = create_module(
|
||||
# language=Python
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Model(BaseModel):
|
||||
a: Model
|
||||
"""
|
||||
)
|
||||
|
||||
assert module.Model.__fields__['a'].type_ is module.Model
|
||||
|
||||
|
||||
def test_forward_ref_auto_update_no_model(create_module):
|
||||
module = create_module(
|
||||
# language=Python
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
a: 'Bar'
|
||||
|
||||
class Bar(BaseModel):
|
||||
b: 'Foo'
|
||||
"""
|
||||
)
|
||||
|
||||
from pydantic.typing import ForwardRef
|
||||
|
||||
assert module.Foo.__fields__['a'].type_ == ForwardRef('Bar')
|
||||
assert module.Bar.__fields__['b'].type_ is module.Foo
|
||||
|
||||
|
||||
def test_forward_ref_one_of_fields_not_defined(create_module):
|
||||
@create_module
|
||||
def module():
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
foo: 'Foo'
|
||||
bar: 'Bar' # noqa: F821
|
||||
|
||||
from pydantic.typing import ForwardRef
|
||||
|
||||
assert module.Foo.__fields__['bar'].type_ == ForwardRef('Bar')
|
||||
assert module.Foo.__fields__['foo'].type_ is module.Foo
|
||||
|
||||
|
||||
def test_basic_forward_ref(create_module):
|
||||
@create_module
|
||||
def module():
|
||||
@@ -509,14 +560,6 @@ def test_nested_forward_ref():
|
||||
class NestedTuple(BaseModel):
|
||||
x: Tuple[int, Optional['NestedTuple']] # noqa: F821
|
||||
|
||||
with pytest.raises(ConfigError) as exc_info:
|
||||
NestedTuple.parse_obj({'x': ('1', {'x': ('2', {'x': ('3', None)})})})
|
||||
assert str(exc_info.value) == (
|
||||
'field "x_1" not yet prepared so type is still a ForwardRef, '
|
||||
'you might need to call NestedTuple.update_forward_refs().'
|
||||
)
|
||||
|
||||
NestedTuple.update_forward_refs()
|
||||
obj = NestedTuple.parse_obj({'x': ('1', {'x': ('2', {'x': ('3', None)})})})
|
||||
assert obj.dict() == {'x': (1, {'x': (2, {'x': (3, None)})})}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user