gemini works as expected

This commit is contained in:
Siddhesh Agarwal
2024-11-01 14:55:22 +05:30
parent 1455b5ba13
commit a2597709d2
+14 -8
View File
@@ -4,7 +4,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import google.generativeai as genai
import instructor
from pydantic import BaseModel
@@ -30,12 +29,21 @@ class Gemini(BaseProvider):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL
def set_model(self, model_name: str):
self.model_name = model_name
@cached_property
def client(self, model_name: str = DEFAULT_MODEL):
def client(self):
"""The raw Gemini client."""
if not self.api_key:
raise ValueError("Gemini API key is required")
self.model_name = model_name
try:
import google.generativeai as genai
except ImportError as exc:
raise ImportError(
"Please install the `google-generativeai` package: `pip install google-generativeai`"
) from exc
genai.configure(api_key=self.api_key)
return genai.GenerativeModel(model_name=self.model_name)
@cached_property
@@ -73,8 +81,7 @@ class Gemini(BaseProvider):
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name)
kwargs.pop("llm_model")
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
@@ -86,13 +93,12 @@ class Gemini(BaseProvider):
raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}"
) from e
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name)
kwargs.pop("llm_model")
try:
response = self.client.generate_content(prompt, **kwargs)
except Exception as e: