basic validation working

This commit is contained in:
Samuel Colvin
2017-05-04 19:07:34 +01:00
parent a8e844dad5
commit 795e3604ef
4 changed files with 280 additions and 68 deletions
+159 -13
View File
@@ -1,27 +1,173 @@
from typing import Any, Type
from collections import OrderedDict
from functools import partial, wraps
from inspect import signature
from pathlib import Path
from typing import Any, Callable, Dict, List, Type
class BaseField:
def str_validator(v) -> str: # TODO config
if isinstance(v, str):
return v
elif isinstance(v, bytes):
return v.decode()
return str(v)
def bytes_validator(v) -> bytes:
if isinstance(v, bytes):
return v
return str_validator(v).encode()
BOOL_STRINGS = {
'1',
'TRUE',
'ON',
}
def bool_validator(v) -> bool:
if isinstance(v, bool):
return v
if isinstance(v, bytes):
v = v.decode()
if isinstance(v, str):
return v.upper() in BOOL_STRINGS
return bool(v)
def number_size_validator(v, *, config):
if config.min_number_size <= v <= config.max_number_size:
raise ValueError(f'size not in range {config.min_number_size} to {config.max_number_size}')
return v
def anystr_length_validator(v, *, config):
if config.max_anystr_length <= len(v) <= config.max_anystr_length:
raise ValueError(f'length not in range {config.max_anystr_length} to {config.max_anystr_length}')
return v
class ValidatorsLookup:
def __init__(self):
self._validators_lookup: Dict[Type, List[Callable]] = {
int: [int, number_size_validator],
float: [float, number_size_validator],
Path: [Path],
str: [str_validator, anystr_length_validator],
bytes: [bytes_validator, anystr_length_validator],
bool: [bool_validator],
# TODO list, List, Dict, Union, datetime, date, time, custom types
}
self._validators_lookup_subclasses = []
def find(self, type_):
try:
return self._validators_lookup[type_]
except KeyError:
raise RuntimeError(f'no validator found for {type_}')
def register(self, type_, *validators_):
self._validators_lookup[type_] = list(validators_)
validators_lookup = ValidatorsLookup()
def wrap_validator(func, config):
multi = False
try:
multi = len(signature(func).parameters) > 1
except ValueError:
# happens on builtins like float
pass
if multi:
return wraps(func)(partial(func, config=config))
return func
class Field:
__slots__ = 'type_', 'validators', 'default', 'required', 'name', 'description', 'info'
def __init__(
self, *,
type_: Type,
validators: List[Callable]=None,
default: Any=None,
v_type: Type=None,
required: bool=False,
name: str=None,
description: str=None):
if default and v_type:
raise RuntimeError('"default" and "v_type" cannot both be defined.')
elif default and required:
raise RuntimeError('It doesn\'t make sense to have "default" set and required=True.')
if default:
self.default = default
self.v_type = type(default)
else:
self.v_type = v_type
if default and required:
raise RuntimeError('It doesn\'t make sense to have `default` set and `required=True`.')
self.type_ = type_
self.validators = validators
self.default = default
self.required = required
self.name = name
self.description = description
def prepare(self, name, config, class_validators):
self.name = self.name or name
if self.default and self.type_ is None:
self.type_ = type(self.default)
class EnvField(BaseField):
if self.type_ is None:
raise RuntimeError(f'unable to infer type for {self.name}')
override_validator = class_validators.get(f'validate_{self.name}_override')
if override_validator:
self.validators = [override_validator]
self.validators = self.validators or self._find_validator()
self.validators.insert(0, class_validators.get(f'validate_{self.name}_pre'))
self.validators.append(class_validators.get(f'validate_{self.name}'))
self.validators.append(class_validators.get(f'validate_{self.name}_post'))
self.validators = tuple(wrap_validator(v, config) for v in self.validators if v)
self.info = OrderedDict([
('type', self.type_.__name__),
('default', self.default),
('required', self.required),
('validators', [f.__qualname__ for f in self.validators])
])
if self.required:
self.info.pop('default')
if self.description:
self.info['description'] = self.description
def _find_validator(self):
get_validators = getattr(self.type_, 'get_validators', None)
if get_validators:
return list(get_validators())
return validators_lookup.find(self.type_)
def validate(self, v):
for validator in self.validators:
v = validator(v)
return v
@classmethod
def infer(cls, *, name, value, annotation, config, class_validators):
required = value == Ellipsis
instance = cls(
type_=annotation,
default=None if required else value,
required=required
)
instance.prepare(name, config, class_validators)
return instance
def __repr__(self):
return f'<Field: {self}>'
def __str__(self):
return ', '.join(f'{k}={v!r}' for k, v in self.info.items())
class EnvField(Field):
def __init__(self, *, env=None, **kwargs):
super().__init__(**kwargs)
self.env_var_name = env
+82 -21
View File
@@ -1,4 +1,28 @@
from collections import OrderedDict
import json
from types import FunctionType
from collections import OrderedDict, namedtuple
from typing import Any, Dict
from pydantic.fields import Field
DEFAULT_CONFIG: Dict[str, Any] = dict(
min_anystr_length=0,
max_anystr_length=2**16,
min_number_size=-2**64,
max_number_size=2**64,
)
Config = namedtuple('Config', list(DEFAULT_CONFIG.keys()))
def get_config(config_class):
if config_class:
for k, v in DEFAULT_CONFIG.items():
if not hasattr(config_class, k):
setattr(config_class, k, v)
else:
config_class = Config(**DEFAULT_CONFIG)
return config_class
class MetaModel(type):
@@ -8,34 +32,70 @@ class MetaModel(type):
def __new__(mcs, name, bases, namespace):
fields = OrderedDict()
base_config = None
for base in reversed(bases):
if issubclass(base, BaseModel) and base != BaseModel:
fields.update(base.fields)
fields.update(base.__fields__)
base_config = base.config
annotations = namespace.get('__annotations__')
if annotations:
print(f'class {name}')
fields.update(annotations)
print(fields)
config = get_config(namespace.get('Config', base_config))
class_validators = {n: f for n, f in namespace.items()
if n.startswith('validate_') and isinstance(f, FunctionType)}
for var_name, value in namespace.items():
if var_name.startswith('_') or isinstance(value, (property, FunctionType, type)):
continue
field = Field.infer(
name=var_name,
value=value,
annotation=annotations.get(var_name),
config=config,
class_validators=class_validators,
)
fields[field.name] = field
namespace.update(
fields=fields
config=config,
__fields__=fields,
)
return super().__new__(mcs, name, bases, namespace)
class ValidationError(ValueError):
def __init__(self, errors):
self.errors = errors
super().__init__(f'{len(self.errors)} errors validating input: {json.dumps(errors)}')
class BaseModel(metaclass=MetaModel):
def __init__(self, **custom_settings):
"""
:param custom_settings: Custom settings to override defaults, only attributes already defined can be set.
"""
self._dict = {
# **self._substitute_environ(custom_settings),
**self._get_custom_settings(custom_settings),
}
[setattr(self, k, v) for k, v in self._dict.items()]
__fields__ = {} # populated by the metaclass
__values__ = {}
def __init__(self, **values):
errors = OrderedDict()
for name, field in self.__fields__.items():
value = values.get(name)
if not value:
if field.required:
errors[name] = {'type': 'Missing', 'msg': 'field required'}
continue
try:
value = field.validate(value)
except (ValueError, TypeError) as e:
errors[name] = {'type': e.__class__.__name__, 'msg': str(e)}
else:
self.__values__[name] = value
setattr(self, name, value)
if errors:
raise ValidationError(errors)
@property
def dict(self):
return self._dict
def values(self):
return self.__values__
@property
def fields(self):
return self.__fields__
def _get_custom_settings(self, custom_settings):
d = {}
@@ -46,8 +106,9 @@ class BaseModel(metaclass=MetaModel):
return d
def __iter__(self):
# so `dict(settings)` works
yield from self._dict.items()
# so `dict(model)` works
yield from self.__values__.items()
def __repr__(self):
return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v) for k, v in self.dict.items()))
return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v)
for k, v in self.__values__.items()))
-34
View File
@@ -1,34 +0,0 @@
import ast
import sys
from pathlib import Path
def find_fields(module, cls_name):
path = Path(sys.modules[module].__file__).resolve()
file_node = ast.parse(path.read_text(), filename=path.name)
cls_node = None
for n in file_node.body:
if isinstance(n, ast.ClassDef) and n.name == cls_name:
cls_node = n
break
if cls_name is None:
raise RuntimeError(f"can't find {cls_name} in {file_node}")
_expression = None
for n in cls_node.body:
if isinstance(n, ast.Expr) and isinstance(n.value, ast.Str):
_expression = n.value.s
continue
# print(ast.dump(n))
if not isinstance(n, (ast.AnnAssign, ast.Assign)):
_expression = None
continue
target = getattr(n, 'target', None) or n.targets[0]
name = target.id
if name.startswith('_'):
_expression = None
continue
yield name, _expression
_expression = None
+39
View File
@@ -3,3 +3,42 @@ json
JsonList
JsonDict
"""
from typing import Type
from pydantic.fields import str_validator
class ConstrainedStr(str):
min_length = None
max_length = None
curtail_length = None
@classmethod
def get_validators(cls):
yield str_validator
yield cls.validate
@classmethod
def validate(cls, value):
l = len(value)
if cls.min_length and l < cls.min_length:
raise ValueError(f'length less than minimum allowed length {cls.min_length}')
if cls.curtail_length:
if l > cls.curtail_length:
value = value[:cls.curtail_length]
elif cls.max_length and l > cls.max_length:
raise ValueError(f'length greater than maximum allowed length {cls.max_length}')
return value
def constr(*, min_length=0, max_length=2**16, curtail_length=None) -> Type[str]:
# use kwargs then define conf in a dict to aid with IDE type hinting
namespace = dict(
min_length=min_length,
max_length=max_length,
curtail_length=curtail_length,
)
return type('ConstrainedStrValue', (ConstrainedStr,), namespace)