Add streaming support to Gemini provider and implement generate_stream_text method

This commit is contained in:
2024-11-02 16:48:59 -04:00
parent 028e89b080
commit 72121c121d
+15
View File
@@ -24,6 +24,7 @@ DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -107,3 +108,17 @@ class Gemini(BaseProvider):
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
return response.text
@logger
def generate_stream_text(self, prompt: str, **kwargs) -> str:
"""Generate streaming text using the Gemini API."""
kwargs.pop("llm_model", None)
try:
response = self.client.generate_content(prompt, stream=True, **kwargs)
for chunk in response:
if chunk.text:
yield chunk.text
except Exception as e:
raise RuntimeError(
f"Failed to generate streaming text with Gemini API: {e}"
) from e