feat(response model): introduce handling of simple types (#447)

This commit is contained in:
Jason Liu
2024-02-19 16:49:33 -05:00
committed by GitHub
parent f29f1bd092
commit 2319fff07e
5 changed files with 264 additions and 1 deletions
+3
View File
@@ -3,6 +3,7 @@ from .maybe import Maybe
from .partial import Partial
from .validators import llm_validator, openai_moderation
from .citation import CitationMixin
from .simple_type import is_simple_type, ModelAdapter
__all__ = [ # noqa: F405
"CitationMixin",
@@ -11,4 +12,6 @@ __all__ = [ # noqa: F405
"Partial",
"llm_validator",
"openai_moderation",
"is_simple_type",
"ModelAdapter",
]
+64
View File
@@ -0,0 +1,64 @@
from inspect import isclass
import typing
from pydantic import BaseModel, create_model
from enum import Enum
from instructor.dsl.partial import Partial
from instructor.function_calls import OpenAISchema
T = typing.TypeVar("T")
class AdapterBase(BaseModel):
pass
class ModelAdapter(typing.Generic[T]):
"""
Accepts a response model and returns a BaseModel with the response model as the content.
"""
def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]:
assert is_simple_type(response_model), "Only simple types are supported"
tmp = create_model(
"Response",
content=(response_model, ...),
__doc__="Correctly Formated and Extracted Response.",
__base__=(AdapterBase, OpenAISchema),
)
return tmp
def is_simple_type(response_model) -> bool:
# ! we're getting mixes between classes and instances due to how we handle some
# ! response model types, we should fix this in later PRs
if isclass(response_model) and issubclass(response_model, BaseModel):
return False
if typing.get_origin(response_model) in {typing.Iterable, Partial}:
# These are reserved for streaming types, would be nice to
return False
if response_model in {
str,
int,
float,
bool,
}:
return True
# If the response_model is a simple type like annotated
if typing.get_origin(response_model) in {
typing.Annotated,
typing.Literal,
typing.Union,
list, # origin of List[T] is list
}:
return True
if isclass(response_model) and issubclass(response_model, Enum):
return True
return False
+19 -1
View File
@@ -30,6 +30,7 @@ from pydantic import BaseModel, ValidationError
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
from .function_calls import Mode, OpenAISchema, openai_schema
@@ -80,6 +81,12 @@ def handle_response_model(
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# Handles the case where the response_model is a simple type
# Literal, Annotated, Union, str, int, float, bool, Enum
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
if is_simple_type(response_model):
response_model = ModelAdapter[response_model]
# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
@@ -213,11 +220,17 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
@@ -266,12 +279,17 @@ async def process_response_async(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
+110
View File
@@ -0,0 +1,110 @@
import pytest
import instructor
import enum
from typing import Annotated, Literal, Union
from pydantic import Field
@pytest.mark.asyncio
async def test_response_simple_types(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)
for response_model in [int, bool, str]:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=response_model,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == response_model
@pytest.mark.asyncio
async def test_annotate(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Annotated[int, Field(description="test")],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == int
def test_literal(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Literal["1231", "212", "331"],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in ["1231", "212", "331"]
def test_union(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Union[int, str],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) in [int, str]
def test_enum(client):
class Options(enum.Enum):
A = "A"
B = "B"
C = "C"
client = instructor.patch(client, mode=instructor.Mode.TOOLS)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Options,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in [Options.A, Options.B, Options.C]
def test_bool(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=bool,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == bool
+68
View File
@@ -0,0 +1,68 @@
from instructor.dsl import is_simple_type, Partial
from pydantic import BaseModel
def test_enum_simple():
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
assert is_simple_type(Color), "Failed for type: " + str(Color)
def test_standard_types():
for t in [str, int, float, bool]:
assert is_simple_type(t), "Failed for type: " + str(t)
def test_partial_not_simple():
class SampleModel(BaseModel):
data: int
assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str(
Partial[int]
)
def test_annotated_simple():
from pydantic import Field
from typing import Annotated
new_type = Annotated[int, Field(description="test")]
assert is_simple_type(new_type), "Failed for type: " + str(new_type)
def test_literal_simple():
from typing import Literal
new_type = Literal[1, 2, 3]
assert is_simple_type(new_type), "Failed for type: " + str(new_type)
def test_union_simple():
from typing import Union
new_type = Union[int, str]
assert is_simple_type(new_type), "Failed for type: " + str(new_type)
def test_iterable_not_simple():
from typing import Iterable
new_type = Iterable[int]
assert not is_simple_type(new_type), "Failed for type: " + str(new_type)
def test_list_is_simple():
from typing import List
new_type = List[int]
assert is_simple_type(new_type), "Failed for type: " + str(new_type)