mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Refactor message parameter in Anthropic provider
This commit is contained in:
@@ -72,6 +72,7 @@ class Groq(BaseProvider):
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -65,7 +65,12 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the Ollama API."""
|
||||
messages = [
|
||||
@@ -80,7 +85,7 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str) -> str:
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
|
||||
"""Generate text using the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
provider = provider_cls()
|
||||
prompt = "What is 2+2?"
|
||||
|
||||
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||
|
||||
assert isinstance(data, ResponseModel)
|
||||
assert type(data.result) == int
|
||||
Reference in New Issue
Block a user