mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
doc: better partial support (#527)
This commit is contained in:
+111
-29
@@ -21,6 +21,7 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
|
||||
from instructor.mode import Mode
|
||||
from instructor.dsl.partialjson import JSONParser
|
||||
@@ -30,7 +31,72 @@ parser = JSONParser()
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
|
||||
class MakeFieldsOptional:
|
||||
pass
|
||||
|
||||
|
||||
def _make_field_optional(
|
||||
field: FieldInfo,
|
||||
) -> tuple[object, FieldInfo]:
|
||||
tmp_field = deepcopy(field)
|
||||
|
||||
annotation = field.annotation
|
||||
|
||||
# Handle generics (like List, Dict, etc.)
|
||||
if get_origin(annotation) is not None:
|
||||
# Get the generic base (like List, Dict) and its arguments (like User in List[User])
|
||||
generic_base = get_origin(annotation)
|
||||
generic_args = get_args(annotation)
|
||||
|
||||
# Recursively apply Partial to each of the generic arguments
|
||||
modified_args = tuple(
|
||||
(
|
||||
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel)
|
||||
else arg
|
||||
)
|
||||
for arg in generic_args
|
||||
)
|
||||
|
||||
# Reconstruct the generic type with modified arguments
|
||||
tmp_field.annotation = (
|
||||
Optional[generic_base[modified_args]] if generic_base else None
|
||||
)
|
||||
tmp_field.default = None
|
||||
# If the field is a BaseModel, then recursively convert it's
|
||||
# attributes to optionals.
|
||||
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type]
|
||||
tmp_field.default = {}
|
||||
else:
|
||||
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
|
||||
tmp_field.default = None
|
||||
return tmp_field.annotation, tmp_field
|
||||
|
||||
|
||||
class PartialBase(Generic[T_Model]):
|
||||
@classmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def get_partial_model(cls) -> type[T_Model]:
|
||||
"""Return a partial model we can use to validate partial results."""
|
||||
assert issubclass(
|
||||
cls, BaseModel
|
||||
), f"{cls.__name__} must be a subclass of BaseModel"
|
||||
|
||||
return create_model(
|
||||
__model_name=(
|
||||
cls.__name__
|
||||
if cls.__name__.startswith("Partial")
|
||||
else f"Partial{cls.__name__}"
|
||||
),
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
**{
|
||||
field_name: _make_field_optional(field_info)
|
||||
for field_name, field_info in cls.model_fields.items()
|
||||
},
|
||||
) # type: ignore[all]
|
||||
|
||||
@classmethod
|
||||
def from_streaming_response(
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
@@ -59,6 +125,7 @@ class PartialBase(Generic[T_Model]):
|
||||
) -> Generator[T_Model, None, None]:
|
||||
prev_obj = None
|
||||
potential_object = ""
|
||||
partial_model = cls.get_partial_model()
|
||||
for chunk in json_chunks:
|
||||
potential_object += chunk
|
||||
|
||||
@@ -67,11 +134,11 @@ class PartialBase(Generic[T_Model]):
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
] = chunk # Provide the raw chunk for debugging and benchmarking
|
||||
obj.__dict__["chunk"] = (
|
||||
chunk # Provide the raw chunk for debugging and benchmarking
|
||||
)
|
||||
prev_obj = obj
|
||||
yield obj
|
||||
|
||||
@@ -81,6 +148,7 @@ class PartialBase(Generic[T_Model]):
|
||||
) -> AsyncGenerator[T_Model, None]:
|
||||
potential_object = ""
|
||||
prev_obj = None
|
||||
partial_model = cls.get_partial_model()
|
||||
async for chunk in json_chunks:
|
||||
potential_object += chunk
|
||||
|
||||
@@ -89,11 +157,11 @@ class PartialBase(Generic[T_Model]):
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
] = chunk # Provide the raw chunk for debugging and benchmarking
|
||||
obj.__dict__["chunk"] = (
|
||||
chunk # Provide the raw chunk for debugging and benchmarking
|
||||
)
|
||||
prev_obj = obj
|
||||
yield obj
|
||||
|
||||
@@ -145,11 +213,10 @@ class PartialBase(Generic[T_Model]):
|
||||
|
||||
|
||||
class Partial(Generic[T_Model]):
|
||||
"""Generate a new class with all attributes optionals.
|
||||
"""Generate a new class which has PartialBase as a base class.
|
||||
|
||||
Notes:
|
||||
This will wrap a class inheriting form BaseModel and will recursively
|
||||
convert all its attributes and its children's attributes to optionals.
|
||||
This will enable partial validation of the model while streaming.
|
||||
|
||||
Example:
|
||||
Partial[SomeModel]
|
||||
@@ -181,13 +248,23 @@ class Partial(Generic[T_Model]):
|
||||
|
||||
def __class_getitem__( # type: ignore[override]
|
||||
cls,
|
||||
wrapped_class: type[T_Model],
|
||||
wrapped_class: type[T_Model] | tuple[type[T_Model], type[MakeFieldsOptional]],
|
||||
) -> type[T_Model]:
|
||||
"""Convert model to a partial model with all fields being optionals."""
|
||||
"""Convert model to one that inherits from PartialBase.
|
||||
|
||||
def _make_field_optional(
|
||||
field: FieldInfo,
|
||||
) -> tuple[object, FieldInfo]:
|
||||
We don't make the fields optional at this point, we just wrap them with `Partial` so the names of the nested models will be
|
||||
`Partial{ModelName}`. We want the output of `model_json_schema()` to
|
||||
reflect the name change, but everything else should be the same as the
|
||||
original model. During validation, we'll generate a true partial model
|
||||
to support partially defined fields.
|
||||
|
||||
"""
|
||||
|
||||
make_fields_optional = None
|
||||
if isinstance(wrapped_class, tuple):
|
||||
wrapped_class, make_fields_optional = wrapped_class
|
||||
|
||||
def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]:
|
||||
tmp_field = deepcopy(field)
|
||||
|
||||
annotation = field.annotation
|
||||
@@ -200,33 +277,38 @@ class Partial(Generic[T_Model]):
|
||||
|
||||
# Recursively apply Partial to each of the generic arguments
|
||||
modified_args = tuple(
|
||||
Partial[arg] # type: ignore[valid-type]
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel)
|
||||
else arg
|
||||
(
|
||||
Partial[arg] # type: ignore[valid-type]
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel)
|
||||
else arg
|
||||
)
|
||||
for arg in generic_args
|
||||
)
|
||||
|
||||
# Reconstruct the generic type with modified arguments
|
||||
tmp_field.annotation = (
|
||||
Optional[generic_base[modified_args]] if generic_base else None
|
||||
generic_base[modified_args] if generic_base else None
|
||||
)
|
||||
tmp_field.default = None
|
||||
# If the field is a BaseModel, then recursively convert it's
|
||||
# attributes to optionals.
|
||||
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||
tmp_field.annotation = Optional[Partial[annotation]] # type: ignore[assignment, valid-type]
|
||||
tmp_field.default = {}
|
||||
else:
|
||||
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
|
||||
tmp_field.default = None
|
||||
tmp_field.annotation = Partial[annotation] # type: ignore[assignment, valid-type]
|
||||
return tmp_field.annotation, tmp_field
|
||||
|
||||
return create_model(
|
||||
__model_name=f"Partial{wrapped_class.__name__}",
|
||||
__model_name=(
|
||||
wrapped_class.__name__
|
||||
if wrapped_class.__name__.startswith("Partial")
|
||||
else f"Partial{wrapped_class.__name__}"
|
||||
),
|
||||
__base__=(wrapped_class, PartialBase),
|
||||
__module__=wrapped_class.__module__,
|
||||
**{
|
||||
field_name: _make_field_optional(field_info)
|
||||
for field_name, field_info in wrapped_class.__fields__.items()
|
||||
field_name: (
|
||||
_make_field_optional(field_info)
|
||||
if make_fields_optional is not None
|
||||
else _wrap_models(field_info)
|
||||
)
|
||||
for field_name, field_info in wrapped_class.model_fields.items()
|
||||
},
|
||||
) # type: ignore[all]
|
||||
|
||||
Generated
+10
-7
@@ -125,7 +125,7 @@ files = [
|
||||
name = "anthropic"
|
||||
version = "0.18.1"
|
||||
description = "The official Python library for the anthropic API"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "anthropic-0.18.1-py3-none-any.whl", hash = "sha256:b85aee64f619ce1b1964ba733a09adc4053e7bc4e6d4186001229ec191099dcf"},
|
||||
@@ -826,7 +826,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc
|
||||
name = "filelock"
|
||||
version = "3.13.1"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
|
||||
@@ -928,7 +928,7 @@ files = [
|
||||
name = "fsspec"
|
||||
version = "2024.3.1"
|
||||
description = "File-system specification"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
|
||||
@@ -1091,7 +1091,7 @@ socks = ["socksio (==1.*)"]
|
||||
name = "huggingface-hub"
|
||||
version = "0.21.4"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
|
||||
@@ -3268,7 +3268,7 @@ test = ["flake8", "isort", "pytest"]
|
||||
name = "tokenizers"
|
||||
version = "0.15.2"
|
||||
description = ""
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"},
|
||||
@@ -3595,7 +3595,7 @@ files = [
|
||||
name = "xmltodict"
|
||||
version = "0.13.0"
|
||||
description = "Makes working with XML feel like you are working with JSON"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.4"
|
||||
files = [
|
||||
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
|
||||
@@ -3705,7 +3705,10 @@ files = [
|
||||
idna = ">=2.0"
|
||||
multidict = ">=4.0"
|
||||
|
||||
[extras]
|
||||
anthropic = ["anthropic", "xmltodict"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "5780411ed7df9fcd0c67d664cd4898ba4be9fb7ee7a74538e8e5937a72359142"
|
||||
content-hash = "35fcdd72be242b352d6d5790f4f37d4593c3047a33ed77bca9575b08e2c485a9"
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# type: ignore[all]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from instructor.dsl.partial import Partial
|
||||
|
||||
|
||||
class SampleNestedPartial(BaseModel):
|
||||
b: int
|
||||
|
||||
|
||||
class SamplePartial(BaseModel):
|
||||
a: int
|
||||
b: SampleNestedPartial
|
||||
|
||||
|
||||
def test_partial():
|
||||
partial = Partial[SamplePartial]
|
||||
assert partial.model_json_schema() == {
|
||||
"$defs": {
|
||||
"PartialSampleNestedPartial": {
|
||||
"properties": {"b": {"title": "B", "type": "integer"}},
|
||||
"required": ["b"],
|
||||
"title": "PartialSampleNestedPartial",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"$ref": "#/$defs/PartialSampleNestedPartial"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "PartialSamplePartial",
|
||||
"type": "object",
|
||||
}, "Wrapped model JSON schema has changed"
|
||||
assert partial.get_partial_model().model_json_schema() == {
|
||||
"$defs": {
|
||||
"PartialSampleNestedPartial": {
|
||||
"properties": {
|
||||
"b": {
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "B",
|
||||
}
|
||||
},
|
||||
"title": "PartialSampleNestedPartial",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"a": {
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "A",
|
||||
},
|
||||
"b": {
|
||||
"anyOf": [
|
||||
{"$ref": "#/$defs/PartialSampleNestedPartial"},
|
||||
{"type": "null"},
|
||||
],
|
||||
"default": {},
|
||||
},
|
||||
},
|
||||
"title": "PartialSamplePartial",
|
||||
"type": "object",
|
||||
}, "Partial model JSON schema has changed"
|
||||
|
||||
for model in partial.model_from_chunks(['{"b": {"b": 1}}']):
|
||||
assert model.model_dump() == {"a": None, "b": {"b": 1}}
|
||||
Reference in New Issue
Block a user