mirror of
https://github.com/kennethreitz/pydantic.git
synced 2026-06-05 23:00:18 +00:00
basic validation working
This commit is contained in:
+159
-13
@@ -1,27 +1,173 @@
|
||||
from typing import Any, Type
|
||||
from collections import OrderedDict
|
||||
from functools import partial, wraps
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
|
||||
class BaseField:
|
||||
def str_validator(v) -> str: # TODO config
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
elif isinstance(v, bytes):
|
||||
return v.decode()
|
||||
return str(v)
|
||||
|
||||
|
||||
def bytes_validator(v) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return str_validator(v).encode()
|
||||
|
||||
|
||||
BOOL_STRINGS = {
|
||||
'1',
|
||||
'TRUE',
|
||||
'ON',
|
||||
}
|
||||
|
||||
|
||||
def bool_validator(v) -> bool:
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
if isinstance(v, str):
|
||||
return v.upper() in BOOL_STRINGS
|
||||
return bool(v)
|
||||
|
||||
|
||||
def number_size_validator(v, *, config):
|
||||
if config.min_number_size <= v <= config.max_number_size:
|
||||
raise ValueError(f'size not in range {config.min_number_size} to {config.max_number_size}')
|
||||
return v
|
||||
|
||||
|
||||
def anystr_length_validator(v, *, config):
|
||||
if config.max_anystr_length <= len(v) <= config.max_anystr_length:
|
||||
raise ValueError(f'length not in range {config.max_anystr_length} to {config.max_anystr_length}')
|
||||
return v
|
||||
|
||||
|
||||
class ValidatorsLookup:
|
||||
def __init__(self):
|
||||
self._validators_lookup: Dict[Type, List[Callable]] = {
|
||||
int: [int, number_size_validator],
|
||||
float: [float, number_size_validator],
|
||||
Path: [Path],
|
||||
str: [str_validator, anystr_length_validator],
|
||||
bytes: [bytes_validator, anystr_length_validator],
|
||||
bool: [bool_validator],
|
||||
# TODO list, List, Dict, Union, datetime, date, time, custom types
|
||||
}
|
||||
self._validators_lookup_subclasses = []
|
||||
|
||||
def find(self, type_):
|
||||
try:
|
||||
return self._validators_lookup[type_]
|
||||
except KeyError:
|
||||
raise RuntimeError(f'no validator found for {type_}')
|
||||
|
||||
def register(self, type_, *validators_):
|
||||
self._validators_lookup[type_] = list(validators_)
|
||||
|
||||
|
||||
validators_lookup = ValidatorsLookup()
|
||||
|
||||
|
||||
def wrap_validator(func, config):
|
||||
multi = False
|
||||
try:
|
||||
multi = len(signature(func).parameters) > 1
|
||||
except ValueError:
|
||||
# happens on builtins like float
|
||||
pass
|
||||
if multi:
|
||||
return wraps(func)(partial(func, config=config))
|
||||
return func
|
||||
|
||||
|
||||
class Field:
|
||||
__slots__ = 'type_', 'validators', 'default', 'required', 'name', 'description', 'info'
|
||||
|
||||
def __init__(
|
||||
self, *,
|
||||
type_: Type,
|
||||
validators: List[Callable]=None,
|
||||
default: Any=None,
|
||||
v_type: Type=None,
|
||||
required: bool=False,
|
||||
name: str=None,
|
||||
description: str=None):
|
||||
if default and v_type:
|
||||
raise RuntimeError('"default" and "v_type" cannot both be defined.')
|
||||
elif default and required:
|
||||
raise RuntimeError('It doesn\'t make sense to have "default" set and required=True.')
|
||||
if default:
|
||||
self.default = default
|
||||
self.v_type = type(default)
|
||||
else:
|
||||
self.v_type = v_type
|
||||
|
||||
if default and required:
|
||||
raise RuntimeError('It doesn\'t make sense to have `default` set and `required=True`.')
|
||||
|
||||
self.type_ = type_
|
||||
self.validators = validators
|
||||
self.default = default
|
||||
self.required = required
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def prepare(self, name, config, class_validators):
|
||||
self.name = self.name or name
|
||||
if self.default and self.type_ is None:
|
||||
self.type_ = type(self.default)
|
||||
|
||||
class EnvField(BaseField):
|
||||
if self.type_ is None:
|
||||
raise RuntimeError(f'unable to infer type for {self.name}')
|
||||
|
||||
override_validator = class_validators.get(f'validate_{self.name}_override')
|
||||
if override_validator:
|
||||
self.validators = [override_validator]
|
||||
|
||||
self.validators = self.validators or self._find_validator()
|
||||
|
||||
self.validators.insert(0, class_validators.get(f'validate_{self.name}_pre'))
|
||||
self.validators.append(class_validators.get(f'validate_{self.name}'))
|
||||
self.validators.append(class_validators.get(f'validate_{self.name}_post'))
|
||||
|
||||
self.validators = tuple(wrap_validator(v, config) for v in self.validators if v)
|
||||
self.info = OrderedDict([
|
||||
('type', self.type_.__name__),
|
||||
('default', self.default),
|
||||
('required', self.required),
|
||||
('validators', [f.__qualname__ for f in self.validators])
|
||||
])
|
||||
if self.required:
|
||||
self.info.pop('default')
|
||||
if self.description:
|
||||
self.info['description'] = self.description
|
||||
|
||||
def _find_validator(self):
|
||||
get_validators = getattr(self.type_, 'get_validators', None)
|
||||
if get_validators:
|
||||
return list(get_validators())
|
||||
return validators_lookup.find(self.type_)
|
||||
|
||||
def validate(self, v):
|
||||
for validator in self.validators:
|
||||
v = validator(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def infer(cls, *, name, value, annotation, config, class_validators):
|
||||
required = value == Ellipsis
|
||||
instance = cls(
|
||||
type_=annotation,
|
||||
default=None if required else value,
|
||||
required=required
|
||||
)
|
||||
instance.prepare(name, config, class_validators)
|
||||
return instance
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Field: {self}>'
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(f'{k}={v!r}' for k, v in self.info.items())
|
||||
|
||||
|
||||
class EnvField(Field):
|
||||
def __init__(self, *, env=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.env_var_name = env
|
||||
|
||||
+82
-21
@@ -1,4 +1,28 @@
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
from types import FunctionType
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
DEFAULT_CONFIG: Dict[str, Any] = dict(
|
||||
min_anystr_length=0,
|
||||
max_anystr_length=2**16,
|
||||
min_number_size=-2**64,
|
||||
max_number_size=2**64,
|
||||
)
|
||||
Config = namedtuple('Config', list(DEFAULT_CONFIG.keys()))
|
||||
|
||||
|
||||
def get_config(config_class):
|
||||
if config_class:
|
||||
for k, v in DEFAULT_CONFIG.items():
|
||||
if not hasattr(config_class, k):
|
||||
setattr(config_class, k, v)
|
||||
else:
|
||||
config_class = Config(**DEFAULT_CONFIG)
|
||||
return config_class
|
||||
|
||||
|
||||
class MetaModel(type):
|
||||
@@ -8,34 +32,70 @@ class MetaModel(type):
|
||||
|
||||
def __new__(mcs, name, bases, namespace):
|
||||
fields = OrderedDict()
|
||||
base_config = None
|
||||
for base in reversed(bases):
|
||||
if issubclass(base, BaseModel) and base != BaseModel:
|
||||
fields.update(base.fields)
|
||||
fields.update(base.__fields__)
|
||||
base_config = base.config
|
||||
|
||||
annotations = namespace.get('__annotations__')
|
||||
if annotations:
|
||||
print(f'class {name}')
|
||||
fields.update(annotations)
|
||||
print(fields)
|
||||
config = get_config(namespace.get('Config', base_config))
|
||||
class_validators = {n: f for n, f in namespace.items()
|
||||
if n.startswith('validate_') and isinstance(f, FunctionType)}
|
||||
|
||||
for var_name, value in namespace.items():
|
||||
if var_name.startswith('_') or isinstance(value, (property, FunctionType, type)):
|
||||
continue
|
||||
field = Field.infer(
|
||||
name=var_name,
|
||||
value=value,
|
||||
annotation=annotations.get(var_name),
|
||||
config=config,
|
||||
class_validators=class_validators,
|
||||
)
|
||||
fields[field.name] = field
|
||||
namespace.update(
|
||||
fields=fields
|
||||
config=config,
|
||||
__fields__=fields,
|
||||
)
|
||||
return super().__new__(mcs, name, bases, namespace)
|
||||
|
||||
|
||||
class ValidationError(ValueError):
|
||||
def __init__(self, errors):
|
||||
self.errors = errors
|
||||
super().__init__(f'{len(self.errors)} errors validating input: {json.dumps(errors)}')
|
||||
|
||||
|
||||
class BaseModel(metaclass=MetaModel):
|
||||
def __init__(self, **custom_settings):
|
||||
"""
|
||||
:param custom_settings: Custom settings to override defaults, only attributes already defined can be set.
|
||||
"""
|
||||
self._dict = {
|
||||
# **self._substitute_environ(custom_settings),
|
||||
**self._get_custom_settings(custom_settings),
|
||||
}
|
||||
[setattr(self, k, v) for k, v in self._dict.items()]
|
||||
__fields__ = {} # populated by the metaclass
|
||||
__values__ = {}
|
||||
|
||||
def __init__(self, **values):
|
||||
errors = OrderedDict()
|
||||
for name, field in self.__fields__.items():
|
||||
value = values.get(name)
|
||||
if not value:
|
||||
if field.required:
|
||||
errors[name] = {'type': 'Missing', 'msg': 'field required'}
|
||||
continue
|
||||
try:
|
||||
value = field.validate(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
errors[name] = {'type': e.__class__.__name__, 'msg': str(e)}
|
||||
else:
|
||||
self.__values__[name] = value
|
||||
setattr(self, name, value)
|
||||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
@property
|
||||
def dict(self):
|
||||
return self._dict
|
||||
def values(self):
|
||||
return self.__values__
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return self.__fields__
|
||||
|
||||
def _get_custom_settings(self, custom_settings):
|
||||
d = {}
|
||||
@@ -46,8 +106,9 @@ class BaseModel(metaclass=MetaModel):
|
||||
return d
|
||||
|
||||
def __iter__(self):
|
||||
# so `dict(settings)` works
|
||||
yield from self._dict.items()
|
||||
# so `dict(model)` works
|
||||
yield from self.__values__.items()
|
||||
|
||||
def __repr__(self):
|
||||
return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v) for k, v in self.dict.items()))
|
||||
return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v)
|
||||
for k, v in self.__values__.items()))
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_fields(module, cls_name):
|
||||
|
||||
path = Path(sys.modules[module].__file__).resolve()
|
||||
|
||||
file_node = ast.parse(path.read_text(), filename=path.name)
|
||||
cls_node = None
|
||||
for n in file_node.body:
|
||||
if isinstance(n, ast.ClassDef) and n.name == cls_name:
|
||||
cls_node = n
|
||||
break
|
||||
if cls_name is None:
|
||||
raise RuntimeError(f"can't find {cls_name} in {file_node}")
|
||||
_expression = None
|
||||
for n in cls_node.body:
|
||||
if isinstance(n, ast.Expr) and isinstance(n.value, ast.Str):
|
||||
_expression = n.value.s
|
||||
continue
|
||||
# print(ast.dump(n))
|
||||
if not isinstance(n, (ast.AnnAssign, ast.Assign)):
|
||||
_expression = None
|
||||
continue
|
||||
target = getattr(n, 'target', None) or n.targets[0]
|
||||
name = target.id
|
||||
if name.startswith('_'):
|
||||
_expression = None
|
||||
continue
|
||||
|
||||
yield name, _expression
|
||||
_expression = None
|
||||
@@ -3,3 +3,42 @@ json
|
||||
JsonList
|
||||
JsonDict
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from pydantic.fields import str_validator
|
||||
|
||||
|
||||
class ConstrainedStr(str):
|
||||
min_length = None
|
||||
max_length = None
|
||||
curtail_length = None
|
||||
|
||||
@classmethod
|
||||
def get_validators(cls):
|
||||
yield str_validator
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value):
|
||||
l = len(value)
|
||||
if cls.min_length and l < cls.min_length:
|
||||
raise ValueError(f'length less than minimum allowed length {cls.min_length}')
|
||||
|
||||
if cls.curtail_length:
|
||||
if l > cls.curtail_length:
|
||||
value = value[:cls.curtail_length]
|
||||
elif cls.max_length and l > cls.max_length:
|
||||
raise ValueError(f'length greater than maximum allowed length {cls.max_length}')
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def constr(*, min_length=0, max_length=2**16, curtail_length=None) -> Type[str]:
|
||||
# use kwargs then define conf in a dict to aid with IDE type hinting
|
||||
namespace = dict(
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
curtail_length=curtail_length,
|
||||
)
|
||||
return type('ConstrainedStrValue', (ConstrainedStr,), namespace)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user