diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index ec5e6ac..a880fd5 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -29,7 +29,7 @@ class Groq(BaseProvider): """A client patched with Instructor.""" return instructor.from_groq(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: Conversation) -> Message: """Send a conversation to the Groq API.""" messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages @@ -50,3 +50,26 @@ class Groq(BaseProvider): llm_model=conversation.llm_model or DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) + + def structured_response(self, prompt: str, response_model): + # Ensure messages are provided in kwargs + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, + response_model=response_model, + ) + return response + + def generate_text(self, prompt: str, *, llm_model: str): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, model=llm_model + ) + + return response.choices[0].message.content