Add multiple modalities: tools, functions, json_mode (#218)

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Jason Liu
2023-11-25 13:56:52 -05:00
committed by GitHub
parent 7de55a9336
commit 359c5f9295
7 changed files with 375 additions and 128 deletions
+70
View File
@@ -0,0 +1,70 @@
# Patching
Instructor enhances client functionality with three new keywords for backwards compatibility. This allows use of the enhanced client as usual, with structured output benefits.
- `response_model`: Defines the response type for `chat.completions.create`.
- `max_retries`: Determines retry attempts for failed `chat.completions.create` validations.
- `validation_context`: Provides extra context to the validation process.
There are three methods for structured output:
1. **Function Calling**: The primary method. Use this for stability and testing.
2. **Tool Calling**: Useful in specific scenarios; lacks the reasking feature of OpenAI's tool calling API.
3. **JSON Mode**: Offers closer adherence to JSON but with more potential validation errors. Suitable for specific non-function calling clients.
## Function Calling
```python
from openai import OpenAI
import instructor
client = instructor.patch(OpenAI())
```
## Tool Calling
```python
import instructor
from instructor import Mode
client = instructor.patch(OpenAI(), mode=Mode.TOOL_CALL)
```
## JSON Mode
```python
import instructor
from instructor import Mode
from openai import OpenAI
client = instructor.patch(OpenAI(), mode=Mode.JSON)
```
### Schema Integration
In JSON Mode, the schema is part of the system message:
```python
import instructor
from openai import OpenAI
client = instructor.patch(OpenAI())
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_format={"type": "json_object"},
messages=[
{
"role": "system",
"content": f"Match your response to this json_schema: \n{UserExtract.model_json_schema()['properties']}",
},
{
"role": "user",
"content": "Extract jason is 25 years old",
},
],
)
user = UserExtract.from_response(response, mode=Mode.JSON)
assert user.name.lower() == "jason"
assert user.age == 25
```
+42 -6
View File
@@ -1,9 +1,20 @@
from calendar import c
import json
from docstring_parser import parse
from functools import wraps
from typing import Any, Callable
from pydantic import BaseModel, create_model, validate_arguments
import enum
class Mode(enum.Enum):
"""The mode to use for patching the client"""
FUNCTIONS: str = "function_call"
TOOLS: str = "tool_call"
JSON: str = "json_mode"
class openai_function:
"""
@@ -176,9 +187,9 @@ class OpenAISchema(BaseModel):
def from_response(
cls,
completion,
throw_error: bool = True,
validation_context=None,
strict: bool = None,
mode: Mode = Mode.FUNCTIONS,
):
"""Execute the function from the response of an openai chat completion
@@ -193,11 +204,36 @@ class OpenAISchema(BaseModel):
"""
message = completion.choices[0].message
return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
elif mode == Mode.TOOLS:
assert (
len(message.tool_calls) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
assert (
tool_call.function.name == cls.openai_schema["name"]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
elif mode == Mode.JSON:
return cls.model_validate_json(
message.content,
context=validation_context,
strict=strict,
)
else:
raise ValueError(f"Invalid patch mode: {mode}")
def openai_schema(cls) -> OpenAISchema:
+112 -48
View File
@@ -4,10 +4,12 @@ from json import JSONDecodeError
from typing import Callable, Optional, Type, Union
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat import ChatCompletion
from pydantic import BaseModel, ValidationError
from .function_calls import OpenAISchema, openai_schema
from .function_calls import OpenAISchema, openai_schema, Mode
import warnings
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
@@ -29,16 +31,68 @@ Parameters:
"""
def handle_response_model(response_model: Type[BaseModel], kwargs):
def dump_message(message) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return {k: v for k, v in dumped_message.items() if v}
def handle_response_model(
*,
response_model: Type[BaseModel],
kwargs,
mode: Mode = Mode.FUNCTIONS,
):
new_kwargs = kwargs.copy()
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {
"name": response_model.openai_schema["name"]
} # type: ignore
elif mode == Mode.TOOLS:
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
# check that the first message is a system message
# if it is not, add a system message to the beginning
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)
# if the first message is a system append the schema to the end
if new_kwargs["messages"][0]["role"] == "system":
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
else:
raise ValueError(f"Invalid patch mode: {mode}")
if new_kwargs.get("stream", False) and response_model is not None:
import warnings
raise NotImplementedError("stream=True is not supported when using response_model parameter")
warnings.warn(
"stream=True is not supported when using response_model parameter"
@@ -48,7 +102,12 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
def process_response(
response, response_model, validation_context: dict = None, strict=None
response,
*,
response_model: Type[BaseModel],
validation_context: dict = None,
strict=None,
mode: Mode = Mode.FUNCTIONS,
): # type: ignore
"""Processes a OpenAI response with the response model, if available
It can use `validation_context` and `strict` to validate the response
@@ -62,25 +121,13 @@ def process_response(
"""
if response_model is not None:
model = response_model.from_response(
response, validation_context=validation_context, strict=strict
response, validation_context=validation_context, strict=strict, mode=mode
)
model._raw_response = response
return model
return response
def dump_message(message: ChatCompletionMessage) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return dumped_message
async def retry_async(
func,
response_model,
@@ -89,22 +136,21 @@ async def retry_async(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
return (
process_response(
response,
response_model,
validation_context,
strict=strict,
),
None,
return process_response(
response,
response_model=response_model,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(response.choices[0].message) # type: ignore
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
kwargs["messages"].append(
{
"role": "user",
@@ -124,20 +170,22 @@ def retry_sync(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return (
process_response(
response, response_model, validation_context, strict=strict
),
None,
return process_response(
response,
response_model=response_model,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(response.choices[0].message)
kwargs["messages"].append(
{
"role": "user",
@@ -156,7 +204,9 @@ def is_async(func: Callable) -> bool:
)
def wrap_chatcompletion(func: Callable) -> Callable:
def wrap_chatcompletion(
func: Callable, mode: Mode = Mode.FUNCTIONS
) -> Callable:
func_is_async = is_async(func)
@wraps(func)
@@ -167,17 +217,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = await retry_async(
if mode == Mode.TOOLS:
max_retries = 0
warnings.warn("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = await retry_async(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
if error:
raise ValueError(error)
return response
@wraps(func)
@@ -188,17 +243,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = retry_sync(
if mode == Mode.TOOLS:
max_retries = 0
warnings.warn("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = retry_sync(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
if error:
raise ValueError(error)
return response
wrapper_function = (
@@ -208,7 +268,9 @@ def wrap_chatcompletion(func: Callable) -> Callable:
return wrapper_function
def patch(client: Union[OpenAI, AsyncOpenAI]):
def patch(
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
):
"""
Patch the `client.chat.completions.create` method
@@ -220,11 +282,13 @@ def patch(client: Union[OpenAI, AsyncOpenAI]):
- `strict` parameter to use strict json parsing
"""
client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create)
client.chat.completions.create = wrap_chatcompletion(
client.chat.completions.create, mode=mode
)
return client
def apatch(client: AsyncOpenAI):
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
"""
No longer necessary, use `patch` instead.
@@ -237,4 +301,4 @@ def apatch(client: AsyncOpenAI):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
return patch(client)
return patch(client, mode=mode)
+1
View File
@@ -131,6 +131,7 @@ nav:
- Models: 'concepts/models.md'
- Fields: 'concepts/fields.md'
- Types: 'concepts/types.md'
- Patching: 'concepts/patching.md'
- Streaming: "concepts/lists.md"
- Union: 'concepts/union.md'
- Alias: 'concepts/alias.md'
+60
View File
@@ -0,0 +1,60 @@
from instructor.function_calls import OpenAISchema, Mode
from openai import OpenAI
client = OpenAI()
class UserExtract(OpenAISchema):
name: str
age: int
def test_tool_call():
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=[
{
"role": "user",
"content": "Extract jason is 25 years old, mary is 30 years old",
},
],
tools=[
{
"type": "function",
"function": UserExtract.openai_schema,
}
],
tool_choice={
"type": "function",
"function": {"name": UserExtract.openai_schema["name"]},
},
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "UserExtract"
assert tool_calls[0].function
user = UserExtract.from_response(response, mode=Mode.TOOLS)
assert user.name.lower() == "jason"
assert user.age == 25
def test_json_mode():
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_format={"type": "json_object"},
messages=[
{
"role": "system",
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}",
},
{
"role": "user",
"content": "Extract jason is 25 years old",
},
],
)
user = UserExtract.from_response(response, mode=Mode.JSON)
assert user.name.lower() == "jason"
assert user.age == 25
+64 -69
View File
@@ -1,19 +1,67 @@
from pydantic import BaseModel, field_validator
import pytest
import instructor
from instructor import llm_validator
from typing_extensions import Annotated
from pydantic import field_validator, BaseModel, BeforeValidator, ValidationError
from openai import OpenAI, AsyncOpenAI
client = instructor.patch(OpenAI())
from instructor.function_calls import Mode
aclient = instructor.patch(AsyncOpenAI())
client = instructor.patch(OpenAI())
class UserExtract(BaseModel):
name: str
age: int
@pytest.mark.parametrize(
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
)
def test_runmodel(mode):
client = instructor.patch(OpenAI(), mode=mode)
model = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name.lower() == "jason"
assert model.age == 25
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
@pytest.mark.parametrize(
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
)
@pytest.mark.asyncio
async def test_runmodel_async(mode):
aclient = instructor.patch(AsyncOpenAI(), mode=mode)
model = await aclient.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name.lower() == "jason"
assert model.age == 25
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
class UserExtractValidated(BaseModel):
name: str
age: int
@field_validator("name")
@classmethod
def validate_name(cls, v):
@@ -22,91 +70,38 @@ class UserExtract(BaseModel):
return v
def test_runmodel_validator():
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON])
def test_runmodel_validator(mode):
client = instructor.patch(OpenAI(), mode=mode)
model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
model="gpt-3.5-turbo-1106",
response_model=UserExtractValidated,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON])
@pytest.mark.asyncio
async def test_runmodel_async_validator():
async def test_runmodel_async_validator(mode):
aclient = instructor.patch(AsyncOpenAI(), mode=mode)
model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
model="gpt-3.5-turbo-1106",
response_model=UserExtractValidated,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
class UserExtractSimple(BaseModel):
name: str
age: int
@pytest.mark.asyncio
async def test_async_runmodel():
model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtractSimple,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(
model, UserExtractSimple
), "Should be instance of UserExtractSimple"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
def test_runmodel():
model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtractSimple,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(
model, UserExtractSimple
), "Should be instance of UserExtractSimple"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
def test_runmodel_validator_error():
class QuestionAnswerNoEvil(BaseModel):
question: str
answer: Annotated[
str,
BeforeValidator(
llm_validator("don't say objectionable things", openai_client=client)
),
]
with pytest.raises(ValidationError):
QuestionAnswerNoEvil(
question="What is the meaning of life?",
answer="The meaning of life is to be evil and steal",
)
+26 -5
View File
@@ -1,17 +1,38 @@
import pytest
import instructor
import instructor
from typing_extensions import Annotated
from pydantic import BaseModel, AfterValidator, ValidationError
from pydantic import BaseModel, AfterValidator, BeforeValidator, ValidationError
from openai import OpenAI
from instructor.dsl.validators import llm_validator
client = instructor.patch(OpenAI())
def test_patch_completes_successfully():
class Response(BaseModel):
message: Annotated[str, AfterValidator(instructor.openai_moderation(client=client))]
message: Annotated[
str, AfterValidator(instructor.openai_moderation(client=client))
]
with pytest.raises(ValidationError) as e:
Response(message="I want to make them suffer the consequences")
Response(message="I want to make them suffer the consequences")
def test_runmodel_validator_error():
class QuestionAnswerNoEvil(BaseModel):
question: str
answer: Annotated[
str,
BeforeValidator(
llm_validator("don't say objectionable things", openai_client=client)
),
]
with pytest.raises(ValidationError):
QuestionAnswerNoEvil(
question="What is the meaning of life?",
answer="The meaning of life is to be evil and steal",
)