fix openai

This commit is contained in:
Luciano Scarpulla
2024-11-15 12:09:39 +08:00
parent 107f983a18
commit a97f9be2c8
2 changed files with 60 additions and 53 deletions
+58 -51
View File
@@ -1,8 +1,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
import instructor
from google.generativeai.responder import Callable
from pydantic import BaseModel
from ..logging import logger
@@ -117,61 +116,64 @@ class OpenAI(BaseProvider):
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
@logger
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
**kwargs,
) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
# Format messages from conversation
formatted_messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
# Set up tools if provided
converted_tools = self.make_tools(tools)
tools_config = (
[t.get_input_schema() for t in converted_tools] if tools else None
)
# Merge all kwargs
request_kwargs = {
**self.DEFAULT_KWARGS,
@logger
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
**kwargs,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": formatted_messages,
}
) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
if tools_config:
request_kwargs["tools"] = tools_config
# Format messages from conversation
formatted_messages = [
{"role": msg.role, "content": msg.text}
for msg in conversation.messages
]
# Make initial API call
response = self.client.chat.completions.create(**request_kwargs)
# Set up tools if provided
converted_tools = self.make_tools(tools)
tools_config = (
[t.get_input_schema() for t in converted_tools] if tools else None
)
# Handle tool responses if needed
while response.choices[0].message.tool_calls:
# Handle each tool call
for tool in converted_tools:
tool.handle(response, formatted_messages)
if tool.is_executed():
# Make another API call with the updated messages
response = self.client.chat.completions.create(**request_kwargs)
tool.reset_result()
# Merge all kwargs
request_kwargs = {
**self.DEFAULT_KWARGS,
**kwargs,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": formatted_messages,
}
final_message = response.choices[0].message.content
if tools_config:
request_kwargs["tools"] = tools_config
return Message(
role="assistant",
text=final_message or "",
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=self.NAME,
)
# Make initial API call
response = self.client.chat.completions.create(**request_kwargs)
# Handle tool responses if needed
while response.choices[0].message.tool_calls:
print(response)
# Handle each tool call
for tool in converted_tools:
tool.handle(response, formatted_messages)
if tool.is_executed():
# Make another API call with the updated messages
response = self.client.chat.completions.create(
**request_kwargs
)
tool.reset_result()
final_message = response.choices[0].message.content
return Message(
role="assistant",
text=final_message or "",
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=self.NAME,
)
@logger
def structured_response(
@@ -231,3 +233,8 @@ def send_conversation(
for chunk in response:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
@cached_property
def tool(self) -> Type[BaseTool]:
"""The tool implementation for OpenAI."""
return OpenAITool
+2 -2
View File
@@ -4,12 +4,12 @@ import pytest
from pydantic import Field
import simplemind as sm
from simplemind.providers import Anthropic, BaseTool
from simplemind.providers import Anthropic, BaseTool, OpenAI
MODELS = [
Anthropic,
# Gemini,
# OpenAI,
OpenAI,
# Groq,
# Ollama,
# Amazon