From e1ae058afbf366b796c7a73a65497dbd98255b3f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 20 Dec 2019 15:08:45 +0000 Subject: [PATCH] Consistent checks for sequence like objects (#1111) --- changes/1090-samuelcolvin.md | 1 + pydantic/env_settings.py | 12 +++++++----- pydantic/fields.py | 2 +- pydantic/main.py | 4 ++-- pydantic/schema.py | 4 ++-- tests/test_schema.py | 21 ++++++++++++++++++++- tests/test_settings.py | 14 ++++++++++++++ 7 files changed, 47 insertions(+), 11 deletions(-) create mode 100644 changes/1090-samuelcolvin.md diff --git a/changes/1090-samuelcolvin.md b/changes/1090-samuelcolvin.md new file mode 100644 index 0000000..12dcd88 --- /dev/null +++ b/changes/1090-samuelcolvin.md @@ -0,0 +1 @@ +Consistent checks for sequence like objects. diff --git a/pydantic/env_settings.py b/pydantic/env_settings.py index e582a89..4f3d190 100644 --- a/pydantic/env_settings.py +++ b/pydantic/env_settings.py @@ -1,11 +1,11 @@ import os import warnings -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import AbstractSet, Any, Dict, List, Mapping, Optional, Union from .fields import ModelField from .main import BaseModel, Extra from .typing import display_as_type -from .utils import deep_update +from .utils import deep_update, sequence_like class SettingsError(ValueError): @@ -65,7 +65,7 @@ class BaseSettings(BaseModel): @classmethod def prepare_field(cls, field: ModelField) -> None: - env_names: Iterable[str] + env_names: Union[List[str], AbstractSet[str]] env = field.field_info.extra.get('env') if env is None: if field.has_alias: @@ -75,11 +75,13 @@ class BaseSettings(BaseModel): 'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names', FutureWarning, ) - env_names = [cls.env_prefix + field.name] + env_names = {cls.env_prefix + field.name} elif isinstance(env, str): env_names = {env} - elif isinstance(env, (list, set, tuple)): + elif isinstance(env, (set, frozenset)): env_names = env + elif sequence_like(env): + env_names = list(env) else: raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set') diff --git a/pydantic/fields.py b/pydantic/fields.py index c522716..0af6ff7 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -640,7 +640,7 @@ class ModelField(Representation): return ( self.shape != SHAPE_SINGLETON - or lenient_issubclass(self.type_, (BaseModel, list, set, dict)) + or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) or hasattr(self.type_, '__pydantic_model__') # pydantic dataclass ) diff --git a/pydantic/main.py b/pydantic/main.py index 42d8b27..b1d38d0 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -18,7 +18,7 @@ from .parse import Protocol, load_file, load_str_bytes from .schema import model_schema from .types import PyObject, StrBytes from .typing import AnyCallable, AnyType, ForwardRef, is_classvar, resolve_annotations, update_field_forward_refs -from .utils import GetterDict, Representation, ValueItems, lenient_issubclass, validate_field_name +from .utils import GetterDict, Representation, ValueItems, lenient_issubclass, sequence_like, validate_field_name if TYPE_CHECKING: from .class_validators import ValidatorListDict @@ -595,7 +595,7 @@ class BaseModel(metaclass=ModelMetaclass): and (not value_include or value_include.is_included(k_)) } - elif isinstance(v, (list, set, tuple)): + elif sequence_like(v): return type(v)( cls._get_value( v_, diff --git a/pydantic/schema.py b/pydantic/schema.py index 523c4a2..2c3b1e7 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -72,7 +72,7 @@ from .typing import ( literal_values, new_type_supertype, ) -from .utils import get_model, lenient_issubclass +from .utils import get_model, lenient_issubclass, sequence_like if TYPE_CHECKING: from .main import BaseModel # noqa: F401 @@ -739,7 +739,7 @@ def multivalue_literal_field_for_schema(values: Tuple[Any, ...], field: ModelFie def encode_default(dft: Any) -> Any: if isinstance(dft, (int, float, str)): return dft - elif isinstance(dft, (tuple, list, set)): + elif sequence_like(dft): t = type(dft) return t(encode_default(v) for v in dft) elif isinstance(dft, dict): diff --git a/tests/test_schema.py b/tests/test_schema.py index 3139309..c1fd59f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -6,7 +6,7 @@ from decimal import Decimal from enum import Enum, IntEnum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path -from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, FrozenSet, List, NewType, Optional, Set, Tuple, Union from uuid import UUID import pytest @@ -1704,3 +1704,22 @@ def test_schema_attributes(): 'properties': {'example': {'title': 'Example', 'enum': ['GT', 'LT', 'GE', 'LE', 'ML', 'MO', 'RE']}}, 'required': ['example'], } + + +def test_frozen_set(): + class Model(BaseModel): + a: FrozenSet[int] = frozenset({1, 2, 3}) + + assert Model.schema() == { + 'title': 'Model', + 'type': 'object', + 'properties': { + 'a': { + 'title': 'A', + 'default': frozenset({1, 2, 3}), + 'type': 'array', + 'items': {'type': 'integer'}, + 'uniqueItems': True, + }, + }, + } diff --git a/tests/test_settings.py b/tests/test_settings.py index 558781c..cdcd667 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -391,3 +391,17 @@ def test_prefix_on_parent(env): env.set('PREFIX_VAR', 'new') assert MyBaseSettings().dict() == {'var': 'old'} assert MySubSettings().dict() == {'var': 'new'} + + +def test_frozenset(env): + class Settings(BaseSettings): + foo: str = 'default foo' + + class Config: + fields = {'foo': {'env': frozenset(['foo_a', 'foo_b'])}} + + assert Settings.__fields__['foo'].field_info.extra['env_names'] == frozenset({'foo_a', 'foo_b'}) + + assert Settings().dict() == {'foo': 'default foo'} + env.set('foo_a', 'x') + assert Settings().dict() == {'foo': 'x'}