diff --git a/instructor/patch.py b/instructor/patch.py index 105bed4..458800e 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -50,6 +50,7 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: "content": message.content or "", } if message.tool_calls is not None: + ret["tool_calls"] = message.model_dump()["tool_calls"] ret["content"] += json.dumps(message.model_dump()["tool_calls"]) if message.function_call is not None: ret["content"] += json.dumps(message.model_dump()["function_call"]) @@ -240,6 +241,15 @@ async def retry_async( logger.exception(f"Retrying, exception: {e}") logger.debug(f"Error response: {response}") kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore + if mode == Mode.TOOLS: + kwargs["messages"].append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "name": response.choices[0].message.tool_calls[0].function.name, + "content": "failure" + } + ) kwargs["messages"].append( { "role": "user", @@ -286,6 +296,15 @@ def retry_sync( logger.exception(f"Retrying, exception: {e}") logger.debug(f"Error response: {response}") kwargs["messages"].append(dump_message(response.choices[0].message)) + if mode == Mode.TOOLS: + kwargs["messages"].append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "name": response.choices[0].message.tool_calls[0].function.name, + "content": "failure" + } + ) kwargs["messages"].append( { "role": "user", @@ -323,10 +342,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable *args, **kwargs, ): - if mode == Mode.TOOLS: - max_retries = 0 - logger.warning("max_retries is not supported when using tool calls") - response_model, new_kwargs = handle_response_model( response_model=response_model, kwargs=kwargs, mode=mode ) # type: ignore @@ -349,10 +364,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable *args, **kwargs, ): - if mode == Mode.TOOLS: - max_retries = 0 - logger.warning("max_retries is not supported when using tool calls") - response_model, new_kwargs = handle_response_model( response_model=response_model, kwargs=kwargs, mode=mode ) # type: ignore @@ -406,4 +417,4 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): - `validation_context` parameter to validate the response using the pydantic model - `strict` parameter to use strict json parsing """ - return patch(client, mode=mode) + return patch(client, mode=mode) \ No newline at end of file diff --git a/tests/test_patch.py b/tests/test_patch.py index e8d2916..027c132 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -92,6 +92,7 @@ def test_override_docs(): { "role": "assistant", "content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]', + "tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}], }, ), ( @@ -110,6 +111,7 @@ def test_override_docs(): { "role": "assistant", "content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]', + "tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}], }, ), ( @@ -151,7 +153,7 @@ def test_override_docs(): "tool_calls and no content and function_call", ChatCompletionMessage( role="assistant", - content=None, + content="", function_call=FunctionCall(arguments="", name="test_tool"), tool_calls=[ ChatCompletionMessageToolCall( @@ -164,6 +166,7 @@ def test_override_docs(): { "role": "assistant", "content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}', + "tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}] }, ), ], @@ -173,4 +176,4 @@ def test_dump_message( message: ChatCompletionMessage, expected: ChatCompletionMessageParam, ): - assert dump_message(message) == expected, name_of_test + assert dump_message(message) == expected, name_of_test \ No newline at end of file