diff --git a/changes/1895-PrettyWood.md b/changes/1895-PrettyWood.md new file mode 100644 index 0000000..bd80647 --- /dev/null +++ b/changes/1895-PrettyWood.md @@ -0,0 +1 @@ +stop calling parent class `root_validator` if overridden diff --git a/pydantic/main.py b/pydantic/main.py index f2a674a..5644c0a 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -249,8 +249,14 @@ class ModelMetaclass(ABCMeta): } or None, '__validators__': vg.validators, - '__pre_root_validators__': unique_list(pre_root_validators + pre_rv_new), - '__post_root_validators__': unique_list(post_root_validators + post_rv_new), + '__pre_root_validators__': unique_list( + pre_root_validators + pre_rv_new, + name_factory=lambda v: v.__name__, + ), + '__post_root_validators__': unique_list( + post_root_validators + post_rv_new, + name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__, + ), '__schema_cache__': {}, '__json_encoder__': staticmethod(json_encoder), '__custom_root_type__': _custom_root_type, diff --git a/pydantic/utils.py b/pydantic/utils.py index 6a6fadb..8d050b4 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -279,16 +279,25 @@ def to_camel(string: str) -> str: T = TypeVar('T') -def unique_list(input_list: Union[List[T], Tuple[T, ...]]) -> List[T]: +def unique_list( + input_list: Union[List[T], Tuple[T, ...]], + *, + name_factory: Callable[[T], str] = str, +) -> List[T]: """ Make a list unique while maintaining order. + We update the list if another one with the same name is set + (e.g. root validator overridden in subclass) """ - result = [] - unique_set = set() + result: List[T] = [] + result_names: List[str] = [] for v in input_list: - if v not in unique_set: - unique_set.add(v) + v_name = name_factory(v) + if v_name not in result_names: + result_names.append(v_name) result.append(v) + else: + result[result_names.index(v_name)] = v return result diff --git a/tests/test_validators.py b/tests/test_validators.py index c863eda..6c1c55f 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1256,3 +1256,39 @@ def test_exceptions_in_field_validators_restore_original_field_value(): with pytest.raises(RuntimeError, match='test error'): model.foo = 'raise_exception' assert model.foo == 'foo' + + +def test_overridden_root_validators(mocker): + validate_stub = mocker.stub(name='validate') + + class A(BaseModel): + x: str + + @root_validator(pre=True) + def pre_root(cls, values): + validate_stub('A', 'pre') + return values + + @root_validator(pre=False) + def post_root(cls, values): + validate_stub('A', 'post') + return values + + class B(A): + @root_validator(pre=True) + def pre_root(cls, values): + validate_stub('B', 'pre') + return values + + @root_validator(pre=False) + def post_root(cls, values): + validate_stub('B', 'post') + return values + + A(x='pika') + assert validate_stub.call_args_list == [mocker.call('A', 'pre'), mocker.call('A', 'post')] + + validate_stub.reset_mock() + + B(x='pika') + assert validate_stub.call_args_list == [mocker.call('B', 'pre'), mocker.call('B', 'post')]