Support extracting attributes from the parent namespace (#4663)

* support extracting attributes from the parent namespace

* fix more tested test

* move parent_frame_namespace
This commit is contained in:
Samuel Colvin
2022-11-08 10:26:24 +00:00
committed by GitHub
parent 8d98c499df
commit a7656fcd09
3 changed files with 128 additions and 1 deletions
+22
View File
@@ -27,6 +27,7 @@ __all__ = (
'origin_is_union',
'NotRequired',
'Required',
'parent_frame_namespace',
'get_type_hints',
)
@@ -191,6 +192,27 @@ def is_finalvar(ann_type: type[Any]) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
"""
We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
and suggestion at the end of the next comment by @gvanrossum.
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
parent of where it is called.
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
"""
frame = sys._getframe(parent_depth)
# if f_back is None, it's the global module namespace and we don't need to include it here
if frame.f_back is None:
return None
else:
return frame.f_locals
if sys.version_info >= (3, 10): # noqa C901
get_type_hints = typing.get_type_hints
+15 -1
View File
@@ -92,7 +92,15 @@ class ModelMetaclass(ABCMeta):
namespace['__hash__'] = hash_func
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
_model_construction.complete_model_class(cls, cls_name, validator_functions, bases, raise_errors=False)
_model_construction.complete_model_class(
cls,
cls_name,
validator_functions,
bases,
types_namespace=_typing_extra.parent_frame_namespace(),
raise_errors=False,
)
return cls
else:
# this is the BaseModel class itself being created, no logic required
@@ -452,6 +460,12 @@ class BaseModel(_repr.Representation, metaclass=ModelMetaclass):
if not force and cls.__pydantic_model_complete__:
return None
else:
parents_namespace = _typing_extra.parent_frame_namespace()
if types_namespace and parents_namespace:
types_namespace = {**parents_namespace, **types_namespace}
elif parents_namespace:
types_namespace = parents_namespace
return _model_construction.complete_model_class(
cls,
cls.__name__,
+91
View File
@@ -749,3 +749,94 @@ def test_force_rebuild():
assert Foobar.__pydantic_model_complete__ is True
assert Foobar.model_rebuild() is None
assert Foobar.model_rebuild(force=True) is True
def test_nested_annotation(create_module):
module = create_module(
# language=Python
"""
from __future__ import annotations
from pydantic import BaseModel
def nested():
class Foo(BaseModel):
a: int
class Bar(BaseModel):
b: Foo
return Bar
"""
)
bar_model = module.nested()
assert bar_model.__pydantic_model_complete__ is True
assert bar_model(b={'a': 1}).dict() == {'b': {'a': 1}}
def test_nested_more_annotation(create_module):
@create_module
def module():
from pydantic import BaseModel
def nested():
class Foo(BaseModel):
a: int
def more_nested():
class Bar(BaseModel):
b: 'Foo'
return Bar
return more_nested()
bar_model = module.nested()
# this does not work because Foo is in a parent scope
assert bar_model.__pydantic_model_complete__ is False
def test_nested_annotation_priority(create_module):
@create_module
def module():
from annotated_types import Gt
from typing_extensions import Annotated
from pydantic import BaseModel
Foobar = Annotated[int, Gt(0)] # noqa: F841
def nested():
Foobar = Annotated[int, Gt(10)] # noqa: F841
class Bar(BaseModel):
b: 'Foobar'
return Bar
bar_model = module.nested()
assert bar_model.__fields__['b'].metadata[0].gt == 10
assert bar_model(b=11).dict() == {'b': 11}
with pytest.raises(ValidationError, match=r'Input should be greater than 10 \[type=greater_than,'):
bar_model(b=1)
def test_nested_model_rebuild(create_module):
@create_module
def module():
from pydantic import BaseModel
def nested():
class Bar(BaseModel):
b: 'Foo'
class Foo(BaseModel):
a: int
assert Bar.__pydantic_model_complete__ is False
Bar.model_rebuild()
return Bar
bar_model = module.nested()
assert bar_model.__pydantic_model_complete__ is True
assert bar_model(b={'a': 1}).dict() == {'b': {'a': 1}}