diff --git a/HISTORY.rst b/HISTORY.rst index 05d0af0..02a7bf4 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -8,6 +8,7 @@ v0.23 (unreleased) * improve documentation for contributing section, # 441 @pilosus * improve README.rst to include essential information about the package, #446 by @pilosus * ``IntEnum`` support, #444 by @potykion +* fix PyObject callable value, #409 by @pilosus * fix ``ForwardRef`` collection bug, #450 by @tigerwings * Support specialized ``ClassVars``, #455 by @tyrylu diff --git a/pydantic/errors.py b/pydantic/errors.py index c3236da..e816d47 100644 --- a/pydantic/errors.py +++ b/pydantic/errors.py @@ -114,7 +114,7 @@ class PathNotADirectoryError(_PathValueError): class PyObjectError(PydanticTypeError): - msg_template = 'ensure this value contains valid import path: {error_message}' + msg_template = 'ensure this value contains valid import path or valid callable: {error_message}' class SequenceError(PydanticTypeError): diff --git a/pydantic/main.py b/pydantic/main.py index 164d776..ab536f1 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -31,7 +31,7 @@ from .fields import Field from .json import custom_pydantic_encoder, pydantic_encoder from .parse import Protocol, load_file, load_str_bytes from .schema import model_schema -from .types import StrBytes +from .types import PyObject, StrBytes from .utils import ( AnyCallable, AnyType, @@ -180,7 +180,11 @@ class MetaModel(ABCMeta): ) for var_name, value in namespace.items(): - if not var_name.startswith('_') and not isinstance(value, TYPE_BLACKLIST) and var_name not in class_vars: + if ( + not var_name.startswith('_') + and (annotations.get(var_name) == PyObject or not isinstance(value, TYPE_BLACKLIST)) + and var_name not in class_vars + ): validate_field_name(bases, var_name) fields[var_name] = Field.infer( name=var_name, diff --git a/pydantic/types.py b/pydantic/types.py index 1a05912..ffe8995 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -12,7 +12,7 @@ from ipaddress import ( _BaseNetwork, ) from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Pattern, Set, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Optional, Pattern, Set, Tuple, Type, Union, cast from uuid import UUID from . import errors @@ -268,11 +268,18 @@ class PyObject: @classmethod def __get_validators__(cls) -> 'CallableGenerator': - yield str_validator yield cls.validate @classmethod def validate(cls, value: Any) -> Any: + if isinstance(value, Callable): # type: ignore + return value + + try: + value = str_validator(value) + except errors.StrError: + raise errors.PyObjectError(error_message='value is neither a valid import path not a valid callable') + if value is not None: try: return import_string(value) diff --git a/tests/test_types.py b/tests/test_types.py index 9c404f3..c01d3e7 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -48,6 +48,10 @@ class ConBytesModel(BaseModel): v: conbytes(max_length=10) = b'foobar' +def foo(): + return 42 + + def test_constrained_bytes_good(): m = ConBytesModel(v=b'short') assert m.v == b'short' @@ -145,7 +149,8 @@ def test_module_import(): assert exc_info.value.errors() == [ { 'loc': ('module',), - 'msg': 'ensure this value contains valid import path: ' '"foobar" doesn\'t look like a module path', + 'msg': 'ensure this value contains valid import path or valid callable: ' + '"foobar" doesn\'t look like a module path', 'type': 'type_error.pyobject', 'ctx': {'error_message': '"foobar" doesn\'t look like a module path'}, } @@ -156,12 +161,25 @@ def test_module_import(): assert exc_info.value.errors() == [ { 'loc': ('module',), - 'msg': 'ensure this value contains valid import path: ' 'Module "os" does not define a "missing" attribute', + 'msg': 'ensure this value contains valid import path or valid callable: ' + 'Module "os" does not define a "missing" attribute', 'type': 'type_error.pyobject', 'ctx': {'error_message': 'Module "os" does not define a "missing" attribute'}, } ] + with pytest.raises(ValidationError) as exc_info: + PyObjectModel(module=[1, 2, 3]) + assert exc_info.value.errors() == [ + { + 'loc': ('module',), + 'msg': 'ensure this value contains valid import path or valid callable: ' + 'value is neither a valid import path not a valid callable', + 'type': 'type_error.pyobject', + 'ctx': {'error_message': 'value is neither a valid import path not a valid callable'}, + } + ] + def test_pyobject_none(): class PyObjectModel(BaseModel): @@ -171,6 +189,15 @@ def test_pyobject_none(): assert m.module is None +def test_pyobject_callable(): + class PyObjectModel(BaseModel): + foo: PyObject = foo + + m = PyObjectModel() + assert m.foo is foo + assert m.foo() == 42 + + class CheckModel(BaseModel): bool_check = True str_check = 's'