From 943a8a06e543cc43e04b98ef1cd0666cb04c35f0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 24 Jan 2020 10:31:16 +0000 Subject: [PATCH] change alias priority logic (#1178) * fix alias priority so alias_generators don't take priority * improve test names * remove debugs * Apply suggestions from code review * more tests and allow custom alias_priority on fields * precedence tests and docs * tweaks and add change * suggestions --- changes/1178-samuelcolvin.md | 3 + changes/make_history.py | 8 +- .../examples/model_config_alias_precedence.py | 22 +- docs/usage/model_config.md | 23 +- pydantic/fields.py | 12 +- pydantic/main.py | 19 +- tests/test_aliases.py | 356 ++++++++++++++++++ tests/test_edge_cases.py | 160 -------- tests/test_main.py | 52 --- tests/test_schema.py | 22 +- 10 files changed, 436 insertions(+), 241 deletions(-) create mode 100644 changes/1178-samuelcolvin.md create mode 100644 tests/test_aliases.py diff --git a/changes/1178-samuelcolvin.md b/changes/1178-samuelcolvin.md new file mode 100644 index 0000000..2721a3d --- /dev/null +++ b/changes/1178-samuelcolvin.md @@ -0,0 +1,3 @@ +**Breaking Change:** alias precedence logic changed so aliases on a field always take priority over +an alias from `alias_generator` to avoid buggy/unexpected behaviour, +see [here](https://pydantic-docs.helpmanual.io/usage/model_config/#alias-precedence) for details diff --git a/changes/make_history.py b/changes/make_history.py index 3f92621..34d1212 100755 --- a/changes/make_history.py +++ b/changes/make_history.py @@ -20,7 +20,11 @@ for p in THIS_DIR.glob('*.md'): if '\n\n' in content: raise RuntimeError(f'{p.name!r}: content includes multiple paragraphs') content = content.replace('\n', '\n ') - priority = 0 if '**breaking change' in content.lower() else 1 + priority = 0 + if '**breaking change' in content.lower(): + priority = 2 + elif content.startswith('**'): + priority = 1 bullet_list.append((priority, int(gh_id), f'* {content}, #{gh_id} by @{creator}')) if not bullet_list: @@ -29,7 +33,7 @@ if not bullet_list: version = SourceFileLoader('version', 'pydantic/version.py').load_module() chunk_title = f'v{version.VERSION} ({date.today():%Y-%m-%d})' -new_chunk = '## {}\n\n{}\n\n'.format(chunk_title, '\n'.join(c for *_, c in sorted(bullet_list))) +new_chunk = '## {}\n\n{}\n\n'.format(chunk_title, '\n'.join(c for *_, c in sorted(bullet_list, reverse=True))) print(f'{chunk_title}...{len(bullet_list)} items') history_path = THIS_DIR / '..' / 'HISTORY.md' diff --git a/docs/examples/model_config_alias_precedence.py b/docs/examples/model_config_alias_precedence.py index a6407f9..3babb24 100644 --- a/docs/examples/model_config_alias_precedence.py +++ b/docs/examples/model_config_alias_precedence.py @@ -1,21 +1,19 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class Voice(BaseModel): - name: str - language_code: str + name: str = Field(None, alias='ActorName') + language_code: str = None + mood: str = None + +class Character(Voice): + act: int = 1 class Config: + fields = {'language_code': 'lang'} + @classmethod def alias_generator(cls, string: str) -> str: # this is the same as `alias_generator = to_camel` above return ''.join(word.capitalize() for word in string.split('_')) -class Character(Voice): - mood: str - - class Config: - fields = {'mood': 'Mood', 'language_code': 'lang'} - -c = Character(Mood='happy', Name='Filiz', lang='tr-TR') -print(c) -print(c.dict(by_alias=True)) +print(Character.schema(by_alias=True)) diff --git a/docs/usage/model_config.md b/docs/usage/model_config.md index 3fd7e46..dacd4dd 100644 --- a/docs/usage/model_config.md +++ b/docs/usage/model_config.md @@ -107,12 +107,27 @@ it should be trivial to modify the `to_camel` function above. ## Alias Precedence -Aliases defined on the `Config` class of child models will take priority over any aliases defined on `Config` of a -parent model: +!!! warning + Alias priority logic changed in **v1.4** to resolve buggy and unexpected behaviour in previous versions. + In some circumstances this may represent a **breaking change**, + see [#1178](https://github.com/samuelcolvin/pydantic/issues/1178) and the precedence order below for details. + +In the case where a field's alias may be defined in multiple places, +the selected value is determined as follows (in descending order of priority): + +1. Set via `Field(..., alias=)`, directly on the model +2. Defined in `Config.fields`, directly on the model +3. Set via `Field(..., alias=)`, on a parent model +4. Defined in `Config.fields`, on a parent model +5. Generated by `alias_generator`, regardless of whether it's on the model or a parent + +!!! note + This means an `alias_generator` defined on a child model **does not** take priority over an alias defined + on a field in a parent model. + +For example: ```py {!.tmp_examples/model_config_alias_precedence.py!} ``` _(This script is complete, it should run "as is")_ - -This includes when a child model uses `alias_generator` where the aliases of all parent model fields will be updated. diff --git a/pydantic/fields.py b/pydantic/fields.py index c064dfb..c9818dd 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -18,7 +18,6 @@ from typing import ( Type, TypeVar, Union, - cast, ) from . import errors as errors_ @@ -60,6 +59,7 @@ class FieldInfo(Representation): __slots__ = ( 'default', 'alias', + 'alias_priority', 'title', 'description', 'const', @@ -79,6 +79,7 @@ class FieldInfo(Representation): def __init__(self, default: Any, **kwargs: Any) -> None: self.default = default self.alias = kwargs.pop('alias', None) + self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias else None) self.title = kwargs.pop('title', None) self.description = kwargs.pop('description', None) self.const = kwargs.pop('const', None) @@ -288,9 +289,12 @@ class ModelField(Representation): self.model_config = config info_from_config = config.get_field_info(self.name) config.prepare_field(self) - if info_from_config: - self.field_info.alias = info_from_config.get('alias') or self.field_info.alias or self.name - self.alias = cast(str, self.field_info.alias) + new_alias = info_from_config.get('alias') + new_alias_priority = info_from_config.get('alias_priority') or 0 + if new_alias and new_alias_priority >= (self.field_info.alias_priority or 0): + self.field_info.alias = new_alias + self.field_info.alias_priority = new_alias_priority + self.alias = new_alias @property def alt_alias(self) -> bool: diff --git a/pydantic/main.py b/pydantic/main.py index 5372f5c..993f55c 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -73,14 +73,23 @@ class BaseConfig: @classmethod def get_field_info(cls, name: str) -> Dict[str, Any]: - field_info = cls.fields.get(name) or {} - if isinstance(field_info, str): - field_info = {'alias': field_info} - elif cls.alias_generator and 'alias' not in field_info: + fields_value = cls.fields.get(name) + + if isinstance(fields_value, str): + field_info: Dict[str, Any] = {'alias': fields_value} + elif isinstance(fields_value, dict): + field_info = fields_value + else: + field_info = {} + + if 'alias' in field_info: + field_info.setdefault('alias_priority', 2) + + if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator: alias = cls.alias_generator(name) if not isinstance(alias, str): raise TypeError(f'Config.alias_generator must return str, not {type(alias)}') - field_info['alias'] = alias + field_info.update(alias=alias, alias_priority=1) return field_info @classmethod diff --git a/tests/test_aliases.py b/tests/test_aliases.py new file mode 100644 index 0000000..751cf5e --- /dev/null +++ b/tests/test_aliases.py @@ -0,0 +1,356 @@ +import re +from typing import Any, List, Optional + +import pytest + +from pydantic import BaseConfig, BaseModel, Extra, ValidationError +from pydantic.fields import Field + + +def test_alias_generator(): + def to_camel(string: str): + return ''.join(x.capitalize() for x in string.split('_')) + + class MyModel(BaseModel): + a: List[str] = None + foo_bar: str + + class Config: + alias_generator = to_camel + + data = {'A': ['foo', 'bar'], 'FooBar': 'foobar'} + v = MyModel(**data) + assert v.a == ['foo', 'bar'] + assert v.foo_bar == 'foobar' + assert v.dict(by_alias=True) == data + + +def test_alias_generator_with_field_schema(): + def to_upper_case(string: str): + return string.upper() + + class MyModel(BaseModel): + my_shiny_field: Any # Alias from Config.fields will be used + foo_bar: str # Alias from Config.fields will be used + baz_bar: str # Alias will be generated + another_field: str # Alias will be generated + + class Config: + alias_generator = to_upper_case + fields = {'my_shiny_field': 'MY_FIELD', 'foo_bar': {'alias': 'FOO'}, 'another_field': {'not_alias': 'a'}} + + data = {'MY_FIELD': ['a'], 'FOO': 'bar', 'BAZ_BAR': 'ok', 'ANOTHER_FIELD': '...'} + m = MyModel(**data) + assert m.dict(by_alias=True) == data + + +def test_alias_generator_wrong_type_error(): + def return_bytes(string): + return b'not a string' + + with pytest.raises(TypeError) as e: + + class MyModel(BaseModel): + bar: Any + + class Config: + alias_generator = return_bytes + + assert str(e.value) == "Config.alias_generator must return str, not " + + +def test_infer_alias(): + class Model(BaseModel): + a = 'foobar' + + class Config: + fields = {'a': '_a'} + + assert Model(_a='different').a == 'different' + assert repr(Model.__fields__['a']) == ( + "ModelField(name='a', type=str, required=False, default='foobar', alias='_a')" + ) + + +def test_alias_error(): + class Model(BaseModel): + a = 123 + + class Config: + fields = {'a': '_a'} + + assert Model(_a='123').a == 123 + + with pytest.raises(ValidationError) as exc_info: + Model(_a='foo') + assert exc_info.value.errors() == [ + {'loc': ('_a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + +def test_annotation_config(): + class Model(BaseModel): + b: float + a: int = 10 + _c: str + + class Config: + fields = {'b': 'foobar'} + + assert list(Model.__fields__.keys()) == ['b', 'a'] + assert [f.alias for f in Model.__fields__.values()] == ['foobar', 'a'] + assert Model(foobar='123').b == 123.0 + + +def test_alias_camel_case(): + class Model(BaseModel): + one_thing: int + another_thing: int + + class Config(BaseConfig): + @classmethod + def get_field_info(cls, name): + field_config = super().get_field_info(name) or {} + if 'alias' not in field_config: + field_config['alias'] = re.sub(r'(?:^|_)([a-z])', lambda m: m.group(1).upper(), name) + return field_config + + v = Model(**{'OneThing': 123, 'AnotherThing': '321'}) + assert v.one_thing == 123 + assert v.another_thing == 321 + assert v == {'one_thing': 123, 'another_thing': 321} + + +def test_get_field_info_inherit(): + class ModelOne(BaseModel): + class Config(BaseConfig): + @classmethod + def get_field_info(cls, name): + field_config = super().get_field_info(name) or {} + if 'alias' not in field_config: + field_config['alias'] = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), name) + return field_config + + class ModelTwo(ModelOne): + one_thing: int + another_thing: int + third_thing: int + + class Config: + fields = {'third_thing': 'Banana'} + + v = ModelTwo(**{'oneThing': 123, 'anotherThing': '321', 'Banana': 1}) + assert v == {'one_thing': 123, 'another_thing': 321, 'third_thing': 1} + + +def test_pop_by_field_name(): + class Model(BaseModel): + last_updated_by: Optional[str] = None + + class Config: + extra = Extra.forbid + allow_population_by_field_name = True + fields = {'last_updated_by': 'lastUpdatedBy'} + + assert Model(lastUpdatedBy='foo').dict() == {'last_updated_by': 'foo'} + assert Model(last_updated_by='foo').dict() == {'last_updated_by': 'foo'} + with pytest.raises(ValidationError) as exc_info: + Model(lastUpdatedBy='foo', last_updated_by='bar') + assert exc_info.value.errors() == [ + {'loc': ('last_updated_by',), 'msg': 'extra fields not permitted', 'type': 'value_error.extra'} + ] + + +def test_population_by_alias(): + with pytest.warns(DeprecationWarning, match='"allow_population_by_alias" is deprecated and replaced by'): + + class Model(BaseModel): + a: str + + class Config: + allow_population_by_alias = True + fields = {'a': {'alias': '_a'}} + + assert Model.__config__.allow_population_by_field_name is True + assert Model(a='different').a == 'different' + assert Model(a='different').dict() == {'a': 'different'} + assert Model(a='different').dict(by_alias=True) == {'_a': 'different'} + + +def test_alias_child_precedence(): + class Parent(BaseModel): + x: int + + class Config: + fields = {'x': 'x1'} + + class Child(Parent): + y: int + + class Config: + fields = {'y': 'y2', 'x': 'x2'} + + assert Child.__fields__['y'].alias == 'y2' + assert Child.__fields__['x'].alias == 'x2' + + +def test_alias_generator_parent(): + class Parent(BaseModel): + x: int + + class Config: + allow_population_by_field_name = True + + @classmethod + def alias_generator(cls, f_name): + return f_name + '1' + + class Child(Parent): + y: int + + class Config: + @classmethod + def alias_generator(cls, f_name): + return f_name + '2' + + assert Child.__fields__['y'].alias == 'y2' + assert Child.__fields__['x'].alias == 'x2' + + +def test_alias_generator_on_parent(): + class Parent(BaseModel): + x: bool = Field(..., alias='a_b_c') + y: str + + class Config: + @staticmethod + def alias_generator(x): + return x.upper() + + class Child(Parent): + y: str + z: str + + assert Parent.__fields__['x'].alias == 'a_b_c' + assert Parent.__fields__['y'].alias == 'Y' + assert Child.__fields__['x'].alias == 'a_b_c' + assert Child.__fields__['y'].alias == 'Y' + assert Child.__fields__['z'].alias == 'Z' + + +def test_alias_generator_on_child(): + class Parent(BaseModel): + x: bool = Field(..., alias='abc') + y: str + + class Child(Parent): + y: str + z: str + + class Config: + @staticmethod + def alias_generator(x): + return x.upper() + + assert [f.alias for f in Parent.__fields__.values()] == ['abc', 'y'] + assert [f.alias for f in Child.__fields__.values()] == ['abc', 'Y', 'Z'] + + +def test_low_priority_alias(): + class Parent(BaseModel): + x: bool = Field(..., alias='abc', alias_priority=1) + y: str + + class Child(Parent): + y: str + z: str + + class Config: + @staticmethod + def alias_generator(x): + return x.upper() + + assert [f.alias for f in Parent.__fields__.values()] == ['abc', 'y'] + assert [f.alias for f in Child.__fields__.values()] == ['X', 'Y', 'Z'] + + +def test_low_priority_alias_config(): + class Parent(BaseModel): + x: bool + y: str + + class Config: + fields = {'x': dict(alias='abc', alias_priority=1)} + + class Child(Parent): + y: str + z: str + + class Config: + @staticmethod + def alias_generator(x): + return x.upper() + + assert [f.alias for f in Parent.__fields__.values()] == ['abc', 'y'] + assert [f.alias for f in Child.__fields__.values()] == ['X', 'Y', 'Z'] + + +def test_field_vs_config(): + class Model(BaseModel): + x: str = Field(..., alias='x_on_field') + y: str + z: str + + class Config: + fields = {'x': dict(alias='x_on_config'), 'y': dict(alias='y_on_config')} + + assert [f.alias for f in Model.__fields__.values()] == ['x_on_field', 'y_on_config', 'z'] + + +def test_alias_priority(): + class Parent(BaseModel): + a: str = Field(..., alias='a_field_parent') + b: str = Field(..., alias='b_field_parent') + c: str = Field(..., alias='c_field_parent') + d: str + e: str + + class Config: + fields = { + 'a': dict(alias='a_config_parent'), + 'c': dict(alias='c_config_parent'), + 'd': dict(alias='d_config_parent'), + } + + @staticmethod + def alias_generator(x): + return f'{x}_generator_parent' + + class Child(Parent): + a: str = Field(..., alias='a_field_child') + + class Config: + fields = { + 'a': dict(alias='a_config_child'), + 'b': dict(alias='b_config_child'), + } + + @staticmethod + def alias_generator(x): + return f'{x}_generator_child' + + # debug([f.alias for f in Parent.__fields__.values()], [f.alias for f in Child.__fields__.values()]) + assert [f.alias for f in Parent.__fields__.values()] == [ + 'a_field_parent', + 'b_field_parent', + 'c_field_parent', + 'd_config_parent', + 'e_generator_parent', + ] + assert [f.alias for f in Child.__fields__.values()] == [ + 'a_field_child', + 'b_config_child', + 'c_field_parent', + 'd_config_parent', + 'e_generator_child', + ] diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 3f84cdf..25a07f9 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -1,4 +1,3 @@ -import re import sys from decimal import Decimal from enum import Enum @@ -7,7 +6,6 @@ from typing import Any, Dict, FrozenSet, Generic, List, Optional, Set, Tuple, Ty import pytest from pydantic import ( - BaseConfig, BaseModel, BaseSettings, Extra, @@ -328,49 +326,6 @@ def test_any_dict(): assert Model(v={2: [1, 2, 3]}).dict() == {'v': {2: [1, 2, 3]}} -def test_infer_alias(): - class Model(BaseModel): - a = 'foobar' - - class Config: - fields = {'a': '_a'} - - assert Model(_a='different').a == 'different' - assert repr(Model.__fields__['a']) == ( - "ModelField(name='a', type=str, required=False, default='foobar', alias='_a')" - ) - - -def test_alias_error(): - class Model(BaseModel): - a = 123 - - class Config: - fields = {'a': '_a'} - - assert Model(_a='123').a == 123 - - with pytest.raises(ValidationError) as exc_info: - Model(_a='foo') - assert exc_info.value.errors() == [ - {'loc': ('_a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} - ] - - -def test_annotation_config(): - class Model(BaseModel): - b: float - a: int = 10 - _c: str - - class Config: - fields = {'b': 'foobar'} - - assert list(Model.__fields__.keys()) == ['b', 'a'] - assert [f.alias for f in Model.__fields__.values()] == ['foobar', 'a'] - assert Model(foobar='123').b == 123.0 - - def test_success_values_include(): class Model(BaseModel): a: int = 1 @@ -721,47 +676,6 @@ def test_string_none(): ] -def test_alias_camel_case(): - class Model(BaseModel): - one_thing: int - another_thing: int - - class Config(BaseConfig): - @classmethod - def get_field_info(cls, name): - field_config = super().get_field_info(name) or {} - if 'alias' not in field_config: - field_config['alias'] = re.sub(r'(?:^|_)([a-z])', lambda m: m.group(1).upper(), name) - return field_config - - v = Model(**{'OneThing': 123, 'AnotherThing': '321'}) - assert v.one_thing == 123 - assert v.another_thing == 321 - assert v == {'one_thing': 123, 'another_thing': 321} - - -def test_get_field_info_inherit(): - class ModelOne(BaseModel): - class Config(BaseConfig): - @classmethod - def get_field_info(cls, name): - field_config = super().get_field_info(name) or {} - if 'alias' not in field_config: - field_config['alias'] = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), name) - return field_config - - class ModelTwo(ModelOne): - one_thing: int - another_thing: int - third_thing: int - - class Config: - fields = {'third_thing': 'Banana'} - - v = ModelTwo(**{'oneThing': 123, 'anotherThing': '321', 'Banana': 1}) - assert v == {'one_thing': 123, 'another_thing': 321, 'third_thing': 1} - - def test_return_errors_ok(): class Model(BaseModel): foo: int @@ -846,24 +760,6 @@ def test_multiple_errors(): assert Model(a=None).a is None -def test_pop_by_alias(): - class Model(BaseModel): - last_updated_by: Optional[str] = None - - class Config: - extra = Extra.forbid - allow_population_by_field_name = True - fields = {'last_updated_by': 'lastUpdatedBy'} - - assert Model(lastUpdatedBy='foo').dict() == {'last_updated_by': 'foo'} - assert Model(last_updated_by='foo').dict() == {'last_updated_by': 'foo'} - with pytest.raises(ValidationError) as exc_info: - Model(lastUpdatedBy='foo', last_updated_by='bar') - assert exc_info.value.errors() == [ - {'loc': ('last_updated_by',), 'msg': 'extra fields not permitted', 'type': 'value_error.extra'} - ] - - def test_validate_all(): class Model(BaseModel): a: int @@ -1118,22 +1014,6 @@ def test_scheme_deprecated(): foo: int = Schema(4) -def test_population_by_alias(): - with pytest.warns(DeprecationWarning, match='"allow_population_by_alias" is deprecated and replaced by'): - - class Model(BaseModel): - a: str - - class Config: - allow_population_by_alias = True - fields = {'a': {'alias': '_a'}} - - assert Model.__config__.allow_population_by_field_name is True - assert Model(a='different').a == 'different' - assert Model(a='different').dict() == {'a': 'different'} - assert Model(a='different').dict(by_alias=True) == {'_a': 'different'} - - def test_fields_deprecated(): class Model(BaseModel): v: str = 'x' @@ -1145,46 +1025,6 @@ def test_fields_deprecated(): assert Model.__fields__.keys() == {'v'} -def test_alias_child_precedence(): - class Parent(BaseModel): - x: int - - class Config: - fields = {'x': 'x1'} - - class Child(Parent): - y: int - - class Config: - fields = {'y': 'y2', 'x': 'x2'} - - assert Child.__fields__['y'].alias == 'y2' - assert Child.__fields__['x'].alias == 'x2' - - -def test_alias_generator_parent(): - class Parent(BaseModel): - x: int - - class Config: - allow_population_by_field_name = True - - @classmethod - def alias_generator(cls, f_name): - return f_name + '1' - - class Child(Parent): - y: int - - class Config: - @classmethod - def alias_generator(cls, f_name): - return f_name + '2' - - assert Child.__fields__['y'].alias == 'y2' - assert Child.__fields__['x'].alias == 'x2' - - def test_optional_field_constraints(): class MyModel(BaseModel): my_int: Optional[int] = Field(..., ge=3) diff --git a/tests/test_main.py b/tests/test_main.py index d093d52..678c70e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -812,58 +812,6 @@ def test_dict_with_extra_keys(): assert m.dict(by_alias=True) == {'alias_a': None, 'extra_key': 'extra'} -def test_alias_generator(): - def to_camel(string: str): - return ''.join(x.capitalize() for x in string.split('_')) - - class MyModel(BaseModel): - a: List[str] = None - foo_bar: str - - class Config: - alias_generator = to_camel - - data = {'A': ['foo', 'bar'], 'FooBar': 'foobar'} - v = MyModel(**data) - assert v.a == ['foo', 'bar'] - assert v.foo_bar == 'foobar' - assert v.dict(by_alias=True) == data - - -def test_alias_generator_with_field_schema(): - def to_upper_case(string: str): - return string.upper() - - class MyModel(BaseModel): - my_shiny_field: Any # Alias from Config.fields will be used - foo_bar: str # Alias from Config.fields will be used - baz_bar: str # Alias will be generated - another_field: str # Alias will be generated - - class Config: - alias_generator = to_upper_case - fields = {'my_shiny_field': 'MY_FIELD', 'foo_bar': {'alias': 'FOO'}, 'another_field': {'not_alias': 'a'}} - - data = {'MY_FIELD': ['a'], 'FOO': 'bar', 'BAZ_BAR': 'ok', 'ANOTHER_FIELD': '...'} - m = MyModel(**data) - assert m.dict(by_alias=True) == data - - -def test_alias_generator_wrong_type_error(): - def return_bytes(string): - return b'not a string' - - with pytest.raises(TypeError) as e: - - class MyModel(BaseModel): - bar: Any - - class Config: - alias_generator = return_bytes - - assert str(e.value) == "Config.alias_generator must return str, not " - - def test_root(): class MyModel(BaseModel): __root__: str diff --git a/tests/test_schema.py b/tests/test_schema.py index 3d5ee2c..bec0fa9 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -100,7 +100,7 @@ def test_by_alias(): title = 'Apple Pie' fields = {'a': 'Snap', 'b': 'Crackle'} - s = { + assert ApplePie.schema() == { 'type': 'object', 'title': 'Apple Pie', 'properties': { @@ -109,11 +109,29 @@ def test_by_alias(): }, 'required': ['Snap'], } - assert ApplePie.schema() == s assert list(ApplePie.schema(by_alias=True)['properties'].keys()) == ['Snap', 'Crackle'] assert list(ApplePie.schema(by_alias=False)['properties'].keys()) == ['a', 'b'] +def test_by_alias_generator(): + class ApplePie(BaseModel): + a: float + b: int = 10 + + class Config: + @staticmethod + def alias_generator(x): + return x.upper() + + assert ApplePie.schema() == { + 'title': 'ApplePie', + 'type': 'object', + 'properties': {'A': {'title': 'A', 'type': 'number'}, 'B': {'title': 'B', 'default': 10, 'type': 'integer'}}, + 'required': ['A'], + } + assert ApplePie.schema(by_alias=False)['properties'].keys() == {'a', 'b'} + + def test_sub_model(): class Foo(BaseModel): """hello"""