From a97f9be2c8c69a27a06409342b9c5ffd8ca27d27 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Fri, 15 Nov 2024 12:09:39 +0800 Subject: [PATCH] fix openai --- simplemind/providers/openai.py | 109 ++++++++++++++++++--------------- tests/test_tools.py | 4 +- 2 files changed, 60 insertions(+), 53 deletions(-) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 87e07b9..25cc3d3 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,8 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar import instructor -from google.generativeai.responder import Callable from pydantic import BaseModel from ..logging import logger @@ -117,61 +116,64 @@ class OpenAI(BaseProvider): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) - -@logger -def send_conversation( - self, - conversation: "Conversation", - tools: list[Callable | BaseTool] | None = None, - **kwargs, -) -> "Message": - """Send a conversation to the OpenAI API.""" - from ..models import Message - - # Format messages from conversation - formatted_messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages - ] - - # Set up tools if provided - converted_tools = self.make_tools(tools) - tools_config = ( - [t.get_input_schema() for t in converted_tools] if tools else None - ) - - # Merge all kwargs - request_kwargs = { - **self.DEFAULT_KWARGS, + @logger + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, **kwargs, - "model": conversation.llm_model or self.DEFAULT_MODEL, - "messages": formatted_messages, - } + ) -> "Message": + """Send a conversation to the OpenAI API.""" + from ..models import Message - if tools_config: - request_kwargs["tools"] = tools_config + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} + for msg in conversation.messages + ] - # Make initial API call - response = self.client.chat.completions.create(**request_kwargs) + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + [t.get_input_schema() for t in converted_tools] if tools else None + ) - # Handle tool responses if needed - while response.choices[0].message.tool_calls: - # Handle each tool call - for tool in converted_tools: - tool.handle(response, formatted_messages) - if tool.is_executed(): - # Make another API call with the updated messages - response = self.client.chat.completions.create(**request_kwargs) - tool.reset_result() + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } - final_message = response.choices[0].message.content + if tools_config: + request_kwargs["tools"] = tools_config - return Message( - role="assistant", - text=final_message or "", - raw=response, - llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=self.NAME, - ) + # Make initial API call + response = self.client.chat.completions.create(**request_kwargs) + + # Handle tool responses if needed + while response.choices[0].message.tool_calls: + print(response) + # Handle each tool call + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + # Make another API call with the updated messages + response = self.client.chat.completions.create( + **request_kwargs + ) + tool.reset_result() + + final_message = response.choices[0].message.content + + return Message( + role="assistant", + text=final_message or "", + raw=response, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, + ) @logger def structured_response( @@ -231,3 +233,8 @@ def send_conversation( for chunk in response: if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content + + @cached_property + def tool(self) -> Type[BaseTool]: + """The tool implementation for OpenAI.""" + return OpenAITool diff --git a/tests/test_tools.py b/tests/test_tools.py index f4d018e..83ef68e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,12 +4,12 @@ import pytest from pydantic import Field import simplemind as sm -from simplemind.providers import Anthropic, BaseTool +from simplemind.providers import Anthropic, BaseTool, OpenAI MODELS = [ Anthropic, # Gemini, - # OpenAI, + OpenAI, # Groq, # Ollama, # Amazon