diff --git a/simplemind/providers/amazon.py b/simplemind/providers/amazon.py index c04fc01..c7c7363 100644 --- a/simplemind/providers/amazon.py +++ b/simplemind/providers/amazon.py @@ -92,3 +92,24 @@ class Amazon(BaseProvider): ) return response.content[0].text + + def generate_stream_text(self, prompt, *, llm_model, **kwargs): + """Generate streaming text using the Amazon API.""" + + # Prepare the messages. + messages = [ + {"role": "user", "content": prompt}, + ] + + # Send the request to the API. + response = self.client.messages.create( + model=llm_model or self.DEFAULT_MODEL, + messages=messages, + stream=True, + **kwargs, + ) + + # Yield the text chunks. + for chunk in response: + if chunk.text: + yield chunk.text