mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
+23
-13
@@ -1,17 +1,22 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from instructor.dsl.multitask import MultiTask, MultiTaskBase
|
||||
from json import JSONDecodeError
|
||||
from typing import get_origin, get_args, Callable, Optional, Type, Union
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from json import JSONDecodeError
|
||||
from typing import Callable, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .function_calls import OpenAISchema, openai_schema, Mode
|
||||
from instructor.dsl.multitask import MultiTask, MultiTaskBase
|
||||
|
||||
import warnings
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
@@ -33,15 +38,20 @@ Parameters:
|
||||
"""
|
||||
|
||||
|
||||
def dump_message(message) -> dict:
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return {k: v for k, v in dumped_message.items() if v}
|
||||
ret: ChatCompletionMessageParam = {
|
||||
"role": message.role,
|
||||
"content": message.content or "",
|
||||
}
|
||||
if message.tool_calls is not None:
|
||||
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
|
||||
if message.function_call is not None:
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
@@ -202,7 +212,7 @@ def retry_sync(
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(response.choices[0].message)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
+113
-5
@@ -1,11 +1,16 @@
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
import instructor
|
||||
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
|
||||
from instructor.patch import is_async, wrap_chatcompletion, OVERRIDE_DOCS
|
||||
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
@@ -66,3 +71,106 @@ def test_override_docs():
|
||||
assert (
|
||||
"response_model" in OVERRIDE_DOCS
|
||||
), "response_model should be in OVERRIDE_DOCS"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name_of_test, message, expected",
|
||||
[
|
||||
(
|
||||
"tool_calls and content and no function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="test_tool",
|
||||
function=Function(arguments="", name="test_tool"),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
|
||||
},
|
||||
),
|
||||
(
|
||||
"tool_calls and no content and no function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="test_tool",
|
||||
function=Function(arguments="", name="test_tool"),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
|
||||
},
|
||||
),
|
||||
(
|
||||
"no tool_calls and no content no function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
},
|
||||
),
|
||||
(
|
||||
"no tool_calls and content and function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
function_call=FunctionCall(arguments="", name="test_tool"),
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'Hello, world!{"arguments": "", "name": "test_tool"}',
|
||||
},
|
||||
),
|
||||
(
|
||||
"no tool_calls and no content and function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
function_call=FunctionCall(arguments="", name="test_tool"),
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '{"arguments": "", "name": "test_tool"}',
|
||||
},
|
||||
),
|
||||
(
|
||||
"tool_calls and no content and function_call",
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
function_call=FunctionCall(arguments="", name="test_tool"),
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="test_tool",
|
||||
function=Function(arguments="", name="test_tool"),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]{"arguments": "", "name": "test_tool"}',
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dump_message(
|
||||
name_of_test: str,
|
||||
message: ChatCompletionMessage,
|
||||
expected: ChatCompletionMessageParam,
|
||||
):
|
||||
assert dump_message(message) == expected, name_of_test
|
||||
|
||||
Reference in New Issue
Block a user