From a2597709d21a488b01cc696390929cbef82de1f3 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Fri, 1 Nov 2024 14:55:22 +0530 Subject: [PATCH] gemini works as expected --- simplemind/providers/gemini.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index cd91471..abd39db 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -4,7 +4,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar -import google.generativeai as genai import instructor from pydantic import BaseModel @@ -30,12 +29,21 @@ class Gemini(BaseProvider): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.model_name = DEFAULT_MODEL + def set_model(self, model_name: str): + self.model_name = model_name + @cached_property - def client(self, model_name: str = DEFAULT_MODEL): + def client(self): """The raw Gemini client.""" if not self.api_key: raise ValueError("Gemini API key is required") - self.model_name = model_name + try: + import google.generativeai as genai + except ImportError as exc: + raise ImportError( + "Please install the `google-generativeai` package: `pip install google-generativeai`" + ) from exc + genai.configure(api_key=self.api_key) return genai.GenerativeModel(model_name=self.model_name) @cached_property @@ -73,8 +81,7 @@ class Gemini(BaseProvider): @logger def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: """Send a structured response to the Gemini API.""" - llm_model = kwargs.pop("llm_model", self.model_name) - + kwargs.pop("llm_model") try: response = self.structured_client.chat.completions.create( messages=[{"role": "user", "content": prompt}], @@ -86,13 +93,12 @@ class Gemini(BaseProvider): raise RuntimeError( f"Failed to send structured response to Gemini API: {e}" ) from e - return response + return response_model.model_validate(response) @logger def generate_text(self, prompt: str, **kwargs) -> str: """Generate text using the Gemini API.""" - llm_model = kwargs.pop("llm_model", self.model_name) - + kwargs.pop("llm_model") try: response = self.client.generate_content(prompt, **kwargs) except Exception as e: