mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor OpenAI provider code for improved readability and consistency
This commit is contained in:
@@ -38,9 +38,7 @@ class OpenAITool(BaseTool):
|
|||||||
|
|
||||||
# Check if there's a tool call
|
# Check if there's a tool call
|
||||||
if assistant_message.tool_calls:
|
if assistant_message.tool_calls:
|
||||||
tool_call = assistant_message.tool_calls[
|
tool_call = assistant_message.tool_calls[0] # Get the first tool call
|
||||||
0
|
|
||||||
] # Get the first tool call
|
|
||||||
if tool_call.function.name == self.name:
|
if tool_call.function.name == self.name:
|
||||||
# Execute the function
|
# Execute the function
|
||||||
import json
|
import json
|
||||||
@@ -128,8 +126,7 @@ class OpenAI(BaseProvider):
|
|||||||
|
|
||||||
# Format messages from conversation
|
# Format messages from conversation
|
||||||
formatted_messages = [
|
formatted_messages = [
|
||||||
{"role": msg.role, "content": msg.text}
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
for msg in conversation.messages
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Set up tools if provided
|
# Set up tools if provided
|
||||||
@@ -154,15 +151,12 @@ class OpenAI(BaseProvider):
|
|||||||
|
|
||||||
# Handle tool responses if needed
|
# Handle tool responses if needed
|
||||||
while response.choices[0].message.tool_calls:
|
while response.choices[0].message.tool_calls:
|
||||||
print(response)
|
|
||||||
# Handle each tool call
|
# Handle each tool call
|
||||||
for tool in converted_tools:
|
for tool in converted_tools:
|
||||||
tool.handle(response, formatted_messages)
|
tool.handle(response, formatted_messages)
|
||||||
if tool.is_executed():
|
if tool.is_executed():
|
||||||
# Make another API call with the updated messages
|
# Make another API call with the updated messages
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(**request_kwargs)
|
||||||
**request_kwargs
|
|
||||||
)
|
|
||||||
tool.reset_result()
|
tool.reset_result()
|
||||||
|
|
||||||
final_message = response.choices[0].message.content
|
final_message = response.choices[0].message.content
|
||||||
@@ -188,25 +182,14 @@ class OpenAI(BaseProvider):
|
|||||||
"""Get a structured response from the OpenAI API."""
|
"""Get a structured response from the OpenAI API."""
|
||||||
# Ensure messages are provided in kwargs
|
# Ensure messages are provided in kwargs
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
if image_url:
|
if image_url:
|
||||||
messages[0]['content'].append({
|
messages[0]["content"].append(
|
||||||
"type": "image_url",
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
"image_url": {
|
)
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
response = self.structured_client.chat.completions.create(
|
response = self.structured_client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -218,29 +201,23 @@ class OpenAI(BaseProvider):
|
|||||||
|
|
||||||
@logger
|
@logger
|
||||||
def generate_text(
|
def generate_text(
|
||||||
self, prompt: str, *, llm_model: str | None = None, image_url: str | None = None, **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
image_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Generate text using the OpenAI API."""
|
"""Generate text using the OpenAI API."""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
if image_url:
|
if image_url:
|
||||||
messages[0]['content'].append({
|
messages[0]["content"].append(
|
||||||
"type": "image_url",
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
"image_url": {
|
)
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -251,32 +228,26 @@ class OpenAI(BaseProvider):
|
|||||||
|
|
||||||
@logger
|
@logger
|
||||||
def generate_stream_text(
|
def generate_stream_text(
|
||||||
self, prompt: str, *, llm_model: str | None = None, image_url: str | None = None, **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
image_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
"""Generate streaming text using the OpenAI API.
|
"""Generate streaming text using the OpenAI API.
|
||||||
|
|
||||||
Yields chunks of text as they are generated by the model.
|
Yields chunks of text as they are generated by the model.
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
if image_url:
|
if image_url:
|
||||||
messages[0]['content'].append({
|
messages[0]["content"].append(
|
||||||
"type": "image_url",
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
"image_url": {
|
)
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
Reference in New Issue
Block a user