mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
Pass model_class to schema_extra staticmethod (#1125)
* Pass model_class to schema_extra staticmethod Resolves #1122 * Add changelog * Apply suggestions from code review Co-Authored-By: Samuel Colvin <samcolvin@gmail.com> * Fix import after rebase * Fix test bug * Use TypeError instead of assert as per review * Rename var so declaration fits one one line * tiny tweaks Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
This commit is contained in:
committed by
Samuel Colvin
parent
e169bd60e4
commit
cd8b504568
@@ -0,0 +1 @@
|
||||
Pass model class to the `Config.schema_extra` callable
|
||||
@@ -1,5 +1,5 @@
|
||||
# output-json
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Person(BaseModel):
|
||||
@@ -8,7 +8,7 @@ class Person(BaseModel):
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: Dict[str, Any]) -> None:
|
||||
def schema_extra(schema: Dict[str, Any], model: Type['Person']) -> None:
|
||||
for prop in schema.get('properties', {}).values():
|
||||
prop.pop('title', None)
|
||||
|
||||
|
||||
+8
-2
@@ -5,6 +5,7 @@ from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -450,7 +451,7 @@ def model_process_schema(
|
||||
sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All
|
||||
the definitions are returned as the second value.
|
||||
"""
|
||||
from inspect import getdoc
|
||||
from inspect import getdoc, signature
|
||||
|
||||
ref_prefix = ref_prefix or default_prefix
|
||||
known_models = known_models or set()
|
||||
@@ -465,7 +466,12 @@ def model_process_schema(
|
||||
s.update(m_schema)
|
||||
schema_extra = model.__config__.schema_extra
|
||||
if callable(schema_extra):
|
||||
schema_extra(s)
|
||||
if not isinstance(schema_extra, FunctionType):
|
||||
raise TypeError(f'{model.__name__}.Config.schema_extra callable is expected to be a staticmethod')
|
||||
if len(signature(schema_extra).parameters) == 1:
|
||||
schema_extra(s)
|
||||
else:
|
||||
schema_extra(s, model)
|
||||
else:
|
||||
s.update(schema_extra)
|
||||
return s, m_definitions, nested_models
|
||||
|
||||
@@ -1490,6 +1490,20 @@ def test_model_with_schema_extra():
|
||||
|
||||
|
||||
def test_model_with_schema_extra_callable():
|
||||
class Model(BaseModel):
|
||||
name: str = None
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema, model_class):
|
||||
schema.pop('properties')
|
||||
schema['type'] = 'override'
|
||||
assert model_class is Model
|
||||
|
||||
assert Model.schema() == {'title': 'Model', 'type': 'override'}
|
||||
|
||||
|
||||
def test_model_with_schema_extra_callable_no_model_class():
|
||||
class Model(BaseModel):
|
||||
name: str = None
|
||||
|
||||
@@ -1502,6 +1516,20 @@ def test_model_with_schema_extra_callable():
|
||||
assert Model.schema() == {'title': 'Model', 'type': 'override'}
|
||||
|
||||
|
||||
def test_model_with_schema_extra_callable_classmethod_asserts():
|
||||
class Model(BaseModel):
|
||||
name: str = None
|
||||
|
||||
class Config:
|
||||
@classmethod
|
||||
def schema_extra(cls, schema, model_class):
|
||||
schema.pop('properties')
|
||||
schema['type'] = 'override'
|
||||
|
||||
with pytest.raises(TypeError, match='Model.Config.schema_extra callable is expected to be a staticmethod'):
|
||||
Model.schema()
|
||||
|
||||
|
||||
def test_model_with_extra_forbidden():
|
||||
class Model(BaseModel):
|
||||
a: str
|
||||
|
||||
Reference in New Issue
Block a user