mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
add openai
This commit is contained in:
+107
-29
@@ -17,7 +17,63 @@ T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OpenAITool(BaseTool):
|
||||
def get_schema(self):
|
||||
def get_response_schema(self):
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
tool_used = False
|
||||
|
||||
# Get the message from the response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[
|
||||
0
|
||||
] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
self.function_result = str(self.raw_func(**function_args))
|
||||
self.tool_id = tool_call.id
|
||||
tool_used = True
|
||||
|
||||
# Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
# Add tool response message
|
||||
messages.append(self.get_response_schema())
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -61,39 +117,61 @@ class OpenAI(BaseProvider):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable] | None = None,
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
tools=self.make_tools(tools),
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
tool.reset_result()
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
|
||||
Reference in New Issue
Block a user