diff --git a/changes/3434-jasujm.md b/changes/3434-jasujm.md new file mode 100644 index 0000000..bd93d66 --- /dev/null +++ b/changes/3434-jasujm.md @@ -0,0 +1,2 @@ +When generating field schema, pass optional `field` argument (of type +`pydantic.fields.ModelField`) to `__modify_schema__()` if present diff --git a/docs/examples/schema_with_field.py b/docs/examples/schema_with_field.py new file mode 100644 index 0000000..9288efd --- /dev/null +++ b/docs/examples/schema_with_field.py @@ -0,0 +1,31 @@ +# output-json +from typing import Optional + +from pydantic import BaseModel, Field +from pydantic.fields import ModelField + + +class RestrictedAlphabetStr(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value, field: ModelField): + alphabet = field.field_info.extra['alphabet'] + if any(c not in alphabet for c in value): + raise ValueError(f'{value!r} is not restricted to {alphabet!r}') + return cls(value) + + @classmethod + def __modify_schema__(cls, field_schema, field: Optional[ModelField]): + if field: + alphabet = field.field_info.extra['alphabet'] + field_schema['examples'] = [c * 3 for c in alphabet] + + +class MyModel(BaseModel): + value: RestrictedAlphabetStr = Field(alphabet='ABC') + + +print(MyModel.schema_json(indent=2)) diff --git a/docs/usage/schema.md b/docs/usage/schema.md index 2ee1a63..9928b27 100644 --- a/docs/usage/schema.md +++ b/docs/usage/schema.md @@ -150,6 +150,21 @@ For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used. Custom field types can customise the schema generated for them using the `__modify_schema__` class method; see [Custom Data Types](types.md#custom-data-types) for more details. +`__modify_schema__` can also take a `field` argument which will have type `Optional[ModelField]`. +*pydantic* will inspect the signature of `__modify_schema__` to determine whether the `field` argument should be +included. + +```py +{!.tmp_examples/schema_with_field.py!} +``` +_(This script is complete, it should run "as is")_ + +Outputs: + +```json +{!.tmp_examples/schema_with_field.json!} +``` + ## JSON Schema Types Types, custom field types, and constraints (like `max_length`) are mapped to the corresponding spec formats in the diff --git a/pydantic/schema.py b/pydantic/schema.py index ff6f8af..e979678 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -90,6 +90,19 @@ TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]] TypeModelSet = Set[TypeModelOrEnum] +def _apply_modify_schema( + modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any] +) -> None: + from inspect import signature + + sig = signature(modify_schema) + args = set(sig.parameters.keys()) + if 'field' in args or 'kwargs' in args: + modify_schema(field_schema, field=field) + else: + modify_schema(field_schema) + + def schema( models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]], *, @@ -335,7 +348,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: f_schema.update(field.field_info.extra) modify_schema = getattr(field.outer_type_, '__modify_schema__', None) if modify_schema: - modify_schema(f_schema) + _apply_modify_schema(modify_schema, field, f_schema) return f_schema @@ -567,7 +580,7 @@ def field_type_schema( field_type = field.outer_type_ modify_schema = getattr(field_type, '__modify_schema__', None) if modify_schema: - modify_schema(f_schema) + _apply_modify_schema(modify_schema, field, f_schema) return f_schema, definitions, nested_models @@ -579,6 +592,7 @@ def model_process_schema( ref_prefix: Optional[str] = None, ref_template: str = default_ref_template, known_models: TypeModelSet = None, + field: Optional[ModelField] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ Used by ``model_schema()``, you probably should be using that function. @@ -592,7 +606,7 @@ def model_process_schema( known_models = known_models or set() if lenient_issubclass(model, Enum): model = cast(Type[Enum], model) - s = enum_process_schema(model) + s = enum_process_schema(model, field=field) return s, {}, set() model = cast(Type['BaseModel'], model) s = {'title': model.__config__.title or model.__name__} @@ -674,7 +688,7 @@ def model_type_schema( return out_schema, definitions, nested_models -def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]: +def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]: """ Take a single `enum` and generate its schema. @@ -695,7 +709,7 @@ def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]: modify_schema = getattr(enum, '__modify_schema__', None) if modify_schema: - modify_schema(schema_) + _apply_modify_schema(modify_schema, field, schema_) return schema_ @@ -871,7 +885,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) enum_name = model_name_map[field_type] f_schema, schema_overrides = get_field_info_schema(field, schema_overrides) f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides)) - definitions[enum_name] = enum_process_schema(field_type) + definitions[enum_name] = enum_process_schema(field_type, field=field) elif is_namedtuple(field_type): sub_schema, *_ = model_process_schema( field_type.__pydantic_model__, @@ -880,6 +894,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) ref_prefix=ref_prefix, ref_template=ref_template, known_models=known_models, + field=field, ) items_schemas = list(sub_schema['properties'].values()) f_schema.update( @@ -895,7 +910,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) modify_schema = getattr(field_type, '__modify_schema__', None) if modify_schema: - modify_schema(f_schema) + _apply_modify_schema(modify_schema, field, f_schema) if f_schema: return f_schema, definitions, nested_models @@ -914,6 +929,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) ref_prefix=ref_prefix, ref_template=ref_template, known_models=known_models, + field=field, ) definitions.update(sub_definitions) definitions[model_name] = sub_schema diff --git a/tests/test_schema.py b/tests/test_schema.py index c1e1f01..eb1f081 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -32,6 +32,7 @@ from typing_extensions import Annotated, Literal from pydantic import BaseModel, Extra, Field, ValidationError, confrozenset, conlist, conset, validator from pydantic.color import Color from pydantic.dataclasses import dataclass +from pydantic.fields import ModelField from pydantic.generics import GenericModel from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, NameEmail, stricturl from pydantic.schema import ( @@ -2628,6 +2629,27 @@ def test_complex_nested_generic(): } +def test_schema_with_field_parameter(): + class RestrictedAlphabetStr(str): + @classmethod + def __modify_schema__(cls, field_schema, field: Optional[ModelField]): + assert isinstance(field, ModelField) + alphabet = field.field_info.extra['alphabet'] + field_schema['examples'] = [c * 3 for c in alphabet] + + class MyModel(BaseModel): + value: RestrictedAlphabetStr = Field(alphabet='ABC') + + assert MyModel.schema() == { + 'title': 'MyModel', + 'type': 'object', + 'properties': { + 'value': {'title': 'Value', 'alphabet': 'ABC', 'examples': ['AAA', 'BBB', 'CCC'], 'type': 'string'} + }, + 'required': ['value'], + } + + def test_discriminated_union(): class BlackCat(BaseModel): pet_type: Literal['cat']