From e71f53d2b5a3a04962ecdf50edd021d0666a332c Mon Sep 17 00:00:00 2001 From: diabolo-dan Date: Sun, 5 Dec 2021 13:40:23 +0000 Subject: [PATCH] Improve generic subclass support (#2549) * Derive concrete subclasses for parameterised generics * Resolve type issues * Add negative assertions to generic subclass tests * Remove incorrect subclassing of partial. The type was incorrectly being picked up for this style of subclassing, and it can be regardless inferred through cls. * Apply feedback: * Improve parameterisation explanation * fix typos * Alias Parameterisation type * Apply suggestions from code review * start docstring with newline. * Use None as default over empty tuple. Co-authored-by: Samuel Colvin * Combine _assigned_parameters cases in __paramaterized_bases__ of generics * Add description for the `_assigned_parameters` variable. Co-authored-by: Samuel Colvin Co-authored-by: Samuel Colvin --- changes/2007-diabolo-dan.md | 1 + pydantic/generics.py | 77 ++++++++++++++++++++++++++++++++- pydantic/main.py | 10 +++-- tests/test_generics.py | 86 +++++++++++++++++++++++++++++++++++++ 4 files changed, 169 insertions(+), 5 deletions(-) create mode 100644 changes/2007-diabolo-dan.md diff --git a/changes/2007-diabolo-dan.md b/changes/2007-diabolo-dan.md new file mode 100644 index 0000000..d79c056 --- /dev/null +++ b/changes/2007-diabolo-dan.md @@ -0,0 +1 @@ +Add parameterised subclasses to `__bases__` when constructing new parameterised classes, so that `A <: B => A[int] <: B[int]`. diff --git a/pydantic/generics.py b/pydantic/generics.py index 17d782f..1fee636 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -30,6 +30,15 @@ _generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[B GenericModelT = TypeVar('GenericModelT', bound='GenericModel') TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type +Parametrization = Mapping[TypeVarType, Type[Any]] + +# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations +# as captured during construction of the class (not instances). +# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created, +# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`. +# (This information is only otherwise available after creation from the class name string). +_assigned_parameters: Dict[Type[Any], Parametrization] = {} + class GenericModel(BaseModel): __slots__ = () @@ -86,13 +95,15 @@ class GenericModel(BaseModel): create_model( model_name, __module__=model_module or cls.__module__, - __base__=cls, + __base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)), __config__=None, __validators__=validators, **fields, ), ) + _assigned_parameters[created_model] = typevars_map + if called_globally: # create global reference and therefore allow pickling object_by_reference = None reference_name = model_name @@ -142,6 +153,70 @@ class GenericModel(BaseModel): params_component = ', '.join(param_names) return f'{cls.__name__}[{params_component}]' + @classmethod + def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]: + """ + Returns unbound bases of cls parameterised to given type variables + + :param typevars_map: Dictionary of type applications for binding subclasses. + Given a generic class `Model` with 2 type variables [S, T] + and a concrete model `Model[str, int]`, + the value `{S: str, T: int}` would be passed to `typevars_map`. + :return: an iterator of generic sub classes, parameterised by `typevars_map` + and other assigned parameters of `cls` + + e.g.: + ``` + class A(GenericModel, Generic[T]): + ... + + class B(A[V], Generic[V]): + ... + + assert A[int] in B.__parameterized_bases__({V: int}) + ``` + """ + + def build_base_model( + base_model: Type[GenericModel], mapped_types: Parametrization + ) -> Iterator[Type[GenericModel]]: + base_parameters = tuple([mapped_types[param] for param in base_model.__parameters__]) + parameterized_base = base_model.__class_getitem__(base_parameters) + if parameterized_base is base_model or parameterized_base is cls: + # Avoid duplication in MRO + return + yield parameterized_base + + for base_model in cls.__bases__: + if not issubclass(base_model, GenericModel): + # not a class that can be meaningfully parameterized + continue + elif not getattr(base_model, '__parameters__', None): + # base_model is "GenericModel" (and has no __parameters__) + # or + # base_model is already concrete, and will be included transitively via cls. + continue + elif cls in _assigned_parameters: + if base_model in _assigned_parameters: + # cls is partially parameterised but not from base_model + # e.g. cls = B[S], base_model = A[S] + # B[S][int] should subclass A[int], (and will be transitively via B[int]) + # but it's not viable to consistently subclass types with arbitrary construction + # So don't attempt to include A[S][int] + continue + else: # base_model not in _assigned_parameters: + # cls is partially parameterized, base_model is original generic + # e.g. cls = B[str, T], base_model = B[S, T] + # Need to determine the mapping for the base_model parameters + mapped_types: Parametrization = { + key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items() + } + yield from build_base_model(base_model, mapped_types) + else: + # cls is base generic, so base_class has a distinct base + # can construct the Parameterised base model using typevars_map directly + yield from build_base_model(base_model, typevars_map) + def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: """Return type with all occurrences of `type_map` keys recursively replaced with their values. diff --git a/pydantic/main.py b/pydantic/main.py index a25e96e..ff63969 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -877,7 +877,7 @@ def create_model( __model_name: str, *, __config__: Optional[Type[BaseConfig]] = None, - __base__: Type['Model'], + __base__: Union[Type['Model'], Tuple[Type['Model'], ...]], __module__: str = __name__, __validators__: Dict[str, classmethod] = None, **field_definitions: Any, @@ -889,7 +889,7 @@ def create_model( __model_name: str, *, __config__: Optional[Type[BaseConfig]] = None, - __base__: Optional[Type['Model']] = None, + __base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None, __module__: str = __name__, __validators__: Dict[str, classmethod] = None, **field_definitions: Any, @@ -910,8 +910,10 @@ def create_model( if __base__ is not None: if __config__ is not None: raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together') + if not isinstance(__base__, tuple): + __base__ = (__base__,) else: - __base__ = cast(Type['Model'], BaseModel) + __base__ = (cast(Type['Model'], BaseModel),) fields = {} annotations = {} @@ -942,7 +944,7 @@ def create_model( if __config__: namespace['Config'] = inherit_config(__config__, BaseConfig) - return type(__model_name, (__base__,), namespace) + return type(__model_name, __base__, namespace) _missing = object() diff --git a/tests/test_generics.py b/tests/test_generics.py index 7e0cd73..fb071b0 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1160,6 +1160,92 @@ def test_generic_annotated(): SomeGenericModel[str](the_alias='qwe') +@skip_36 +def test_generic_subclass(): + T = TypeVar('T') + + class A(GenericModel, Generic[T]): + ... + + class B(A[T], Generic[T]): + ... + + assert B[int].__name__ == 'B[int]' + assert issubclass(B[int], B) + assert issubclass(B[int], A[int]) + assert not issubclass(B[int], A[str]) + + +@skip_36 +def test_generic_subclass_with_partial_application(): + T = TypeVar('T') + S = TypeVar('S') + + class A(GenericModel, Generic[T]): + ... + + class B(A[S], Generic[T, S]): + ... + + PartiallyAppliedB = B[str, T] + assert issubclass(PartiallyAppliedB[int], A[int]) + assert not issubclass(PartiallyAppliedB[int], A[str]) + assert not issubclass(PartiallyAppliedB[str], A[int]) + + +@skip_36 +def test_multilevel_generic_binding(): + T = TypeVar('T') + S = TypeVar('S') + + class A(GenericModel, Generic[T, S]): + ... + + class B(A[str, T], Generic[T]): + ... + + assert B[int].__name__ == 'B[int]' + assert issubclass(B[int], A[str, int]) + assert not issubclass(B[str], A[str, int]) + + +@skip_36 +def test_generic_subclass_with_extra_type(): + T = TypeVar('T') + S = TypeVar('S') + + class A(GenericModel, Generic[T]): + ... + + class B(A[S], Generic[T, S]): + ... + + assert B[int, str].__name__ == 'B[int, str]', B[int, str].__name__ + assert issubclass(B[str, int], B) + assert issubclass(B[str, int], A[int]) + assert not issubclass(B[int, str], A[int]) + + +@skip_36 +def test_multi_inheritance_generic_binding(): + T = TypeVar('T') + + class A(GenericModel, Generic[T]): + ... + + class B(A[int], Generic[T]): + ... + + class C(B[str], Generic[T]): + ... + + assert C[float].__name__ == 'C[float]' + assert issubclass(C[float], B[str]) + assert not issubclass(C[float], B[int]) + assert issubclass(C[float], A[int]) + assert not issubclass(C[float], A[str]) + + @skip_36 def test_parse_generic_json(): T = TypeVar('T')