diff --git a/changes/1974-uriyyo.md b/changes/1974-uriyyo.md new file mode 100644 index 0000000..7c998cc --- /dev/null +++ b/changes/1974-uriyyo.md @@ -0,0 +1 @@ +Add ability to use `min_length/max_length` constraints with secret types \ No newline at end of file diff --git a/pydantic/schema.py b/pydantic/schema.py index db47075..079e08c 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -48,6 +48,8 @@ from .types import ( ConstrainedList, ConstrainedSet, ConstrainedStr, + SecretBytes, + SecretStr, conbytes, condecimal, confloat, @@ -905,7 +907,13 @@ def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field attrs: Optional[Tuple[str, ...]] = None constraint_func: Optional[Callable[..., type]] = None if isinstance(type_, type): - if issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl, ConstrainedStr)): + if issubclass(type_, (SecretStr, SecretBytes)): + attrs = ('max_length', 'min_length') + + def constraint_func(**kwargs: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kwargs) + + elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl, ConstrainedStr)): attrs = ('max_length', 'min_length', 'regex') constraint_func = constr elif issubclass(type_, bytes): diff --git a/pydantic/types.py b/pydantic/types.py index 79693dd..92ea8f2 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -598,13 +598,24 @@ class Json(metaclass=JsonMeta): class SecretStr: + min_length: OptionalInt = None + max_length: OptionalInt = None + @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', writeOnly=True, format='password') + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) @classmethod def __get_validators__(cls) -> 'CallableGenerator': yield cls.validate + yield constr_length_validator @classmethod def validate(cls, value: Any) -> 'SecretStr': @@ -625,6 +636,9 @@ class SecretStr: def __eq__(self, other: Any) -> bool: return isinstance(other, SecretStr) and self.get_secret_value() == other.get_secret_value() + def __len__(self) -> int: + return len(self._secret_value) + def display(self) -> str: warnings.warn('`secret_str.display()` is deprecated, use `str(secret_str)` instead', DeprecationWarning) return str(self) @@ -634,13 +648,24 @@ class SecretStr: class SecretBytes: + min_length: OptionalInt = None + max_length: OptionalInt = None + @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', writeOnly=True, format='password') + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) @classmethod def __get_validators__(cls) -> 'CallableGenerator': yield cls.validate + yield constr_length_validator @classmethod def validate(cls, value: Any) -> 'SecretBytes': @@ -661,6 +686,9 @@ class SecretBytes: def __eq__(self, other: Any) -> bool: return isinstance(other, SecretBytes) and self.get_secret_value() == other.get_secret_value() + def __len__(self) -> int: + return len(self._secret_value) + def display(self) -> str: warnings.warn('`secret_bytes.display()` is deprecated, use `str(secret_bytes)` instead', DeprecationWarning) return str(self) diff --git a/tests/test_types.py b/tests/test_types.py index 2eb55bf..7e4c7e0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -2160,6 +2160,38 @@ def test_secretstr_error(): assert exc_info.value.errors() == [{'loc': ('password',), 'msg': 'str type expected', 'type': 'type_error.str'}] +def test_secretstr_min_max_length(): + class Foobar(BaseModel): + password: SecretStr = Field(min_length=6, max_length=10) + + with pytest.raises(ValidationError) as exc_info: + Foobar(password='') + + assert exc_info.value.errors() == [ + { + 'loc': ('password',), + 'msg': 'ensure this value has at least 6 characters', + 'type': 'value_error.any_str.min_length', + 'ctx': {'limit_value': 6}, + } + ] + + with pytest.raises(ValidationError) as exc_info: + Foobar(password='1' * 20) + + assert exc_info.value.errors() == [ + { + 'loc': ('password',), + 'msg': 'ensure this value has at most 10 characters', + 'type': 'value_error.any_str.max_length', + 'ctx': {'limit_value': 10}, + } + ] + + value = '1' * 8 + assert Foobar(password=value).password.get_secret_value() == value + + def test_secretbytes(): class Foobar(BaseModel): password: SecretBytes @@ -2216,6 +2248,63 @@ def test_secretbytes_error(): assert exc_info.value.errors() == [{'loc': ('password',), 'msg': 'byte type expected', 'type': 'type_error.bytes'}] +def test_secretbytes_min_max_length(): + class Foobar(BaseModel): + password: SecretBytes = Field(min_length=6, max_length=10) + + with pytest.raises(ValidationError) as exc_info: + Foobar(password=b'') + + assert exc_info.value.errors() == [ + { + 'loc': ('password',), + 'msg': 'ensure this value has at least 6 characters', + 'type': 'value_error.any_str.min_length', + 'ctx': {'limit_value': 6}, + } + ] + + with pytest.raises(ValidationError) as exc_info: + Foobar(password=b'1' * 20) + + assert exc_info.value.errors() == [ + { + 'loc': ('password',), + 'msg': 'ensure this value has at most 10 characters', + 'type': 'value_error.any_str.max_length', + 'ctx': {'limit_value': 10}, + } + ] + + value = b'1' * 8 + assert Foobar(password=value).password.get_secret_value() == value + + +@pytest.mark.parametrize('secret_cls', [SecretStr, SecretBytes]) +@pytest.mark.parametrize( + 'field_kw,schema_kw', + [ + [{}, {}], + [{'min_length': 6}, {'minLength': 6}], + [{'max_length': 10}, {'maxLength': 10}], + [{'min_length': 6, 'max_length': 10}, {'minLength': 6, 'maxLength': 10}], + ], + ids=['no-constrains', 'min-constraint', 'max-constraint', 'min-max-constraints'], +) +def test_secrets_schema(secret_cls, field_kw, schema_kw): + class Foobar(BaseModel): + password: secret_cls = Field(**field_kw) + + assert Foobar.schema() == { + 'title': 'Foobar', + 'type': 'object', + 'properties': { + 'password': {'title': 'Password', 'type': 'string', 'writeOnly': True, 'format': 'password', **schema_kw} + }, + 'required': ['password'], + } + + def test_generic_without_params(): class Model(BaseModel): generic_list: List