diff --git a/.travis.yml b/.travis.yml index ce60100..9f05c7e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,6 @@ install: - pip freeze script: -# test without cython but with ujson and email-validator - python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(1 if pydantic.compiled else 0)" - make test @@ -40,7 +39,6 @@ jobs: python: 3.6 name: 'Cython: 3.6' script: - # test with cython, ujson and email-validator - make build-cython-trace - python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)" - make test @@ -50,7 +48,6 @@ jobs: python: 3.7 name: 'Cython: 3.7' script: - # test with cython, ujson and email-validator - make build-cython-trace - python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)" - make test @@ -61,8 +58,7 @@ jobs: python: 3.6 name: 'Without Deps 3.6' script: - # test without cython, ujson and email-validator - - pip uninstall -y ujson email-validator + - pip uninstall -y ujson email-validator typing-extensions - make test env: - 'DEPS=no' @@ -70,8 +66,7 @@ jobs: python: 3.7 name: 'Without Deps 3.7' script: - # test without cython, ujson and email-validator - - pip uninstall -y ujson email-validator cython + - pip uninstall -y ujson email-validator cython typing-extensions - make test env: - 'DEPS=no' @@ -80,7 +75,6 @@ jobs: python: 3.7 name: 'Benchmarks' script: - # default install skips cython compilation, need to compile for benchmarks - make build-cython - BENCHMARK_REPEATS=1 make benchmark-all after_success: skip diff --git a/pydantic/errors.py b/pydantic/errors.py index 42c5f05..eda450a 100644 --- a/pydantic/errors.py +++ b/pydantic/errors.py @@ -49,7 +49,10 @@ class NoneIsAllowedError(PydanticTypeError): class WrongConstantError(PydanticValueError): code = 'const' - msg_template = 'expected constant value {const!r}' + + def __str__(self) -> str: + permitted = ', '.join(repr(v) for v in self.ctx['permitted']) # type: ignore + return f'unexpected value; permitted: {permitted}' class BytesError(PydanticTypeError): diff --git a/pydantic/fields.py b/pydantic/fields.py index 8e62007..f512385 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -26,6 +26,11 @@ from .types import Json, JsonWrapper from .utils import AnyCallable, AnyType, Callable, ForwardRef, display_as_type, lenient_issubclass, sequence_like from .validators import NoneType, constant_validator, dict_validator, find_validators +try: + from typing_extensions import Literal +except ImportError: + Literal = None # type: ignore + Required: Any = Ellipsis if TYPE_CHECKING: # pragma: no cover @@ -187,6 +192,8 @@ class Field: return if origin is Callable: return + if Literal is not None and origin is Literal: + return if origin is Union: types_ = [] for type_ in self.type_.__args__: # type: ignore diff --git a/pydantic/utils.py b/pydantic/utils.py index 01f630a..c30031b 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -24,6 +24,11 @@ from typing import ( # type: ignore import pydantic +try: + from typing_extensions import Literal +except ImportError: + Literal = None # type: ignore + try: import email_validator except ImportError: @@ -285,6 +290,18 @@ def is_callable_type(type_: AnyType) -> bool: return type_ is Callable or getattr(type_, '__origin__', None) is Callable +if sys.version_info >= (3, 7): + + def is_literal_type(type_: AnyType) -> bool: + return Literal is not None and getattr(type_, '__origin__', None) is Literal + + +else: + + def is_literal_type(type_: AnyType) -> bool: + return Literal is not None and hasattr(type_, '__values__') and type_ == Literal[type_.__values__] + + def _check_classvar(v: AnyType) -> bool: return type(v) == type(ClassVar) and (sys.version_info < (3, 7) or getattr(v, '_name', None) == 'ClassVar') diff --git a/pydantic/validators.py b/pydantic/validators.py index 62c5444..f3d3aca 100644 --- a/pydantic/validators.py +++ b/pydantic/validators.py @@ -1,4 +1,5 @@ import re +import sys from collections import OrderedDict from datetime import date, datetime, time, timedelta from decimal import Decimal, DecimalException @@ -24,7 +25,16 @@ from uuid import UUID from . import errors from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from .utils import AnyCallable, AnyType, ForwardRef, change_exception, display_as_type, is_callable_type, sequence_like +from .utils import ( + AnyCallable, + AnyType, + ForwardRef, + change_exception, + display_as_type, + is_callable_type, + is_literal_type, + sequence_like, +) if TYPE_CHECKING: # pragma: no cover from .fields import Field @@ -140,7 +150,7 @@ def constant_validator(v: 'Any', field: 'Field') -> 'Any': Schema. """ if v != field.default: - raise errors.WrongConstantError(given=v, const=field.default) + raise errors.WrongConstantError(given=v, permitted=[field.default]) return v @@ -334,6 +344,21 @@ def callable_validator(v: Any) -> AnyCallable: raise errors.CallableError(value=v) +def make_literal_validator(type_: Any) -> Callable[[Any], Any]: + if sys.version_info >= (3, 7): + permitted_choices = type_.__args__ + else: + permitted_choices = type_.__values__ + allowed_choices_set = set(permitted_choices) + + def literal_validator(v: Any) -> Any: + if v not in allowed_choices_set: + raise errors.WrongConstantError(given=v, permitted=permitted_choices) + return v + + return literal_validator + + T = TypeVar('T') @@ -409,7 +434,9 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ ] -def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[AnyCallable, None, None]: +def find_validators( # noqa: C901 (ignore complexity) + type_: AnyType, config: Type['BaseConfig'] +) -> Generator[AnyCallable, None, None]: if type_ is Any: return type_type = type(type_) @@ -421,6 +448,9 @@ def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[Any if is_callable_type(type_): yield callable_validator return + if is_literal_type(type_): + yield make_literal_validator(type_) + return supertype = _find_supertype(type_) if supertype is not None: diff --git a/requirements.txt b/requirements.txt index ed4761f..38a7e1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ ujson==1.35 email-validator==1.0.4 dataclasses==0.6; python_version < '3.7' +typing-extensions==3.7.2 diff --git a/setup.py b/setup.py index 198daec..2a4552d 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ setup( extras_require={ 'ujson': ['ujson>=1.35'], 'email': ['email-validator>=1.0.3'], + 'typing_extensions': ['typing-extensions>=3.7.2'] }, ext_modules=ext_modules, ) diff --git a/tests/test_main.py b/tests/test_main.py index a165abe..ca06bae 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -391,9 +391,9 @@ def test_const_with_wrong_value(): assert exc_info.value.errors() == [ { 'loc': ('a',), - 'msg': 'expected constant value 3', + 'msg': 'unexpected value; permitted: 3', 'type': 'value_error.const', - 'ctx': {'given': 4, 'const': 3}, + 'ctx': {'given': 4, 'permitted': [3]}, } ] diff --git a/tests/test_types.py b/tests/test_types.py index 94eb81e..77110d0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -48,6 +48,11 @@ try: except ImportError: email_validator = None +try: + import typing_extensions +except ImportError: + typing_extensions = None + class ConBytesModel(BaseModel): v: conbytes(max_length=10) = b'foobar' @@ -1661,3 +1666,40 @@ def test_generic_without_params_error(): {'loc': ('generic_list',), 'msg': 'value is not a valid list', 'type': 'type_error.list'}, {'loc': ('generic_dict',), 'msg': 'value is not a valid dict', 'type': 'type_error.dict'}, ] + + +@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed') +def test_literal_single(): + class Model(BaseModel): + a: typing_extensions.Literal['a'] + + Model(a='a') + with pytest.raises(ValidationError) as exc_info: + Model(a='b') + assert exc_info.value.errors() == [ + { + 'loc': ('a',), + 'msg': "unexpected value; permitted: 'a'", + 'type': 'value_error.const', + 'ctx': {'given': 'b', 'permitted': ('a',)}, + } + ] + + +@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed') +def test_literal_multiple(): + class Model(BaseModel): + a_or_b: typing_extensions.Literal['a', 'b'] + + Model(a_or_b='a') + Model(a_or_b='b') + with pytest.raises(ValidationError) as exc_info: + Model(a_or_b='c') + assert exc_info.value.errors() == [ + { + 'loc': ('a_or_b',), + 'msg': "unexpected value; permitted: 'a', 'b'", + 'type': 'value_error.const', + 'ctx': {'given': 'c', 'permitted': ('a', 'b')}, + } + ]