doc: better partial support (#527)

This commit is contained in:
Michael Hahn
2024-03-22 17:28:57 -07:00
committed by GitHub
parent 726ca86c95
commit 7fcad59cef
3 changed files with 189 additions and 36 deletions
+111 -29
View File
@@ -21,6 +21,7 @@ from typing import (
TypeVar,
)
from copy import deepcopy
from functools import lru_cache
from instructor.mode import Mode
from instructor.dsl.partialjson import JSONParser
@@ -30,7 +31,72 @@ parser = JSONParser()
T_Model = TypeVar("T_Model", bound=BaseModel)
class MakeFieldsOptional:
pass
def _make_field_optional(
field: FieldInfo,
) -> tuple[object, FieldInfo]:
tmp_field = deepcopy(field)
annotation = field.annotation
# Handle generics (like List, Dict, etc.)
if get_origin(annotation) is not None:
# Get the generic base (like List, Dict) and its arguments (like User in List[User])
generic_base = get_origin(annotation)
generic_args = get_args(annotation)
# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
)
# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type]
tmp_field.default = {}
else:
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
tmp_field.default = None
return tmp_field.annotation, tmp_field
class PartialBase(Generic[T_Model]):
@classmethod
@lru_cache(maxsize=None)
def get_partial_model(cls) -> type[T_Model]:
"""Return a partial model we can use to validate partial results."""
assert issubclass(
cls, BaseModel
), f"{cls.__name__} must be a subclass of BaseModel"
return create_model(
__model_name=(
cls.__name__
if cls.__name__.startswith("Partial")
else f"Partial{cls.__name__}"
),
__base__=cls,
__module__=cls.__module__,
**{
field_name: _make_field_optional(field_info)
for field_name, field_info in cls.model_fields.items()
},
) # type: ignore[all]
@classmethod
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
@@ -59,6 +125,7 @@ class PartialBase(Generic[T_Model]):
) -> Generator[T_Model, None, None]:
prev_obj = None
potential_object = ""
partial_model = cls.get_partial_model()
for chunk in json_chunks:
potential_object += chunk
@@ -67,11 +134,11 @@ class PartialBase(Generic[T_Model]):
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
] = chunk # Provide the raw chunk for debugging and benchmarking
obj.__dict__["chunk"] = (
chunk # Provide the raw chunk for debugging and benchmarking
)
prev_obj = obj
yield obj
@@ -81,6 +148,7 @@ class PartialBase(Generic[T_Model]):
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
prev_obj = None
partial_model = cls.get_partial_model()
async for chunk in json_chunks:
potential_object += chunk
@@ -89,11 +157,11 @@ class PartialBase(Generic[T_Model]):
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
] = chunk # Provide the raw chunk for debugging and benchmarking
obj.__dict__["chunk"] = (
chunk # Provide the raw chunk for debugging and benchmarking
)
prev_obj = obj
yield obj
@@ -145,11 +213,10 @@ class PartialBase(Generic[T_Model]):
class Partial(Generic[T_Model]):
"""Generate a new class with all attributes optionals.
"""Generate a new class which has PartialBase as a base class.
Notes:
This will wrap a class inheriting form BaseModel and will recursively
convert all its attributes and its children's attributes to optionals.
This will enable partial validation of the model while streaming.
Example:
Partial[SomeModel]
@@ -181,13 +248,23 @@ class Partial(Generic[T_Model]):
def __class_getitem__( # type: ignore[override]
cls,
wrapped_class: type[T_Model],
wrapped_class: type[T_Model] | tuple[type[T_Model], type[MakeFieldsOptional]],
) -> type[T_Model]:
"""Convert model to a partial model with all fields being optionals."""
"""Convert model to one that inherits from PartialBase.
def _make_field_optional(
field: FieldInfo,
) -> tuple[object, FieldInfo]:
We don't make the fields optional at this point, we just wrap them with `Partial` so the names of the nested models will be
`Partial{ModelName}`. We want the output of `model_json_schema()` to
reflect the name change, but everything else should be the same as the
original model. During validation, we'll generate a true partial model
to support partially defined fields.
"""
make_fields_optional = None
if isinstance(wrapped_class, tuple):
wrapped_class, make_fields_optional = wrapped_class
def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]:
tmp_field = deepcopy(field)
annotation = field.annotation
@@ -200,33 +277,38 @@ class Partial(Generic[T_Model]):
# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
Partial[arg] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
(
Partial[arg] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
)
# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
generic_base[modified_args] if generic_base else None
)
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
tmp_field.annotation = Optional[Partial[annotation]] # type: ignore[assignment, valid-type]
tmp_field.default = {}
else:
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
tmp_field.default = None
tmp_field.annotation = Partial[annotation] # type: ignore[assignment, valid-type]
return tmp_field.annotation, tmp_field
return create_model(
__model_name=f"Partial{wrapped_class.__name__}",
__model_name=(
wrapped_class.__name__
if wrapped_class.__name__.startswith("Partial")
else f"Partial{wrapped_class.__name__}"
),
__base__=(wrapped_class, PartialBase),
__module__=wrapped_class.__module__,
**{
field_name: _make_field_optional(field_info)
for field_name, field_info in wrapped_class.__fields__.items()
field_name: (
_make_field_optional(field_info)
if make_fields_optional is not None
else _wrap_models(field_info)
)
for field_name, field_info in wrapped_class.model_fields.items()
},
) # type: ignore[all]
Generated
+10 -7
View File
@@ -125,7 +125,7 @@ files = [
name = "anthropic"
version = "0.18.1"
description = "The official Python library for the anthropic API"
optional = false
optional = true
python-versions = ">=3.7"
files = [
{file = "anthropic-0.18.1-py3-none-any.whl", hash = "sha256:b85aee64f619ce1b1964ba733a09adc4053e7bc4e6d4186001229ec191099dcf"},
@@ -826,7 +826,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc
name = "filelock"
version = "3.13.1"
description = "A platform independent file lock."
optional = false
optional = true
python-versions = ">=3.8"
files = [
{file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
@@ -928,7 +928,7 @@ files = [
name = "fsspec"
version = "2024.3.1"
description = "File-system specification"
optional = false
optional = true
python-versions = ">=3.8"
files = [
{file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
@@ -1091,7 +1091,7 @@ socks = ["socksio (==1.*)"]
name = "huggingface-hub"
version = "0.21.4"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
optional = true
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
@@ -3268,7 +3268,7 @@ test = ["flake8", "isort", "pytest"]
name = "tokenizers"
version = "0.15.2"
description = ""
optional = false
optional = true
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"},
@@ -3595,7 +3595,7 @@ files = [
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
optional = true
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
@@ -3705,7 +3705,10 @@ files = [
idna = ">=2.0"
multidict = ">=4.0"
[extras]
anthropic = ["anthropic", "xmltodict"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "5780411ed7df9fcd0c67d664cd4898ba4be9fb7ee7a74538e8e5937a72359142"
content-hash = "35fcdd72be242b352d6d5790f4f37d4593c3047a33ed77bca9575b08e2c485a9"
+68
View File
@@ -0,0 +1,68 @@
# type: ignore[all]
from pydantic import BaseModel
from instructor.dsl.partial import Partial
class SampleNestedPartial(BaseModel):
b: int
class SamplePartial(BaseModel):
a: int
b: SampleNestedPartial
def test_partial():
partial = Partial[SamplePartial]
assert partial.model_json_schema() == {
"$defs": {
"PartialSampleNestedPartial": {
"properties": {"b": {"title": "B", "type": "integer"}},
"required": ["b"],
"title": "PartialSampleNestedPartial",
"type": "object",
}
},
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"$ref": "#/$defs/PartialSampleNestedPartial"},
},
"required": ["a", "b"],
"title": "PartialSamplePartial",
"type": "object",
}, "Wrapped model JSON schema has changed"
assert partial.get_partial_model().model_json_schema() == {
"$defs": {
"PartialSampleNestedPartial": {
"properties": {
"b": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"default": None,
"title": "B",
}
},
"title": "PartialSampleNestedPartial",
"type": "object",
}
},
"properties": {
"a": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"default": None,
"title": "A",
},
"b": {
"anyOf": [
{"$ref": "#/$defs/PartialSampleNestedPartial"},
{"type": "null"},
],
"default": {},
},
},
"title": "PartialSamplePartial",
"type": "object",
}, "Partial model JSON schema has changed"
for model in partial.model_from_chunks(['{"b": {"b": 1}}']):
assert model.model_dump() == {"a": None, "b": {"b": 1}}