Add multiple modalities: tools, functions, json_mode (#218)

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Jason Liu
2023-11-25 13:56:52 -05:00
committed by GitHub
parent 7de55a9336
commit 359c5f9295
7 changed files with 375 additions and 128 deletions
+112 -48
View File
@@ -4,10 +4,12 @@ from json import JSONDecodeError
from typing import Callable, Optional, Type, Union
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat import ChatCompletion
from pydantic import BaseModel, ValidationError
from .function_calls import OpenAISchema, openai_schema
from .function_calls import OpenAISchema, openai_schema, Mode
import warnings
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
@@ -29,16 +31,68 @@ Parameters:
"""
def handle_response_model(response_model: Type[BaseModel], kwargs):
def dump_message(message) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return {k: v for k, v in dumped_message.items() if v}
def handle_response_model(
*,
response_model: Type[BaseModel],
kwargs,
mode: Mode = Mode.FUNCTIONS,
):
new_kwargs = kwargs.copy()
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {
"name": response_model.openai_schema["name"]
} # type: ignore
elif mode == Mode.TOOLS:
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
# check that the first message is a system message
# if it is not, add a system message to the beginning
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)
# if the first message is a system append the schema to the end
if new_kwargs["messages"][0]["role"] == "system":
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
else:
raise ValueError(f"Invalid patch mode: {mode}")
if new_kwargs.get("stream", False) and response_model is not None:
import warnings
raise NotImplementedError("stream=True is not supported when using response_model parameter")
warnings.warn(
"stream=True is not supported when using response_model parameter"
@@ -48,7 +102,12 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
def process_response(
response, response_model, validation_context: dict = None, strict=None
response,
*,
response_model: Type[BaseModel],
validation_context: dict = None,
strict=None,
mode: Mode = Mode.FUNCTIONS,
): # type: ignore
"""Processes a OpenAI response with the response model, if available
It can use `validation_context` and `strict` to validate the response
@@ -62,25 +121,13 @@ def process_response(
"""
if response_model is not None:
model = response_model.from_response(
response, validation_context=validation_context, strict=strict
response, validation_context=validation_context, strict=strict, mode=mode
)
model._raw_response = response
return model
return response
def dump_message(message: ChatCompletionMessage) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return dumped_message
async def retry_async(
func,
response_model,
@@ -89,22 +136,21 @@ async def retry_async(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
return (
process_response(
response,
response_model,
validation_context,
strict=strict,
),
None,
return process_response(
response,
response_model=response_model,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(response.choices[0].message) # type: ignore
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
kwargs["messages"].append(
{
"role": "user",
@@ -124,20 +170,22 @@ def retry_sync(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return (
process_response(
response, response_model, validation_context, strict=strict
),
None,
return process_response(
response,
response_model=response_model,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(response.choices[0].message)
kwargs["messages"].append(
{
"role": "user",
@@ -156,7 +204,9 @@ def is_async(func: Callable) -> bool:
)
def wrap_chatcompletion(func: Callable) -> Callable:
def wrap_chatcompletion(
func: Callable, mode: Mode = Mode.FUNCTIONS
) -> Callable:
func_is_async = is_async(func)
@wraps(func)
@@ -167,17 +217,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = await retry_async(
if mode == Mode.TOOLS:
max_retries = 0
warnings.warn("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = await retry_async(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
if error:
raise ValueError(error)
return response
@wraps(func)
@@ -188,17 +243,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = retry_sync(
if mode == Mode.TOOLS:
max_retries = 0
warnings.warn("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = retry_sync(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
if error:
raise ValueError(error)
return response
wrapper_function = (
@@ -208,7 +268,9 @@ def wrap_chatcompletion(func: Callable) -> Callable:
return wrapper_function
def patch(client: Union[OpenAI, AsyncOpenAI]):
def patch(
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
):
"""
Patch the `client.chat.completions.create` method
@@ -220,11 +282,13 @@ def patch(client: Union[OpenAI, AsyncOpenAI]):
- `strict` parameter to use strict json parsing
"""
client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create)
client.chat.completions.create = wrap_chatcompletion(
client.chat.completions.create, mode=mode
)
return client
def apatch(client: AsyncOpenAI):
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
"""
No longer necessary, use `patch` instead.
@@ -237,4 +301,4 @@ def apatch(client: AsyncOpenAI):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
return patch(client)
return patch(client, mode=mode)