mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
Add support for Literal annotation (#582)
fix #561 * Add support for Literal annotation * Updated requirements.txt * incorporating feedback * skip typing_extensions tests if not installed * missed a spot * address feedback * Make work with python 3.6 * Work for *both* 3.6 and 3.7 * incorporate feedback * fixed naming and quotes * Trying to fix LGTM bot issue
This commit is contained in:
+2
-8
@@ -14,7 +14,6 @@ install:
|
||||
- pip freeze
|
||||
|
||||
script:
|
||||
# test without cython but with ujson and email-validator
|
||||
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(1 if pydantic.compiled else 0)"
|
||||
- make test
|
||||
|
||||
@@ -40,7 +39,6 @@ jobs:
|
||||
python: 3.6
|
||||
name: 'Cython: 3.6'
|
||||
script:
|
||||
# test with cython, ujson and email-validator
|
||||
- make build-cython-trace
|
||||
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
|
||||
- make test
|
||||
@@ -50,7 +48,6 @@ jobs:
|
||||
python: 3.7
|
||||
name: 'Cython: 3.7'
|
||||
script:
|
||||
# test with cython, ujson and email-validator
|
||||
- make build-cython-trace
|
||||
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
|
||||
- make test
|
||||
@@ -61,8 +58,7 @@ jobs:
|
||||
python: 3.6
|
||||
name: 'Without Deps 3.6'
|
||||
script:
|
||||
# test without cython, ujson and email-validator
|
||||
- pip uninstall -y ujson email-validator
|
||||
- pip uninstall -y ujson email-validator typing-extensions
|
||||
- make test
|
||||
env:
|
||||
- 'DEPS=no'
|
||||
@@ -70,8 +66,7 @@ jobs:
|
||||
python: 3.7
|
||||
name: 'Without Deps 3.7'
|
||||
script:
|
||||
# test without cython, ujson and email-validator
|
||||
- pip uninstall -y ujson email-validator cython
|
||||
- pip uninstall -y ujson email-validator cython typing-extensions
|
||||
- make test
|
||||
env:
|
||||
- 'DEPS=no'
|
||||
@@ -80,7 +75,6 @@ jobs:
|
||||
python: 3.7
|
||||
name: 'Benchmarks'
|
||||
script:
|
||||
# default install skips cython compilation, need to compile for benchmarks
|
||||
- make build-cython
|
||||
- BENCHMARK_REPEATS=1 make benchmark-all
|
||||
after_success: skip
|
||||
|
||||
+4
-1
@@ -49,7 +49,10 @@ class NoneIsAllowedError(PydanticTypeError):
|
||||
|
||||
class WrongConstantError(PydanticValueError):
|
||||
code = 'const'
|
||||
msg_template = 'expected constant value {const!r}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v) for v in self.ctx['permitted']) # type: ignore
|
||||
return f'unexpected value; permitted: {permitted}'
|
||||
|
||||
|
||||
class BytesError(PydanticTypeError):
|
||||
|
||||
@@ -26,6 +26,11 @@ from .types import Json, JsonWrapper
|
||||
from .utils import AnyCallable, AnyType, Callable, ForwardRef, display_as_type, lenient_issubclass, sequence_like
|
||||
from .validators import NoneType, constant_validator, dict_validator, find_validators
|
||||
|
||||
try:
|
||||
from typing_extensions import Literal
|
||||
except ImportError:
|
||||
Literal = None # type: ignore
|
||||
|
||||
Required: Any = Ellipsis
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
@@ -187,6 +192,8 @@ class Field:
|
||||
return
|
||||
if origin is Callable:
|
||||
return
|
||||
if Literal is not None and origin is Literal:
|
||||
return
|
||||
if origin is Union:
|
||||
types_ = []
|
||||
for type_ in self.type_.__args__: # type: ignore
|
||||
|
||||
@@ -24,6 +24,11 @@ from typing import ( # type: ignore
|
||||
|
||||
import pydantic
|
||||
|
||||
try:
|
||||
from typing_extensions import Literal
|
||||
except ImportError:
|
||||
Literal = None # type: ignore
|
||||
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError:
|
||||
@@ -285,6 +290,18 @@ def is_callable_type(type_: AnyType) -> bool:
|
||||
return type_ is Callable or getattr(type_, '__origin__', None) is Callable
|
||||
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
|
||||
def is_literal_type(type_: AnyType) -> bool:
|
||||
return Literal is not None and getattr(type_, '__origin__', None) is Literal
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def is_literal_type(type_: AnyType) -> bool:
|
||||
return Literal is not None and hasattr(type_, '__values__') and type_ == Literal[type_.__values__]
|
||||
|
||||
|
||||
def _check_classvar(v: AnyType) -> bool:
|
||||
return type(v) == type(ClassVar) and (sys.version_info < (3, 7) or getattr(v, '_name', None) == 'ClassVar')
|
||||
|
||||
|
||||
+33
-3
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal, DecimalException
|
||||
@@ -24,7 +25,16 @@ from uuid import UUID
|
||||
|
||||
from . import errors
|
||||
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from .utils import AnyCallable, AnyType, ForwardRef, change_exception, display_as_type, is_callable_type, sequence_like
|
||||
from .utils import (
|
||||
AnyCallable,
|
||||
AnyType,
|
||||
ForwardRef,
|
||||
change_exception,
|
||||
display_as_type,
|
||||
is_callable_type,
|
||||
is_literal_type,
|
||||
sequence_like,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .fields import Field
|
||||
@@ -140,7 +150,7 @@ def constant_validator(v: 'Any', field: 'Field') -> 'Any':
|
||||
Schema.
|
||||
"""
|
||||
if v != field.default:
|
||||
raise errors.WrongConstantError(given=v, const=field.default)
|
||||
raise errors.WrongConstantError(given=v, permitted=[field.default])
|
||||
|
||||
return v
|
||||
|
||||
@@ -334,6 +344,21 @@ def callable_validator(v: Any) -> AnyCallable:
|
||||
raise errors.CallableError(value=v)
|
||||
|
||||
|
||||
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
|
||||
if sys.version_info >= (3, 7):
|
||||
permitted_choices = type_.__args__
|
||||
else:
|
||||
permitted_choices = type_.__values__
|
||||
allowed_choices_set = set(permitted_choices)
|
||||
|
||||
def literal_validator(v: Any) -> Any:
|
||||
if v not in allowed_choices_set:
|
||||
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
|
||||
return v
|
||||
|
||||
return literal_validator
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@@ -409,7 +434,9 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [
|
||||
]
|
||||
|
||||
|
||||
def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[AnyCallable, None, None]:
|
||||
def find_validators( # noqa: C901 (ignore complexity)
|
||||
type_: AnyType, config: Type['BaseConfig']
|
||||
) -> Generator[AnyCallable, None, None]:
|
||||
if type_ is Any:
|
||||
return
|
||||
type_type = type(type_)
|
||||
@@ -421,6 +448,9 @@ def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[Any
|
||||
if is_callable_type(type_):
|
||||
yield callable_validator
|
||||
return
|
||||
if is_literal_type(type_):
|
||||
yield make_literal_validator(type_)
|
||||
return
|
||||
|
||||
supertype = _find_supertype(type_)
|
||||
if supertype is not None:
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
ujson==1.35
|
||||
email-validator==1.0.4
|
||||
dataclasses==0.6; python_version < '3.7'
|
||||
typing-extensions==3.7.2
|
||||
|
||||
@@ -101,6 +101,7 @@ setup(
|
||||
extras_require={
|
||||
'ujson': ['ujson>=1.35'],
|
||||
'email': ['email-validator>=1.0.3'],
|
||||
'typing_extensions': ['typing-extensions>=3.7.2']
|
||||
},
|
||||
ext_modules=ext_modules,
|
||||
)
|
||||
|
||||
+2
-2
@@ -391,9 +391,9 @@ def test_const_with_wrong_value():
|
||||
assert exc_info.value.errors() == [
|
||||
{
|
||||
'loc': ('a',),
|
||||
'msg': 'expected constant value 3',
|
||||
'msg': 'unexpected value; permitted: 3',
|
||||
'type': 'value_error.const',
|
||||
'ctx': {'given': 4, 'const': 3},
|
||||
'ctx': {'given': 4, 'permitted': [3]},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -48,6 +48,11 @@ try:
|
||||
except ImportError:
|
||||
email_validator = None
|
||||
|
||||
try:
|
||||
import typing_extensions
|
||||
except ImportError:
|
||||
typing_extensions = None
|
||||
|
||||
|
||||
class ConBytesModel(BaseModel):
|
||||
v: conbytes(max_length=10) = b'foobar'
|
||||
@@ -1661,3 +1666,40 @@ def test_generic_without_params_error():
|
||||
{'loc': ('generic_list',), 'msg': 'value is not a valid list', 'type': 'type_error.list'},
|
||||
{'loc': ('generic_dict',), 'msg': 'value is not a valid dict', 'type': 'type_error.dict'},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
|
||||
def test_literal_single():
|
||||
class Model(BaseModel):
|
||||
a: typing_extensions.Literal['a']
|
||||
|
||||
Model(a='a')
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Model(a='b')
|
||||
assert exc_info.value.errors() == [
|
||||
{
|
||||
'loc': ('a',),
|
||||
'msg': "unexpected value; permitted: 'a'",
|
||||
'type': 'value_error.const',
|
||||
'ctx': {'given': 'b', 'permitted': ('a',)},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
|
||||
def test_literal_multiple():
|
||||
class Model(BaseModel):
|
||||
a_or_b: typing_extensions.Literal['a', 'b']
|
||||
|
||||
Model(a_or_b='a')
|
||||
Model(a_or_b='b')
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Model(a_or_b='c')
|
||||
assert exc_info.value.errors() == [
|
||||
{
|
||||
'loc': ('a_or_b',),
|
||||
'msg': "unexpected value; permitted: 'a', 'b'",
|
||||
'type': 'value_error.const',
|
||||
'ctx': {'given': 'c', 'permitted': ('a', 'b')},
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user