Fix #244 -- Refactor dump_message function to force content key (#245)

Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
Guillaume Pouyat
2023-12-02 01:01:33 +01:00
committed by GitHub
parent 0477f58d0f
commit a5ea6e5c41
2 changed files with 136 additions and 18 deletions
+23 -13
View File
@@ -1,17 +1,22 @@
import inspect
from functools import wraps
from instructor.dsl.multitask import MultiTask, MultiTaskBase
from json import JSONDecodeError
from typing import get_origin, get_args, Callable, Optional, Type, Union
import json
import warnings
from collections.abc import Iterable
from functools import wraps
from json import JSONDecodeError
from typing import Callable, Optional, Type, Union, get_args, get_origin
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from pydantic import BaseModel, ValidationError
from .function_calls import OpenAISchema, openai_schema, Mode
from instructor.dsl.multitask import MultiTask, MultiTaskBase
import warnings
from .function_calls import Mode, OpenAISchema, openai_schema
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
@@ -33,15 +38,20 @@ Parameters:
"""
def dump_message(message) -> dict:
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return {k: v for k, v in dumped_message.items() if v}
ret: ChatCompletionMessageParam = {
"role": message.role,
"content": message.content or "",
}
if message.tool_calls is not None:
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
if message.function_call is not None:
ret["content"] += json.dumps(message.model_dump()["function_call"])
return ret
def handle_response_model(
@@ -202,7 +212,7 @@ def retry_sync(
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(response.choices[0].message)
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
+113 -5
View File
@@ -1,11 +1,16 @@
import functools
import pytest
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
import instructor
from openai import OpenAI, AsyncOpenAI
from instructor.patch import is_async, wrap_chatcompletion, OVERRIDE_DOCS
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion
def test_patch_completes_successfully():
@@ -66,3 +71,106 @@ def test_override_docs():
assert (
"response_model" in OVERRIDE_DOCS
), "response_model should be in OVERRIDE_DOCS"
@pytest.mark.parametrize(
"name_of_test, message, expected",
[
(
"tool_calls and content and no function_call",
ChatCompletionMessage(
role="assistant",
content="Hello, world!",
tool_calls=[
ChatCompletionMessageToolCall(
id="test_tool",
function=Function(arguments="", name="test_tool"),
type="function",
)
],
),
{
"role": "assistant",
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
},
),
(
"tool_calls and no content and no function_call",
ChatCompletionMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="test_tool",
function=Function(arguments="", name="test_tool"),
type="function",
)
],
),
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
},
),
(
"no tool_calls and no content no function_call",
ChatCompletionMessage(
role="assistant",
content=None,
),
{
"role": "assistant",
"content": "",
},
),
(
"no tool_calls and content and function_call",
ChatCompletionMessage(
role="assistant",
content="Hello, world!",
function_call=FunctionCall(arguments="", name="test_tool"),
),
{
"role": "assistant",
"content": 'Hello, world!{"arguments": "", "name": "test_tool"}',
},
),
(
"no tool_calls and no content and function_call",
ChatCompletionMessage(
role="assistant",
content=None,
function_call=FunctionCall(arguments="", name="test_tool"),
),
{
"role": "assistant",
"content": '{"arguments": "", "name": "test_tool"}',
},
),
(
"tool_calls and no content and function_call",
ChatCompletionMessage(
role="assistant",
content=None,
function_call=FunctionCall(arguments="", name="test_tool"),
tool_calls=[
ChatCompletionMessageToolCall(
id="test_tool",
function=Function(arguments="", name="test_tool"),
type="function",
)
],
),
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}',
},
),
],
)
def test_dump_message(
name_of_test: str,
message: ChatCompletionMessage,
expected: ChatCompletionMessageParam,
):
assert dump_message(message) == expected, name_of_test