Add optional field argument to __modify_schema__() (#3434)

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
This commit is contained in:
Jaakko Moisio
2021-12-18 22:55:22 +02:00
committed by GitHub
parent f36040a4a3
commit 63337fbadc
5 changed files with 93 additions and 7 deletions
+2
View File
@@ -0,0 +1,2 @@
When generating field schema, pass optional `field` argument (of type
`pydantic.fields.ModelField`) to `__modify_schema__()` if present
+31
View File
@@ -0,0 +1,31 @@
# output-json
from typing import Optional
from pydantic import BaseModel, Field
from pydantic.fields import ModelField
class RestrictedAlphabetStr(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value, field: ModelField):
alphabet = field.field_info.extra['alphabet']
if any(c not in alphabet for c in value):
raise ValueError(f'{value!r} is not restricted to {alphabet!r}')
return cls(value)
@classmethod
def __modify_schema__(cls, field_schema, field: Optional[ModelField]):
if field:
alphabet = field.field_info.extra['alphabet']
field_schema['examples'] = [c * 3 for c in alphabet]
class MyModel(BaseModel):
value: RestrictedAlphabetStr = Field(alphabet='ABC')
print(MyModel.schema_json(indent=2))
+15
View File
@@ -150,6 +150,21 @@ For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used.
Custom field types can customise the schema generated for them using the `__modify_schema__` class method;
see [Custom Data Types](types.md#custom-data-types) for more details.
`__modify_schema__` can also take a `field` argument which will have type `Optional[ModelField]`.
*pydantic* will inspect the signature of `__modify_schema__` to determine whether the `field` argument should be
included.
```py
{!.tmp_examples/schema_with_field.py!}
```
_(This script is complete, it should run "as is")_
Outputs:
```json
{!.tmp_examples/schema_with_field.json!}
```
## JSON Schema Types
Types, custom field types, and constraints (like `max_length`) are mapped to the corresponding spec formats in the
+23 -7
View File
@@ -90,6 +90,19 @@ TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]]
TypeModelSet = Set[TypeModelOrEnum]
def _apply_modify_schema(
modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any]
) -> None:
from inspect import signature
sig = signature(modify_schema)
args = set(sig.parameters.keys())
if 'field' in args or 'kwargs' in args:
modify_schema(field_schema, field=field)
else:
modify_schema(field_schema)
def schema(
models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]],
*,
@@ -335,7 +348,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]:
f_schema.update(field.field_info.extra)
modify_schema = getattr(field.outer_type_, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema
@@ -567,7 +580,7 @@ def field_type_schema(
field_type = field.outer_type_
modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema, definitions, nested_models
@@ -579,6 +592,7 @@ def model_process_schema(
ref_prefix: Optional[str] = None,
ref_template: str = default_ref_template,
known_models: TypeModelSet = None,
field: Optional[ModelField] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Used by ``model_schema()``, you probably should be using that function.
@@ -592,7 +606,7 @@ def model_process_schema(
known_models = known_models or set()
if lenient_issubclass(model, Enum):
model = cast(Type[Enum], model)
s = enum_process_schema(model)
s = enum_process_schema(model, field=field)
return s, {}, set()
model = cast(Type['BaseModel'], model)
s = {'title': model.__config__.title or model.__name__}
@@ -674,7 +688,7 @@ def model_type_schema(
return out_schema, definitions, nested_models
def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:
def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]:
"""
Take a single `enum` and generate its schema.
@@ -695,7 +709,7 @@ def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:
modify_schema = getattr(enum, '__modify_schema__', None)
if modify_schema:
modify_schema(schema_)
_apply_modify_schema(modify_schema, field, schema_)
return schema_
@@ -871,7 +885,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
enum_name = model_name_map[field_type]
f_schema, schema_overrides = get_field_info_schema(field, schema_overrides)
f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides))
definitions[enum_name] = enum_process_schema(field_type)
definitions[enum_name] = enum_process_schema(field_type, field=field)
elif is_namedtuple(field_type):
sub_schema, *_ = model_process_schema(
field_type.__pydantic_model__,
@@ -880,6 +894,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
items_schemas = list(sub_schema['properties'].values())
f_schema.update(
@@ -895,7 +910,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
if f_schema:
return f_schema, definitions, nested_models
@@ -914,6 +929,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
definitions.update(sub_definitions)
definitions[model_name] = sub_schema
+22
View File
@@ -32,6 +32,7 @@ from typing_extensions import Annotated, Literal
from pydantic import BaseModel, Extra, Field, ValidationError, confrozenset, conlist, conset, validator
from pydantic.color import Color
from pydantic.dataclasses import dataclass
from pydantic.fields import ModelField
from pydantic.generics import GenericModel
from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, NameEmail, stricturl
from pydantic.schema import (
@@ -2628,6 +2629,27 @@ def test_complex_nested_generic():
}
def test_schema_with_field_parameter():
class RestrictedAlphabetStr(str):
@classmethod
def __modify_schema__(cls, field_schema, field: Optional[ModelField]):
assert isinstance(field, ModelField)
alphabet = field.field_info.extra['alphabet']
field_schema['examples'] = [c * 3 for c in alphabet]
class MyModel(BaseModel):
value: RestrictedAlphabetStr = Field(alphabet='ABC')
assert MyModel.schema() == {
'title': 'MyModel',
'type': 'object',
'properties': {
'value': {'title': 'Value', 'alphabet': 'ABC', 'examples': ['AAA', 'BBB', 'CCC'], 'type': 'string'}
},
'required': ['value'],
}
def test_discriminated_union():
class BlackCat(BaseModel):
pet_type: Literal['cat']