mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
fix openai
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from google.generativeai.responder import Callable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
@@ -117,61 +116,64 @@ class OpenAI(BaseProvider):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
tool.reset_result()
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
print(response)
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(
|
||||
**request_kwargs
|
||||
)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
@@ -231,3 +233,8 @@ def send_conversation(
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for OpenAI."""
|
||||
return OpenAITool
|
||||
|
||||
+2
-2
@@ -4,12 +4,12 @@ import pytest
|
||||
from pydantic import Field
|
||||
|
||||
import simplemind as sm
|
||||
from simplemind.providers import Anthropic, BaseTool
|
||||
from simplemind.providers import Anthropic, BaseTool, OpenAI
|
||||
|
||||
MODELS = [
|
||||
Anthropic,
|
||||
# Gemini,
|
||||
# OpenAI,
|
||||
OpenAI,
|
||||
# Groq,
|
||||
# Ollama,
|
||||
# Amazon
|
||||
|
||||
Reference in New Issue
Block a user