From 2319fff07ee34ce31eace7e3a0e7c1a99e921fa4 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Mon, 19 Feb 2024 16:49:33 -0500 Subject: [PATCH] feat(response model): introduce handling of simple types (#447) --- instructor/dsl/__init__.py | 3 + instructor/dsl/simple_type.py | 64 +++++++++++++++++ instructor/patch.py | 20 +++++- tests/openai/test_simple_types.py | 110 ++++++++++++++++++++++++++++++ tests/test_simple_types.py | 68 ++++++++++++++++++ 5 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 instructor/dsl/simple_type.py create mode 100644 tests/openai/test_simple_types.py create mode 100644 tests/test_simple_types.py diff --git a/instructor/dsl/__init__.py b/instructor/dsl/__init__.py index 83ae1de..31b4d78 100644 --- a/instructor/dsl/__init__.py +++ b/instructor/dsl/__init__.py @@ -3,6 +3,7 @@ from .maybe import Maybe from .partial import Partial from .validators import llm_validator, openai_moderation from .citation import CitationMixin +from .simple_type import is_simple_type, ModelAdapter __all__ = [ # noqa: F405 "CitationMixin", @@ -11,4 +12,6 @@ __all__ = [ # noqa: F405 "Partial", "llm_validator", "openai_moderation", + "is_simple_type", + "ModelAdapter", ] diff --git a/instructor/dsl/simple_type.py b/instructor/dsl/simple_type.py new file mode 100644 index 0000000..42f1482 --- /dev/null +++ b/instructor/dsl/simple_type.py @@ -0,0 +1,64 @@ +from inspect import isclass +import typing +from pydantic import BaseModel, create_model +from enum import Enum + + +from instructor.dsl.partial import Partial +from instructor.function_calls import OpenAISchema + + +T = typing.TypeVar("T") + + +class AdapterBase(BaseModel): + pass + + +class ModelAdapter(typing.Generic[T]): + """ + Accepts a response model and returns a BaseModel with the response model as the content. + """ + + def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]: + assert is_simple_type(response_model), "Only simple types are supported" + tmp = create_model( + "Response", + content=(response_model, ...), + __doc__="Correctly Formated and Extracted Response.", + __base__=(AdapterBase, OpenAISchema), + ) + return tmp + + +def is_simple_type(response_model) -> bool: + # ! we're getting mixes between classes and instances due to how we handle some + # ! response model types, we should fix this in later PRs + if isclass(response_model) and issubclass(response_model, BaseModel): + return False + + if typing.get_origin(response_model) in {typing.Iterable, Partial}: + # These are reserved for streaming types, would be nice to + return False + + if response_model in { + str, + int, + float, + bool, + }: + return True + + # If the response_model is a simple type like annotated + if typing.get_origin(response_model) in { + typing.Annotated, + typing.Literal, + typing.Union, + list, # origin of List[T] is list + }: + return True + + if isclass(response_model) and issubclass(response_model, Enum): + return True + + return False diff --git a/instructor/patch.py b/instructor/patch.py index b56e186..158fa84 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -30,6 +30,7 @@ from pydantic import BaseModel, ValidationError from instructor.dsl.iterable import IterableModel, IterableBase from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model from instructor.dsl.partial import PartialBase +from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type from .function_calls import Mode, OpenAISchema, openai_schema @@ -80,6 +81,12 @@ def handle_response_model( """ new_kwargs = kwargs.copy() if response_model is not None: + # Handles the case where the response_model is a simple type + # Literal, Annotated, Union, str, int, float, bool, Enum + # We wrap the response_model in a ModelAdapter that sets 'content' as the response + if is_simple_type(response_model): + response_model = ModelAdapter[response_model] + # This a special case for parallel tools if mode == Mode.PARALLEL_TOOLS: assert ( @@ -213,11 +220,17 @@ def process_response( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): + logger.debug(f"Returning takes from IterableBase") return [task for task in model.tasks] if isinstance(response_model, ParallelBase): + logger.debug(f"Returning model from ParallelBase") return model + if isinstance(model, AdapterBase): + logger.debug(f"Returning model from AdapterBase") + return model.content + model._raw_response = response return model @@ -266,12 +279,17 @@ async def process_response_async( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): - #! If the response model is a multitask, return the tasks + logger.debug(f"Returning takes from IterableBase") return [task for task in model.tasks] if isinstance(response_model, ParallelBase): + logger.debug(f"Returning model from ParallelBase") return model + if isinstance(model, AdapterBase): + logger.debug(f"Returning model from AdapterBase") + return model.content + model._raw_response = response return model diff --git a/tests/openai/test_simple_types.py b/tests/openai/test_simple_types.py new file mode 100644 index 0000000..adb2562 --- /dev/null +++ b/tests/openai/test_simple_types.py @@ -0,0 +1,110 @@ +import pytest +import instructor +import enum + +from typing import Annotated, Literal, Union +from pydantic import Field + + +@pytest.mark.asyncio +async def test_response_simple_types(aclient): + client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) + + for response_model in [int, bool, str]: + response = await client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=response_model, + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert type(response) == response_model + + +@pytest.mark.asyncio +async def test_annotate(aclient): + client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) + + response = await client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=Annotated[int, Field(description="test")], + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert type(response) == int + + +def test_literal(client): + client = instructor.patch(client, mode=instructor.Mode.TOOLS) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=Literal["1231", "212", "331"], + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert response in ["1231", "212", "331"] + + +def test_union(client): + client = instructor.patch(client, mode=instructor.Mode.TOOLS) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=Union[int, str], + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert type(response) in [int, str] + + +def test_enum(client): + class Options(enum.Enum): + A = "A" + B = "B" + C = "C" + + client = instructor.patch(client, mode=instructor.Mode.TOOLS) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=Options, + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert response in [Options.A, Options.B, Options.C] + + +def test_bool(client): + client = instructor.patch(client, mode=instructor.Mode.TOOLS) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=bool, + messages=[ + { + "role": "user", + "content": "Produce a Random but correct response given the desired output", + }, + ], + ) + assert type(response) == bool diff --git a/tests/test_simple_types.py b/tests/test_simple_types.py new file mode 100644 index 0000000..05239d2 --- /dev/null +++ b/tests/test_simple_types.py @@ -0,0 +1,68 @@ +from instructor.dsl import is_simple_type, Partial +from pydantic import BaseModel + + +def test_enum_simple(): + from enum import Enum + + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + assert is_simple_type(Color), "Failed for type: " + str(Color) + + +def test_standard_types(): + for t in [str, int, float, bool]: + assert is_simple_type(t), "Failed for type: " + str(t) + + +def test_partial_not_simple(): + class SampleModel(BaseModel): + data: int + + assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str( + Partial[int] + ) + + +def test_annotated_simple(): + from pydantic import Field + from typing import Annotated + + new_type = Annotated[int, Field(description="test")] + + assert is_simple_type(new_type), "Failed for type: " + str(new_type) + + +def test_literal_simple(): + from typing import Literal + + new_type = Literal[1, 2, 3] + + assert is_simple_type(new_type), "Failed for type: " + str(new_type) + + +def test_union_simple(): + from typing import Union + + new_type = Union[int, str] + + assert is_simple_type(new_type), "Failed for type: " + str(new_type) + + +def test_iterable_not_simple(): + from typing import Iterable + + new_type = Iterable[int] + + assert not is_simple_type(new_type), "Failed for type: " + str(new_type) + + +def test_list_is_simple(): + from typing import List + + new_type = List[int] + + assert is_simple_type(new_type), "Failed for type: " + str(new_type)