Pass model_class to schema_extra staticmethod (#1125)

* Pass model_class to schema_extra staticmethod

Resolves #1122

* Add changelog

* Apply suggestions from code review

Co-Authored-By: Samuel Colvin <samcolvin@gmail.com>

* Fix import after rebase

* Fix test bug

* Use TypeError instead of assert as per review

* Rename var so declaration fits one one line

* tiny tweaks

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
This commit is contained in:
John Carter
2020-01-07 01:01:03 +13:00
committed by Samuel Colvin
parent e169bd60e4
commit cd8b504568
4 changed files with 39 additions and 4 deletions
+1
View File
@@ -0,0 +1 @@
Pass model class to the `Config.schema_extra` callable
+2 -2
View File
@@ -1,5 +1,5 @@
# output-json
from typing import Dict, Any
from typing import Dict, Any, Type
from pydantic import BaseModel
class Person(BaseModel):
@@ -8,7 +8,7 @@ class Person(BaseModel):
class Config:
@staticmethod
def schema_extra(schema: Dict[str, Any]) -> None:
def schema_extra(schema: Dict[str, Any], model: Type['Person']) -> None:
for prop in schema.get('properties', {}).values():
prop.pop('title', None)
+8 -2
View File
@@ -5,6 +5,7 @@ from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
@@ -450,7 +451,7 @@ def model_process_schema(
sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All
the definitions are returned as the second value.
"""
from inspect import getdoc
from inspect import getdoc, signature
ref_prefix = ref_prefix or default_prefix
known_models = known_models or set()
@@ -465,7 +466,12 @@ def model_process_schema(
s.update(m_schema)
schema_extra = model.__config__.schema_extra
if callable(schema_extra):
schema_extra(s)
if not isinstance(schema_extra, FunctionType):
raise TypeError(f'{model.__name__}.Config.schema_extra callable is expected to be a staticmethod')
if len(signature(schema_extra).parameters) == 1:
schema_extra(s)
else:
schema_extra(s, model)
else:
s.update(schema_extra)
return s, m_definitions, nested_models
+28
View File
@@ -1490,6 +1490,20 @@ def test_model_with_schema_extra():
def test_model_with_schema_extra_callable():
class Model(BaseModel):
name: str = None
class Config:
@staticmethod
def schema_extra(schema, model_class):
schema.pop('properties')
schema['type'] = 'override'
assert model_class is Model
assert Model.schema() == {'title': 'Model', 'type': 'override'}
def test_model_with_schema_extra_callable_no_model_class():
class Model(BaseModel):
name: str = None
@@ -1502,6 +1516,20 @@ def test_model_with_schema_extra_callable():
assert Model.schema() == {'title': 'Model', 'type': 'override'}
def test_model_with_schema_extra_callable_classmethod_asserts():
class Model(BaseModel):
name: str = None
class Config:
@classmethod
def schema_extra(cls, schema, model_class):
schema.pop('properties')
schema['type'] = 'override'
with pytest.raises(TypeError, match='Model.Config.schema_extra callable is expected to be a staticmethod'):
Model.schema()
def test_model_with_extra_forbidden():
class Model(BaseModel):
a: str