diff --git a/instructor/patch.py b/instructor/patch.py index 3aabefd..3e364e9 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -1,17 +1,22 @@ import inspect -from functools import wraps -from instructor.dsl.multitask import MultiTask, MultiTaskBase -from json import JSONDecodeError -from typing import get_origin, get_args, Callable, Optional, Type, Union +import json +import warnings from collections.abc import Iterable +from functools import wraps +from json import JSONDecodeError +from typing import Callable, Optional, Type, Union, get_args, get_origin from openai import AsyncOpenAI, OpenAI -from openai.types.chat import ChatCompletion +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageParam, +) from pydantic import BaseModel, ValidationError -from .function_calls import OpenAISchema, openai_schema, Mode +from instructor.dsl.multitask import MultiTask, MultiTaskBase -import warnings +from .function_calls import Mode, OpenAISchema, openai_schema OVERRIDE_DOCS = """ Creates a new chat completion for the provided messages and parameters. @@ -33,15 +38,20 @@ Parameters: """ -def dump_message(message) -> dict: +def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: """Dumps a message to a dict, to be returned to the OpenAI API. Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests if it isn't used. """ - dumped_message = message.model_dump() - if not dumped_message.get("tool_calls"): - del dumped_message["tool_calls"] - return {k: v for k, v in dumped_message.items() if v} + ret: ChatCompletionMessageParam = { + "role": message.role, + "content": message.content or "", + } + if message.tool_calls is not None: + ret["content"] += json.dumps(message.model_dump()["tool_calls"]) + if message.function_call is not None: + ret["content"] += json.dumps(message.model_dump()["function_call"]) + return ret def handle_response_model( @@ -202,7 +212,7 @@ def retry_sync( mode=mode, ) except (ValidationError, JSONDecodeError) as e: - kwargs["messages"].append(response.choices[0].message) + kwargs["messages"].append(dump_message(response.choices[0].message)) kwargs["messages"].append( { "role": "user", diff --git a/tests/test_patch.py b/tests/test_patch.py index d6913d4..e8d2916 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -1,11 +1,16 @@ import functools + import pytest +from openai import AsyncOpenAI, OpenAI +from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) + import instructor - -from openai import OpenAI, AsyncOpenAI - - -from instructor.patch import is_async, wrap_chatcompletion, OVERRIDE_DOCS +from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion def test_patch_completes_successfully(): @@ -66,3 +71,106 @@ def test_override_docs(): assert ( "response_model" in OVERRIDE_DOCS ), "response_model should be in OVERRIDE_DOCS" + + +@pytest.mark.parametrize( + "name_of_test, message, expected", + [ + ( + "tool_calls and content and no function_call", + ChatCompletionMessage( + role="assistant", + content="Hello, world!", + tool_calls=[ + ChatCompletionMessageToolCall( + id="test_tool", + function=Function(arguments="", name="test_tool"), + type="function", + ) + ], + ), + { + "role": "assistant", + "content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]', + }, + ), + ( + "tool_calls and no content and no function_call", + ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="test_tool", + function=Function(arguments="", name="test_tool"), + type="function", + ) + ], + ), + { + "role": "assistant", + "content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]', + }, + ), + ( + "no tool_calls and no content no function_call", + ChatCompletionMessage( + role="assistant", + content=None, + ), + { + "role": "assistant", + "content": "", + }, + ), + ( + "no tool_calls and content and function_call", + ChatCompletionMessage( + role="assistant", + content="Hello, world!", + function_call=FunctionCall(arguments="", name="test_tool"), + ), + { + "role": "assistant", + "content": 'Hello, world!{"arguments": "", "name": "test_tool"}', + }, + ), + ( + "no tool_calls and no content and function_call", + ChatCompletionMessage( + role="assistant", + content=None, + function_call=FunctionCall(arguments="", name="test_tool"), + ), + { + "role": "assistant", + "content": '{"arguments": "", "name": "test_tool"}', + }, + ), + ( + "tool_calls and no content and function_call", + ChatCompletionMessage( + role="assistant", + content=None, + function_call=FunctionCall(arguments="", name="test_tool"), + tool_calls=[ + ChatCompletionMessageToolCall( + id="test_tool", + function=Function(arguments="", name="test_tool"), + type="function", + ) + ], + ), + { + "role": "assistant", + "content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}', + }, + ), + ], +) +def test_dump_message( + name_of_test: str, + message: ChatCompletionMessage, + expected: ChatCompletionMessageParam, +): + assert dump_message(message) == expected, name_of_test