Refactor OpenAI provider code for improved readability and consistency

This commit is contained in:
2025-02-08 19:13:15 -05:00
parent 3421de0fc1
commit 9ccef9abdc
+27 -56
View File
@@ -38,9 +38,7 @@ class OpenAITool(BaseTool):
# Check if there's a tool call # Check if there's a tool call
if assistant_message.tool_calls: if assistant_message.tool_calls:
tool_call = assistant_message.tool_calls[ tool_call = assistant_message.tool_calls[0] # Get the first tool call
0
] # Get the first tool call
if tool_call.function.name == self.name: if tool_call.function.name == self.name:
# Execute the function # Execute the function
import json import json
@@ -128,8 +126,7 @@ class OpenAI(BaseProvider):
# Format messages from conversation # Format messages from conversation
formatted_messages = [ formatted_messages = [
{"role": msg.role, "content": msg.text} {"role": msg.role, "content": msg.text} for msg in conversation.messages
for msg in conversation.messages
] ]
# Set up tools if provided # Set up tools if provided
@@ -154,15 +151,12 @@ class OpenAI(BaseProvider):
# Handle tool responses if needed # Handle tool responses if needed
while response.choices[0].message.tool_calls: while response.choices[0].message.tool_calls:
print(response)
# Handle each tool call # Handle each tool call
for tool in converted_tools: for tool in converted_tools:
tool.handle(response, formatted_messages) tool.handle(response, formatted_messages)
if tool.is_executed(): if tool.is_executed():
# Make another API call with the updated messages # Make another API call with the updated messages
response = self.client.chat.completions.create( response = self.client.chat.completions.create(**request_kwargs)
**request_kwargs
)
tool.reset_result() tool.reset_result()
final_message = response.choices[0].message.content final_message = response.choices[0].message.content
@@ -188,25 +182,14 @@ class OpenAI(BaseProvider):
"""Get a structured response from the OpenAI API.""" """Get a structured response from the OpenAI API."""
# Ensure messages are provided in kwargs # Ensure messages are provided in kwargs
messages = [ messages = [
{ {"role": "user", "content": [{"type": "text", "text": prompt}]},
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
},
] ]
"""Add an image (url or base64-encoded) to the message if provided.""" """Add an image (url or base64-encoded) to the message if provided."""
if image_url: if image_url:
messages[0]['content'].append({ messages[0]["content"].append(
"type": "image_url", {"type": "image_url", "image_url": {"url": image_url}}
"image_url": { )
"url": image_url
}
})
response = self.structured_client.chat.completions.create( response = self.structured_client.chat.completions.create(
messages=messages, messages=messages,
@@ -218,29 +201,23 @@ class OpenAI(BaseProvider):
@logger @logger
def generate_text( def generate_text(
self, prompt: str, *, llm_model: str | None = None, image_url: str | None = None, **kwargs self,
prompt: str,
*,
llm_model: str | None = None,
image_url: str | None = None,
**kwargs,
): ):
"""Generate text using the OpenAI API.""" """Generate text using the OpenAI API."""
messages = [ messages = [
{ {"role": "user", "content": [{"type": "text", "text": prompt}]},
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
},
] ]
"""Add an image (url or base64-encoded) to the message if provided.""" """Add an image (url or base64-encoded) to the message if provided."""
if image_url: if image_url:
messages[0]['content'].append({ messages[0]["content"].append(
"type": "image_url", {"type": "image_url", "image_url": {"url": image_url}}
"image_url": { )
"url": image_url
}
})
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
@@ -251,32 +228,26 @@ class OpenAI(BaseProvider):
@logger @logger
def generate_stream_text( def generate_stream_text(
self, prompt: str, *, llm_model: str | None = None, image_url: str | None = None, **kwargs self,
prompt: str,
*,
llm_model: str | None = None,
image_url: str | None = None,
**kwargs,
) -> Iterator[str]: ) -> Iterator[str]:
"""Generate streaming text using the OpenAI API. """Generate streaming text using the OpenAI API.
Yields chunks of text as they are generated by the model. Yields chunks of text as they are generated by the model.
""" """
messages = [ messages = [
{ {"role": "user", "content": [{"type": "text", "text": prompt}]},
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
},
] ]
"""Add an image (url or base64-encoded) to the message if provided.""" """Add an image (url or base64-encoded) to the message if provided."""
if image_url: if image_url:
messages[0]['content'].append({ messages[0]["content"].append(
"type": "image_url", {"type": "image_url", "image_url": {"url": image_url}}
"image_url": { )
"url": image_url
}
})
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,