diff --git a/pydantic/fields.py b/pydantic/fields.py index a121ef9..c7ea74e 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -1,27 +1,173 @@ -from typing import Any, Type +from collections import OrderedDict +from functools import partial, wraps +from inspect import signature +from pathlib import Path +from typing import Any, Callable, Dict, List, Type -class BaseField: +def str_validator(v) -> str: # TODO config + if isinstance(v, str): + return v + elif isinstance(v, bytes): + return v.decode() + return str(v) + + +def bytes_validator(v) -> bytes: + if isinstance(v, bytes): + return v + return str_validator(v).encode() + + +BOOL_STRINGS = { + '1', + 'TRUE', + 'ON', +} + + +def bool_validator(v) -> bool: + if isinstance(v, bool): + return v + if isinstance(v, bytes): + v = v.decode() + if isinstance(v, str): + return v.upper() in BOOL_STRINGS + return bool(v) + + +def number_size_validator(v, *, config): + if config.min_number_size <= v <= config.max_number_size: + raise ValueError(f'size not in range {config.min_number_size} to {config.max_number_size}') + return v + + +def anystr_length_validator(v, *, config): + if config.max_anystr_length <= len(v) <= config.max_anystr_length: + raise ValueError(f'length not in range {config.max_anystr_length} to {config.max_anystr_length}') + return v + + +class ValidatorsLookup: + def __init__(self): + self._validators_lookup: Dict[Type, List[Callable]] = { + int: [int, number_size_validator], + float: [float, number_size_validator], + Path: [Path], + str: [str_validator, anystr_length_validator], + bytes: [bytes_validator, anystr_length_validator], + bool: [bool_validator], + # TODO list, List, Dict, Union, datetime, date, time, custom types + } + self._validators_lookup_subclasses = [] + + def find(self, type_): + try: + return self._validators_lookup[type_] + except KeyError: + raise RuntimeError(f'no validator found for {type_}') + + def register(self, type_, *validators_): + self._validators_lookup[type_] = list(validators_) + + +validators_lookup = ValidatorsLookup() + + +def wrap_validator(func, config): + multi = False + try: + multi = len(signature(func).parameters) > 1 + except ValueError: + # happens on builtins like float + pass + if multi: + return wraps(func)(partial(func, config=config)) + return func + + +class Field: + __slots__ = 'type_', 'validators', 'default', 'required', 'name', 'description', 'info' + def __init__( self, *, + type_: Type, + validators: List[Callable]=None, default: Any=None, - v_type: Type=None, required: bool=False, + name: str=None, description: str=None): - if default and v_type: - raise RuntimeError('"default" and "v_type" cannot both be defined.') - elif default and required: - raise RuntimeError('It doesn\'t make sense to have "default" set and required=True.') - if default: - self.default = default - self.v_type = type(default) - else: - self.v_type = v_type + + if default and required: + raise RuntimeError('It doesn\'t make sense to have `default` set and `required=True`.') + + self.type_ = type_ + self.validators = validators + self.default = default self.required = required + self.name = name self.description = description + def prepare(self, name, config, class_validators): + self.name = self.name or name + if self.default and self.type_ is None: + self.type_ = type(self.default) -class EnvField(BaseField): + if self.type_ is None: + raise RuntimeError(f'unable to infer type for {self.name}') + + override_validator = class_validators.get(f'validate_{self.name}_override') + if override_validator: + self.validators = [override_validator] + + self.validators = self.validators or self._find_validator() + + self.validators.insert(0, class_validators.get(f'validate_{self.name}_pre')) + self.validators.append(class_validators.get(f'validate_{self.name}')) + self.validators.append(class_validators.get(f'validate_{self.name}_post')) + + self.validators = tuple(wrap_validator(v, config) for v in self.validators if v) + self.info = OrderedDict([ + ('type', self.type_.__name__), + ('default', self.default), + ('required', self.required), + ('validators', [f.__qualname__ for f in self.validators]) + ]) + if self.required: + self.info.pop('default') + if self.description: + self.info['description'] = self.description + + def _find_validator(self): + get_validators = getattr(self.type_, 'get_validators', None) + if get_validators: + return list(get_validators()) + return validators_lookup.find(self.type_) + + def validate(self, v): + for validator in self.validators: + v = validator(v) + return v + + @classmethod + def infer(cls, *, name, value, annotation, config, class_validators): + required = value == Ellipsis + instance = cls( + type_=annotation, + default=None if required else value, + required=required + ) + instance.prepare(name, config, class_validators) + return instance + + def __repr__(self): + return f'' + + def __str__(self): + return ', '.join(f'{k}={v!r}' for k, v in self.info.items()) + + +class EnvField(Field): def __init__(self, *, env=None, **kwargs): super().__init__(**kwargs) self.env_var_name = env diff --git a/pydantic/main.py b/pydantic/main.py index 8437452..03e82fd 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1,4 +1,28 @@ -from collections import OrderedDict +import json +from types import FunctionType +from collections import OrderedDict, namedtuple +from typing import Any, Dict + +from pydantic.fields import Field + + +DEFAULT_CONFIG: Dict[str, Any] = dict( + min_anystr_length=0, + max_anystr_length=2**16, + min_number_size=-2**64, + max_number_size=2**64, +) +Config = namedtuple('Config', list(DEFAULT_CONFIG.keys())) + + +def get_config(config_class): + if config_class: + for k, v in DEFAULT_CONFIG.items(): + if not hasattr(config_class, k): + setattr(config_class, k, v) + else: + config_class = Config(**DEFAULT_CONFIG) + return config_class class MetaModel(type): @@ -8,34 +32,70 @@ class MetaModel(type): def __new__(mcs, name, bases, namespace): fields = OrderedDict() + base_config = None for base in reversed(bases): if issubclass(base, BaseModel) and base != BaseModel: - fields.update(base.fields) + fields.update(base.__fields__) + base_config = base.config + annotations = namespace.get('__annotations__') - if annotations: - print(f'class {name}') - fields.update(annotations) - print(fields) + config = get_config(namespace.get('Config', base_config)) + class_validators = {n: f for n, f in namespace.items() + if n.startswith('validate_') and isinstance(f, FunctionType)} + + for var_name, value in namespace.items(): + if var_name.startswith('_') or isinstance(value, (property, FunctionType, type)): + continue + field = Field.infer( + name=var_name, + value=value, + annotation=annotations.get(var_name), + config=config, + class_validators=class_validators, + ) + fields[field.name] = field namespace.update( - fields=fields + config=config, + __fields__=fields, ) return super().__new__(mcs, name, bases, namespace) +class ValidationError(ValueError): + def __init__(self, errors): + self.errors = errors + super().__init__(f'{len(self.errors)} errors validating input: {json.dumps(errors)}') + + class BaseModel(metaclass=MetaModel): - def __init__(self, **custom_settings): - """ - :param custom_settings: Custom settings to override defaults, only attributes already defined can be set. - """ - self._dict = { - # **self._substitute_environ(custom_settings), - **self._get_custom_settings(custom_settings), - } - [setattr(self, k, v) for k, v in self._dict.items()] + __fields__ = {} # populated by the metaclass + __values__ = {} + + def __init__(self, **values): + errors = OrderedDict() + for name, field in self.__fields__.items(): + value = values.get(name) + if not value: + if field.required: + errors[name] = {'type': 'Missing', 'msg': 'field required'} + continue + try: + value = field.validate(value) + except (ValueError, TypeError) as e: + errors[name] = {'type': e.__class__.__name__, 'msg': str(e)} + else: + self.__values__[name] = value + setattr(self, name, value) + if errors: + raise ValidationError(errors) @property - def dict(self): - return self._dict + def values(self): + return self.__values__ + + @property + def fields(self): + return self.__fields__ def _get_custom_settings(self, custom_settings): d = {} @@ -46,8 +106,9 @@ class BaseModel(metaclass=MetaModel): return d def __iter__(self): - # so `dict(settings)` works - yield from self._dict.items() + # so `dict(model)` works + yield from self.__values__.items() def __repr__(self): - return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v) for k, v in self.dict.items())) + return '<{} {}>'.format(self.__class__.__name__, ' '.join('{}={!r}'.format(k, v) + for k, v in self.__values__.items())) diff --git a/pydantic/meta.py b/pydantic/meta.py deleted file mode 100644 index 0aada73..0000000 --- a/pydantic/meta.py +++ /dev/null @@ -1,34 +0,0 @@ -import ast -import sys -from pathlib import Path - - -def find_fields(module, cls_name): - - path = Path(sys.modules[module].__file__).resolve() - - file_node = ast.parse(path.read_text(), filename=path.name) - cls_node = None - for n in file_node.body: - if isinstance(n, ast.ClassDef) and n.name == cls_name: - cls_node = n - break - if cls_name is None: - raise RuntimeError(f"can't find {cls_name} in {file_node}") - _expression = None - for n in cls_node.body: - if isinstance(n, ast.Expr) and isinstance(n.value, ast.Str): - _expression = n.value.s - continue - # print(ast.dump(n)) - if not isinstance(n, (ast.AnnAssign, ast.Assign)): - _expression = None - continue - target = getattr(n, 'target', None) or n.targets[0] - name = target.id - if name.startswith('_'): - _expression = None - continue - - yield name, _expression - _expression = None diff --git a/pydantic/types.py b/pydantic/types.py index 2d0485c..d4d58d7 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -3,3 +3,42 @@ json JsonList JsonDict """ +from typing import Type + +from pydantic.fields import str_validator + + +class ConstrainedStr(str): + min_length = None + max_length = None + curtail_length = None + + @classmethod + def get_validators(cls): + yield str_validator + yield cls.validate + + @classmethod + def validate(cls, value): + l = len(value) + if cls.min_length and l < cls.min_length: + raise ValueError(f'length less than minimum allowed length {cls.min_length}') + + if cls.curtail_length: + if l > cls.curtail_length: + value = value[:cls.curtail_length] + elif cls.max_length and l > cls.max_length: + raise ValueError(f'length greater than maximum allowed length {cls.max_length}') + + return value + + +def constr(*, min_length=0, max_length=2**16, curtail_length=None) -> Type[str]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + ) + return type('ConstrainedStrValue', (ConstrainedStr,), namespace) +