diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 5b2801f..dfbf69d 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -24,6 +24,7 @@ class Groq(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_KWARGS = DEFAULT_KWARGS + supports_streaming = True def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -111,3 +112,32 @@ class Groq(BaseProvider): ) return str(response.choices[0].message.content) + + @logger + def generate_stream_text( + self, + prompt: str, + *, + llm_model: str | None = None, + **kwargs, + ) -> str: + """Generate streaming text using the Groq API.""" + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.client.chat.completions.create( + messages=messages, + model=llm_model or self.DEFAULT_MODEL, + stream=True, + **{**self.DEFAULT_KWARGS, **kwargs}, + ) + + try: + for chunk in response: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + except Exception as e: + raise RuntimeError( + f"Failed to generate streaming text with Groq API: {e}" + ) from e