Mc/toolmaxretry (#323)

This commit is contained in:
casty
2024-01-05 14:49:36 -05:00
committed by GitHub
parent 05fcf0b585
commit c16d622e88
2 changed files with 25 additions and 11 deletions
+20 -9
View File
@@ -50,6 +50,7 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"content": message.content or "",
}
if message.tool_calls is not None:
ret["tool_calls"] = message.model_dump()["tool_calls"]
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
if message.function_call is not None:
ret["content"] += json.dumps(message.model_dump()["function_call"])
@@ -240,6 +241,15 @@ async def retry_async(
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"name": response.choices[0].message.tool_calls[0].function.name,
"content": "failure"
}
)
kwargs["messages"].append(
{
"role": "user",
@@ -286,6 +296,15 @@ def retry_sync(
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message))
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"name": response.choices[0].message.tool_calls[0].function.name,
"content": "failure"
}
)
kwargs["messages"].append(
{
"role": "user",
@@ -323,10 +342,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable
*args,
**kwargs,
):
if mode == Mode.TOOLS:
max_retries = 0
logger.warning("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
@@ -349,10 +364,6 @@ def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable
*args,
**kwargs,
):
if mode == Mode.TOOLS:
max_retries = 0
logger.warning("max_retries is not supported when using tool calls")
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
@@ -406,4 +417,4 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
return patch(client, mode=mode)
return patch(client, mode=mode)
+5 -2
View File
@@ -92,6 +92,7 @@ def test_override_docs():
{
"role": "assistant",
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
},
),
(
@@ -110,6 +111,7 @@ def test_override_docs():
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}],
},
),
(
@@ -151,7 +153,7 @@ def test_override_docs():
"tool_calls and no content and function_call",
ChatCompletionMessage(
role="assistant",
content=None,
content="",
function_call=FunctionCall(arguments="", name="test_tool"),
tool_calls=[
ChatCompletionMessageToolCall(
@@ -164,6 +166,7 @@ def test_override_docs():
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}',
"tool_calls": [{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]
},
),
],
@@ -173,4 +176,4 @@ def test_dump_message(
message: ChatCompletionMessage,
expected: ChatCompletionMessageParam,
):
assert dump_message(message) == expected, name_of_test
assert dump_message(message) == expected, name_of_test