diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 853481a..87e07b9 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -17,7 +17,63 @@ T = TypeVar("T", bound=BaseModel) class OpenAITool(BaseTool): - def get_schema(self): + def get_response_schema(self): + assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" + + return { + "role": "tool", + "tool_call_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + tool_used = False + + # Get the message from the response + assistant_message = response.choices[0].message + + # Check if there's a tool call + if assistant_message.tool_calls: + tool_call = assistant_message.tool_calls[ + 0 + ] # Get the first tool call + if tool_call.function.name == self.name: + # Execute the function + import json + + function_args = json.loads(tool_call.function.arguments) + self.function_result = str(self.raw_func(**function_args)) + self.tool_id = tool_call.id + tool_used = True + + # Add assistant's message with tool call + messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ], + } + ) + + if tool_used: + # Add tool response message + messages.append(self.get_response_schema()) + + def get_input_schema(self): return { "type": "function", "function": { @@ -61,39 +117,61 @@ class OpenAI(BaseProvider): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) - @logger - def send_conversation( - self, - conversation: "Conversation", - tools: list[Callable] | None = None, + +@logger +def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, + **kwargs, +) -> "Message": + """Send a conversation to the OpenAI API.""" + from ..models import Message + + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} for msg in conversation.messages + ] + + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + [t.get_input_schema() for t in converted_tools] if tools else None + ) + + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, **kwargs, - ) -> "Message": - """Send a conversation to the OpenAI API.""" - from ..models import Message + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } - messages = [ - {"role": msg.role, "content": msg.text} - for msg in conversation.messages - ] + if tools_config: + request_kwargs["tools"] = tools_config - response = self.client.chat.completions.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - tools=self.make_tools(tools), - **{**self.DEFAULT_KWARGS, **kwargs}, - ) + # Make initial API call + response = self.client.chat.completions.create(**request_kwargs) - # Get the response content from the OpenAI response - assistant_message = response.choices[0].message + # Handle tool responses if needed + while response.choices[0].message.tool_calls: + # Handle each tool call + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + # Make another API call with the updated messages + response = self.client.chat.completions.create(**request_kwargs) + tool.reset_result() - # Create and return a properly formatted Message instance - return Message( - role="assistant", - text=assistant_message.content or "", - raw=response, - llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=self.NAME, - ) + final_message = response.choices[0].message.content + + return Message( + role="assistant", + text=final_message or "", + raw=response, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, + ) @logger def structured_response(