mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Add multiple modalities: tools, functions, json_mode (#218)
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
+112
-48
@@ -4,10 +4,12 @@ from json import JSONDecodeError
|
||||
from typing import Callable, Optional, Type, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
from .function_calls import OpenAISchema, openai_schema, Mode
|
||||
|
||||
import warnings
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
@@ -29,16 +31,68 @@ Parameters:
|
||||
"""
|
||||
|
||||
|
||||
def handle_response_model(response_model: Type[BaseModel], kwargs):
|
||||
def dump_message(message) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return {k: v for k, v in dumped_message.items() if v}
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
kwargs,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {
|
||||
"name": response_model.openai_schema["name"]
|
||||
} # type: ignore
|
||||
elif mode == Mode.TOOLS:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": response_model.openai_schema,
|
||||
}
|
||||
]
|
||||
new_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": response_model.openai_schema["name"]},
|
||||
}
|
||||
elif mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
|
||||
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
},
|
||||
)
|
||||
|
||||
# if the first message is a system append the schema to the end
|
||||
if new_kwargs["messages"][0]["role"] == "system":
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
if new_kwargs.get("stream", False) and response_model is not None:
|
||||
import warnings
|
||||
raise NotImplementedError("stream=True is not supported when using response_model parameter")
|
||||
|
||||
warnings.warn(
|
||||
"stream=True is not supported when using response_model parameter"
|
||||
@@ -48,7 +102,12 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
|
||||
|
||||
|
||||
def process_response(
|
||||
response, response_model, validation_context: dict = None, strict=None
|
||||
response,
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
validation_context: dict = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
): # type: ignore
|
||||
"""Processes a OpenAI response with the response model, if available
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
@@ -62,25 +121,13 @@ def process_response(
|
||||
"""
|
||||
if response_model is not None:
|
||||
model = response_model.from_response(
|
||||
response, validation_context=validation_context, strict=strict
|
||||
response, validation_context=validation_context, strict=strict, mode=mode
|
||||
)
|
||||
model._raw_response = response
|
||||
return model
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return dumped_message
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func,
|
||||
response_model,
|
||||
@@ -89,22 +136,21 @@ async def retry_async(
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response,
|
||||
response_model,
|
||||
validation_context,
|
||||
strict=strict,
|
||||
),
|
||||
None,
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(response.choices[0].message) # type: ignore
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -124,20 +170,22 @@ def retry_sync(
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
# Excepts ValidationError, and JSONDecodeError
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response, response_model, validation_context, strict=strict
|
||||
),
|
||||
None,
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(response.choices[0].message)
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -156,7 +204,9 @@ def is_async(func: Callable) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
def wrap_chatcompletion(
|
||||
func: Callable, mode: Mode = Mode.FUNCTIONS
|
||||
) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
@@ -167,17 +217,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
|
||||
response, error = await retry_async(
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
warnings.warn("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = await retry_async(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
@wraps(func)
|
||||
@@ -188,17 +243,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
|
||||
response, error = retry_sync(
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
warnings.warn("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = retry_sync(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
wrapper_function = (
|
||||
@@ -208,7 +268,9 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
def patch(
|
||||
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
|
||||
):
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
@@ -220,11 +282,13 @@ def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
|
||||
client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create)
|
||||
client.chat.completions.create = wrap_chatcompletion(
|
||||
client.chat.completions.create, mode=mode
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def apatch(client: AsyncOpenAI):
|
||||
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
"""
|
||||
No longer necessary, use `patch` instead.
|
||||
|
||||
@@ -237,4 +301,4 @@ def apatch(client: AsyncOpenAI):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
return patch(client)
|
||||
return patch(client, mode=mode)
|
||||
|
||||
Reference in New Issue
Block a user