mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Mc/toolmaxretry (#323)
This commit is contained in:
+20
-9
@@ -50,6 +50,7 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"content": message.content or "",
|
||||
}
|
||||
if message.tool_calls is not None:
|
||||
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
||||
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
|
||||
if message.function_call is not None:
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
@@ -240,6 +241,15 @@ async def retry_async(
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": "failure"
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -286,6 +296,15 @@ def retry_sync(
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": "failure"
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -323,10 +342,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
logger.warning("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
@@ -349,10 +364,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
logger.warning("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
@@ -406,4 +417,4 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
return patch(client, mode=mode)
|
||||
return patch(client, mode=mode)
|
||||
+5
-2
@@ -92,6 +92,7 @@ def test_override_docs():
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
|
||||
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -110,6 +111,7 @@ def test_override_docs():
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
|
||||
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -151,7 +153,7 @@ def test_override_docs():
|
||||
"tool_calls and no content and function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
content="",
|
||||
function_call=FunctionCall(arguments="", name="test_tool"),
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
@@ -164,6 +166,7 @@ def test_override_docs():
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}',
|
||||
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -173,4 +176,4 @@ def test_dump_message(
|
||||
message: ChatCompletionMessage,
|
||||
expected: ChatCompletionMessageParam,
|
||||
):
|
||||
assert dump_message(message) == expected, name_of_test
|
||||
assert dump_message(message) == expected, name_of_test
|
||||
Reference in New Issue
Block a user