From febf5473d5a1a78d81b194ae329b24ef65b353b0 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Thu, 31 Oct 2024 16:33:01 -0400 Subject: [PATCH] Refactor message parameter in Anthropic provider --- simplemind/providers/groq.py | 1 + simplemind/providers/ollama.py | 9 +++++++-- tests/test_generate_data.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 tests/test_generate_data.py diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index d1a3c8e..42d8090 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -72,6 +72,7 @@ class Groq(BaseProvider): response = self.structured_client.chat.completions.create( messages=messages, response_model=response_model, + model=kwargs.pop("llm_model", self.DEFAULT_MODEL), **kwargs, ) return response diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 7d1ac22..12f241c 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -65,7 +65,12 @@ class Ollama(BaseProvider): ) def structured_response( - self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs + self, + prompt: str, + response_model: Type[T], + *, + llm_model: str | None = None, + **kwargs, ) -> T: """Get a structured response from the Ollama API.""" messages = [ @@ -80,7 +85,7 @@ class Ollama(BaseProvider): ) return response - def generate_text(self, prompt: str, *, llm_model: str) -> str: + def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str: """Generate text using the Ollama API.""" messages = [ {"role": "user", "content": prompt}, diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py new file mode 100644 index 0000000..bea35c8 --- /dev/null +++ b/tests/test_generate_data.py @@ -0,0 +1,28 @@ +import pytest + +from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama +from pydantic import BaseModel + + +class ResponseModel(BaseModel): + result: int + + +@pytest.mark.parametrize( + "provider_cls", + [ + Anthropic, + Gemini, + OpenAI, + Groq, + Ollama, + ], +) +def test_generate_data(provider_cls): + provider = provider_cls() + prompt = "What is 2+2?" + + data = provider.structured_response(prompt=prompt, response_model=ResponseModel) + + assert isinstance(data, ResponseModel) + assert type(data.result) == int