mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
gemini works as expected
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import google.generativeai as genai
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -30,12 +29,21 @@ class Gemini(BaseProvider):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
def client(self, model_name: str = DEFAULT_MODEL):
|
||||
def client(self):
|
||||
"""The raw Gemini client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Gemini API key is required")
|
||||
self.model_name = model_name
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||
) from exc
|
||||
genai.configure(api_key=self.api_key)
|
||||
return genai.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
@cached_property
|
||||
@@ -73,8 +81,7 @@ class Gemini(BaseProvider):
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Send a structured response to the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
|
||||
kwargs.pop("llm_model")
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
@@ -86,13 +93,12 @@ class Gemini(BaseProvider):
|
||||
raise RuntimeError(
|
||||
f"Failed to send structured response to Gemini API: {e}"
|
||||
) from e
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
|
||||
kwargs.pop("llm_model")
|
||||
try:
|
||||
response = self.client.generate_content(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user