Refactor message parameter in Anthropic provider

This commit is contained in:
2024-10-31 16:33:01 -04:00
parent 48ac97f070
commit febf5473d5
3 changed files with 36 additions and 2 deletions
+1
View File
@@ -72,6 +72,7 @@ class Groq(BaseProvider):
response = self.structured_client.chat.completions.create(
messages=messages,
response_model=response_model,
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**kwargs,
)
return response
+7 -2
View File
@@ -65,7 +65,12 @@ class Ollama(BaseProvider):
)
def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
self,
prompt: str,
response_model: Type[T],
*,
llm_model: str | None = None,
**kwargs,
) -> T:
"""Get a structured response from the Ollama API."""
messages = [
@@ -80,7 +85,7 @@ class Ollama(BaseProvider):
)
return response
def generate_text(self, prompt: str, *, llm_model: str) -> str:
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
+28
View File
@@ -0,0 +1,28 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
from pydantic import BaseModel
class ResponseModel(BaseModel):
result: int
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
],
)
def test_generate_data(provider_cls):
provider = provider_cls()
prompt = "What is 2+2?"
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
assert isinstance(data, ResponseModel)
assert type(data.result) == int