diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 4798933..0809913 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -24,6 +24,7 @@ class Anthropic(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_KWARGS = DEFAULT_KWARGS + supports_streaming = True def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -107,3 +108,20 @@ class Anthropic(BaseProvider): ) return response.content[0].text + + @logger + def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs): + # Prepare the messages. + messages = [ + {"role": "user", "content": prompt}, + ] + + # Make the request. + with self.client.messages.stream( + model=llm_model or self.DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs}, + ) as stream: + # Yield each chunk of text from the stream. + for chunk in stream.text_stream: + yield chunk