fix: update Gemini provider to handle unimplemented features and improve error handling

This commit is contained in:
2024-10-31 12:09:23 -04:00
parent fb6c4c289b
commit 222d3025b1
+37 -32
View File
@@ -1,3 +1,6 @@
# TODO: this is a placeholder file for the Gemini provider
# IT is not currently working as desired.
from functools import cached_property
from typing import Type, TypeVar
@@ -37,34 +40,38 @@ class Gemini(BaseProvider):
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Gemini API."""
from ..models import Message
messages = [
{
"role": msg.role,
"content": msg.text,
"metadata": msg.meta or {},
}
for msg in conversation.messages
]
try:
response = self.structured_client.chat.completions.create(
messages=messages, response_model=None
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=str(response),
raw=response,
llm_model=self.model_name,
llm_provider=PROVIDER_NAME,
raise NotImplementedError(
"Gemini does not support conversation-based completions"
)
# from ..models import Message
# messages = [
# {
# "role": msg.role,
# "content": msg.text,
# "metadata": msg.meta or {},
# }
# for msg in conversation.messages
# ]
# try:
# response = self.client.chat.completions.create(
# messages=messages, response_model=None
# )
# except Exception as e:
# # Handle the exception appropriately, e.g., log the error or raise a custom exception
# raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
# # Create and return a properly formatted Message instance
# return Message(
# role="assistant",
# text=str(response),
# raw=response,
# llm_model=self.model_name,
# llm_provider=PROVIDER_NAME,
# )
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
try:
@@ -82,13 +89,11 @@ class Gemini(BaseProvider):
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name)
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=None,
**kwargs,
)
response = self.client.generate_content(prompt, **kwargs)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
return str(response)
return response.result