From 7fcad59cefcc60c210f2bcbcec75370556500f9d Mon Sep 17 00:00:00 2001 From: Michael Hahn Date: Fri, 22 Mar 2024 17:28:57 -0700 Subject: [PATCH] doc: better partial support (#527) --- instructor/dsl/partial.py | 140 ++++++++++++++++++++++++++++++-------- poetry.lock | 17 +++-- tests/dsl/test_partial.py | 68 ++++++++++++++++++ 3 files changed, 189 insertions(+), 36 deletions(-) create mode 100644 tests/dsl/test_partial.py diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 0778fae..8ef27e1 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -21,6 +21,7 @@ from typing import ( TypeVar, ) from copy import deepcopy +from functools import lru_cache from instructor.mode import Mode from instructor.dsl.partialjson import JSONParser @@ -30,7 +31,72 @@ parser = JSONParser() T_Model = TypeVar("T_Model", bound=BaseModel) +class MakeFieldsOptional: + pass + + +def _make_field_optional( + field: FieldInfo, +) -> tuple[object, FieldInfo]: + tmp_field = deepcopy(field) + + annotation = field.annotation + + # Handle generics (like List, Dict, etc.) + if get_origin(annotation) is not None: + # Get the generic base (like List, Dict) and its arguments (like User in List[User]) + generic_base = get_origin(annotation) + generic_args = get_args(annotation) + + # Recursively apply Partial to each of the generic arguments + modified_args = tuple( + ( + Partial[arg, MakeFieldsOptional] # type: ignore[valid-type] + if isinstance(arg, type) and issubclass(arg, BaseModel) + else arg + ) + for arg in generic_args + ) + + # Reconstruct the generic type with modified arguments + tmp_field.annotation = ( + Optional[generic_base[modified_args]] if generic_base else None + ) + tmp_field.default = None + # If the field is a BaseModel, then recursively convert it's + # attributes to optionals. + elif isinstance(annotation, type) and issubclass(annotation, BaseModel): + tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type] + tmp_field.default = {} + else: + tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment] + tmp_field.default = None + return tmp_field.annotation, tmp_field + + class PartialBase(Generic[T_Model]): + @classmethod + @lru_cache(maxsize=None) + def get_partial_model(cls) -> type[T_Model]: + """Return a partial model we can use to validate partial results.""" + assert issubclass( + cls, BaseModel + ), f"{cls.__name__} must be a subclass of BaseModel" + + return create_model( + __model_name=( + cls.__name__ + if cls.__name__.startswith("Partial") + else f"Partial{cls.__name__}" + ), + __base__=cls, + __module__=cls.__module__, + **{ + field_name: _make_field_optional(field_info) + for field_name, field_info in cls.model_fields.items() + }, + ) # type: ignore[all] + @classmethod def from_streaming_response( cls, completion: Iterable[Any], mode: Mode, **kwargs: Any @@ -59,6 +125,7 @@ class PartialBase(Generic[T_Model]): ) -> Generator[T_Model, None, None]: prev_obj = None potential_object = "" + partial_model = cls.get_partial_model() for chunk in json_chunks: potential_object += chunk @@ -67,11 +134,11 @@ class PartialBase(Generic[T_Model]): parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] + obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] if obj != prev_obj: - obj.__dict__[ - "chunk" - ] = chunk # Provide the raw chunk for debugging and benchmarking + obj.__dict__["chunk"] = ( + chunk # Provide the raw chunk for debugging and benchmarking + ) prev_obj = obj yield obj @@ -81,6 +148,7 @@ class PartialBase(Generic[T_Model]): ) -> AsyncGenerator[T_Model, None]: potential_object = "" prev_obj = None + partial_model = cls.get_partial_model() async for chunk in json_chunks: potential_object += chunk @@ -89,11 +157,11 @@ class PartialBase(Generic[T_Model]): parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] + obj = partial_model.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] if obj != prev_obj: - obj.__dict__[ - "chunk" - ] = chunk # Provide the raw chunk for debugging and benchmarking + obj.__dict__["chunk"] = ( + chunk # Provide the raw chunk for debugging and benchmarking + ) prev_obj = obj yield obj @@ -145,11 +213,10 @@ class PartialBase(Generic[T_Model]): class Partial(Generic[T_Model]): - """Generate a new class with all attributes optionals. + """Generate a new class which has PartialBase as a base class. Notes: - This will wrap a class inheriting form BaseModel and will recursively - convert all its attributes and its children's attributes to optionals. + This will enable partial validation of the model while streaming. Example: Partial[SomeModel] @@ -181,13 +248,23 @@ class Partial(Generic[T_Model]): def __class_getitem__( # type: ignore[override] cls, - wrapped_class: type[T_Model], + wrapped_class: type[T_Model] | tuple[type[T_Model], type[MakeFieldsOptional]], ) -> type[T_Model]: - """Convert model to a partial model with all fields being optionals.""" + """Convert model to one that inherits from PartialBase. - def _make_field_optional( - field: FieldInfo, - ) -> tuple[object, FieldInfo]: + We don't make the fields optional at this point, we just wrap them with `Partial` so the names of the nested models will be + `Partial{ModelName}`. We want the output of `model_json_schema()` to + reflect the name change, but everything else should be the same as the + original model. During validation, we'll generate a true partial model + to support partially defined fields. + + """ + + make_fields_optional = None + if isinstance(wrapped_class, tuple): + wrapped_class, make_fields_optional = wrapped_class + + def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]: tmp_field = deepcopy(field) annotation = field.annotation @@ -200,33 +277,38 @@ class Partial(Generic[T_Model]): # Recursively apply Partial to each of the generic arguments modified_args = tuple( - Partial[arg] # type: ignore[valid-type] - if isinstance(arg, type) and issubclass(arg, BaseModel) - else arg + ( + Partial[arg] # type: ignore[valid-type] + if isinstance(arg, type) and issubclass(arg, BaseModel) + else arg + ) for arg in generic_args ) # Reconstruct the generic type with modified arguments tmp_field.annotation = ( - Optional[generic_base[modified_args]] if generic_base else None + generic_base[modified_args] if generic_base else None ) - tmp_field.default = None # If the field is a BaseModel, then recursively convert it's # attributes to optionals. elif isinstance(annotation, type) and issubclass(annotation, BaseModel): - tmp_field.annotation = Optional[Partial[annotation]] # type: ignore[assignment, valid-type] - tmp_field.default = {} - else: - tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment] - tmp_field.default = None + tmp_field.annotation = Partial[annotation] # type: ignore[assignment, valid-type] return tmp_field.annotation, tmp_field return create_model( - __model_name=f"Partial{wrapped_class.__name__}", + __model_name=( + wrapped_class.__name__ + if wrapped_class.__name__.startswith("Partial") + else f"Partial{wrapped_class.__name__}" + ), __base__=(wrapped_class, PartialBase), __module__=wrapped_class.__module__, **{ - field_name: _make_field_optional(field_info) - for field_name, field_info in wrapped_class.__fields__.items() + field_name: ( + _make_field_optional(field_info) + if make_fields_optional is not None + else _wrap_models(field_info) + ) + for field_name, field_info in wrapped_class.model_fields.items() }, ) # type: ignore[all] diff --git a/poetry.lock b/poetry.lock index c339194..259f349 100644 --- a/poetry.lock +++ b/poetry.lock @@ -125,7 +125,7 @@ files = [ name = "anthropic" version = "0.18.1" description = "The official Python library for the anthropic API" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "anthropic-0.18.1-py3-none-any.whl", hash = "sha256:b85aee64f619ce1b1964ba733a09adc4053e7bc4e6d4186001229ec191099dcf"}, @@ -826,7 +826,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.13.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, @@ -928,7 +928,7 @@ files = [ name = "fsspec" version = "2024.3.1" description = "File-system specification" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, @@ -1091,7 +1091,7 @@ socks = ["socksio (==1.*)"] name = "huggingface-hub" version = "0.21.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, @@ -3268,7 +3268,7 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.15.2" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, @@ -3595,7 +3595,7 @@ files = [ name = "xmltodict" version = "0.13.0" description = "Makes working with XML feel like you are working with JSON" -optional = false +optional = true python-versions = ">=3.4" files = [ {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, @@ -3705,7 +3705,10 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[extras] +anthropic = ["anthropic", "xmltodict"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5780411ed7df9fcd0c67d664cd4898ba4be9fb7ee7a74538e8e5937a72359142" +content-hash = "35fcdd72be242b352d6d5790f4f37d4593c3047a33ed77bca9575b08e2c485a9" diff --git a/tests/dsl/test_partial.py b/tests/dsl/test_partial.py new file mode 100644 index 0000000..b10374b --- /dev/null +++ b/tests/dsl/test_partial.py @@ -0,0 +1,68 @@ +# type: ignore[all] +from pydantic import BaseModel + +from instructor.dsl.partial import Partial + + +class SampleNestedPartial(BaseModel): + b: int + + +class SamplePartial(BaseModel): + a: int + b: SampleNestedPartial + + +def test_partial(): + partial = Partial[SamplePartial] + assert partial.model_json_schema() == { + "$defs": { + "PartialSampleNestedPartial": { + "properties": {"b": {"title": "B", "type": "integer"}}, + "required": ["b"], + "title": "PartialSampleNestedPartial", + "type": "object", + } + }, + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"$ref": "#/$defs/PartialSampleNestedPartial"}, + }, + "required": ["a", "b"], + "title": "PartialSamplePartial", + "type": "object", + }, "Wrapped model JSON schema has changed" + assert partial.get_partial_model().model_json_schema() == { + "$defs": { + "PartialSampleNestedPartial": { + "properties": { + "b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "title": "B", + } + }, + "title": "PartialSampleNestedPartial", + "type": "object", + } + }, + "properties": { + "a": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "title": "A", + }, + "b": { + "anyOf": [ + {"$ref": "#/$defs/PartialSampleNestedPartial"}, + {"type": "null"}, + ], + "default": {}, + }, + }, + "title": "PartialSamplePartial", + "type": "object", + }, "Partial model JSON schema has changed" + + for model in partial.model_from_chunks(['{"b": {"b": 1}}']): + assert model.model_dump() == {"a": None, "b": {"b": 1}}