Add support for Literal annotation (#582)

fix #561

* Add support for Literal annotation

* Updated requirements.txt

* incorporating feedback

* skip typing_extensions tests if not installed

* missed a spot

* address feedback

* Make work with python 3.6

* Work for *both* 3.6 and 3.7

* incorporate feedback

* fixed naming and quotes

* Trying to fix LGTM bot issue
This commit is contained in:
dmontagu
2019-06-25 02:33:21 -07:00
committed by Samuel Colvin
parent 7000a27d56
commit 3ee54ed2bb
9 changed files with 109 additions and 14 deletions
+2 -8
View File
@@ -14,7 +14,6 @@ install:
- pip freeze
script:
# test without cython but with ujson and email-validator
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(1 if pydantic.compiled else 0)"
- make test
@@ -40,7 +39,6 @@ jobs:
python: 3.6
name: 'Cython: 3.6'
script:
# test with cython, ujson and email-validator
- make build-cython-trace
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
- make test
@@ -50,7 +48,6 @@ jobs:
python: 3.7
name: 'Cython: 3.7'
script:
# test with cython, ujson and email-validator
- make build-cython-trace
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
- make test
@@ -61,8 +58,7 @@ jobs:
python: 3.6
name: 'Without Deps 3.6'
script:
# test without cython, ujson and email-validator
- pip uninstall -y ujson email-validator
- pip uninstall -y ujson email-validator typing-extensions
- make test
env:
- 'DEPS=no'
@@ -70,8 +66,7 @@ jobs:
python: 3.7
name: 'Without Deps 3.7'
script:
# test without cython, ujson and email-validator
- pip uninstall -y ujson email-validator cython
- pip uninstall -y ujson email-validator cython typing-extensions
- make test
env:
- 'DEPS=no'
@@ -80,7 +75,6 @@ jobs:
python: 3.7
name: 'Benchmarks'
script:
# default install skips cython compilation, need to compile for benchmarks
- make build-cython
- BENCHMARK_REPEATS=1 make benchmark-all
after_success: skip
+4 -1
View File
@@ -49,7 +49,10 @@ class NoneIsAllowedError(PydanticTypeError):
class WrongConstantError(PydanticValueError):
code = 'const'
msg_template = 'expected constant value {const!r}'
def __str__(self) -> str:
permitted = ', '.join(repr(v) for v in self.ctx['permitted']) # type: ignore
return f'unexpected value; permitted: {permitted}'
class BytesError(PydanticTypeError):
+7
View File
@@ -26,6 +26,11 @@ from .types import Json, JsonWrapper
from .utils import AnyCallable, AnyType, Callable, ForwardRef, display_as_type, lenient_issubclass, sequence_like
from .validators import NoneType, constant_validator, dict_validator, find_validators
try:
from typing_extensions import Literal
except ImportError:
Literal = None # type: ignore
Required: Any = Ellipsis
if TYPE_CHECKING: # pragma: no cover
@@ -187,6 +192,8 @@ class Field:
return
if origin is Callable:
return
if Literal is not None and origin is Literal:
return
if origin is Union:
types_ = []
for type_ in self.type_.__args__: # type: ignore
+17
View File
@@ -24,6 +24,11 @@ from typing import ( # type: ignore
import pydantic
try:
from typing_extensions import Literal
except ImportError:
Literal = None # type: ignore
try:
import email_validator
except ImportError:
@@ -285,6 +290,18 @@ def is_callable_type(type_: AnyType) -> bool:
return type_ is Callable or getattr(type_, '__origin__', None) is Callable
if sys.version_info >= (3, 7):
def is_literal_type(type_: AnyType) -> bool:
return Literal is not None and getattr(type_, '__origin__', None) is Literal
else:
def is_literal_type(type_: AnyType) -> bool:
return Literal is not None and hasattr(type_, '__values__') and type_ == Literal[type_.__values__]
def _check_classvar(v: AnyType) -> bool:
return type(v) == type(ClassVar) and (sys.version_info < (3, 7) or getattr(v, '_name', None) == 'ClassVar')
+33 -3
View File
@@ -1,4 +1,5 @@
import re
import sys
from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from decimal import Decimal, DecimalException
@@ -24,7 +25,16 @@ from uuid import UUID
from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .utils import AnyCallable, AnyType, ForwardRef, change_exception, display_as_type, is_callable_type, sequence_like
from .utils import (
AnyCallable,
AnyType,
ForwardRef,
change_exception,
display_as_type,
is_callable_type,
is_literal_type,
sequence_like,
)
if TYPE_CHECKING: # pragma: no cover
from .fields import Field
@@ -140,7 +150,7 @@ def constant_validator(v: 'Any', field: 'Field') -> 'Any':
Schema.
"""
if v != field.default:
raise errors.WrongConstantError(given=v, const=field.default)
raise errors.WrongConstantError(given=v, permitted=[field.default])
return v
@@ -334,6 +344,21 @@ def callable_validator(v: Any) -> AnyCallable:
raise errors.CallableError(value=v)
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
if sys.version_info >= (3, 7):
permitted_choices = type_.__args__
else:
permitted_choices = type_.__values__
allowed_choices_set = set(permitted_choices)
def literal_validator(v: Any) -> Any:
if v not in allowed_choices_set:
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
return v
return literal_validator
T = TypeVar('T')
@@ -409,7 +434,9 @@ _VALIDATORS: List[Tuple[AnyType, List[Any]]] = [
]
def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[AnyCallable, None, None]:
def find_validators( # noqa: C901 (ignore complexity)
type_: AnyType, config: Type['BaseConfig']
) -> Generator[AnyCallable, None, None]:
if type_ is Any:
return
type_type = type(type_)
@@ -421,6 +448,9 @@ def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[Any
if is_callable_type(type_):
yield callable_validator
return
if is_literal_type(type_):
yield make_literal_validator(type_)
return
supertype = _find_supertype(type_)
if supertype is not None:
+1
View File
@@ -5,3 +5,4 @@
ujson==1.35
email-validator==1.0.4
dataclasses==0.6; python_version < '3.7'
typing-extensions==3.7.2
+1
View File
@@ -101,6 +101,7 @@ setup(
extras_require={
'ujson': ['ujson>=1.35'],
'email': ['email-validator>=1.0.3'],
'typing_extensions': ['typing-extensions>=3.7.2']
},
ext_modules=ext_modules,
)
+2 -2
View File
@@ -391,9 +391,9 @@ def test_const_with_wrong_value():
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': 'expected constant value 3',
'msg': 'unexpected value; permitted: 3',
'type': 'value_error.const',
'ctx': {'given': 4, 'const': 3},
'ctx': {'given': 4, 'permitted': [3]},
}
]
+42
View File
@@ -48,6 +48,11 @@ try:
except ImportError:
email_validator = None
try:
import typing_extensions
except ImportError:
typing_extensions = None
class ConBytesModel(BaseModel):
v: conbytes(max_length=10) = b'foobar'
@@ -1661,3 +1666,40 @@ def test_generic_without_params_error():
{'loc': ('generic_list',), 'msg': 'value is not a valid list', 'type': 'type_error.list'},
{'loc': ('generic_dict',), 'msg': 'value is not a valid dict', 'type': 'type_error.dict'},
]
@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
def test_literal_single():
class Model(BaseModel):
a: typing_extensions.Literal['a']
Model(a='a')
with pytest.raises(ValidationError) as exc_info:
Model(a='b')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'a'",
'type': 'value_error.const',
'ctx': {'given': 'b', 'permitted': ('a',)},
}
]
@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
def test_literal_multiple():
class Model(BaseModel):
a_or_b: typing_extensions.Literal['a', 'b']
Model(a_or_b='a')
Model(a_or_b='b')
with pytest.raises(ValidationError) as exc_info:
Model(a_or_b='c')
assert exc_info.value.errors() == [
{
'loc': ('a_or_b',),
'msg': "unexpected value; permitted: 'a', 'b'",
'type': 'value_error.const',
'ctx': {'given': 'c', 'permitted': ('a', 'b')},
}
]