diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 42d8090..5a033be 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -10,6 +10,8 @@ from ._base import BaseProvider PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" +DEFAULT_MAX_TOKENS = 1_000 +DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} T = TypeVar("T", bound=BaseModel) @@ -17,6 +19,7 @@ T = TypeVar("T", bound=BaseModel) class Groq(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -48,7 +51,7 @@ class Groq(BaseProvider): response = self.client.chat.completions.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) # Get the response content from the Groq response @@ -73,7 +76,7 @@ class Groq(BaseProvider): messages=messages, response_model=response_model, model=kwargs.pop("llm_model", self.DEFAULT_MODEL), - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) return response @@ -91,7 +94,7 @@ class Groq(BaseProvider): response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) return response.choices[0].message.content diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 12f241c..6eceaa3 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -15,11 +15,17 @@ T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "ollama" DEFAULT_MODEL = "llama3.2" DEFAULT_TIMEOUT = 60 +DEFAULT_MAX_TOKENS = 1_000 +DEFAULT_KWARGS = { + "max_tokens": DEFAULT_MAX_TOKENS, + "timeout": DEFAULT_TIMEOUT, +} class Ollama(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS TIMEOUT = DEFAULT_TIMEOUT def __init__(self, host_url: str | None = None): @@ -43,7 +49,7 @@ class Ollama(BaseProvider): mode=instructor.Mode.JSON, ) - def send_conversation(self, conversation: "Conversation") -> "Message": + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the Ollama API.""" from ..models import Message @@ -51,7 +57,9 @@ class Ollama(BaseProvider): {"role": msg.role, "content": msg.text} for msg in conversation.messages ] response = self.client.chat( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages + model=conversation.llm_model or DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs} ) assistant_message = response.get("message") @@ -81,18 +89,20 @@ class Ollama(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response - def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str: + def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs) -> str: """Generate text using the Ollama API.""" messages = [ {"role": "user", "content": prompt}, ] response = self.client.chat( - messages=messages, model=llm_model or self.DEFAULT_MODEL + messages=messages, + model=llm_model or self.DEFAULT_MODEL, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response.get("message", {}).get("content", "") diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index b6fa568..a756835 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -12,11 +12,14 @@ T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" +DEFAULT_MAX_TOKENS = 1_000 +DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} class OpenAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -42,7 +45,9 @@ class OpenAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs + model=conversation.llm_model or DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs} ) # Get the response content from the OpenAI response @@ -74,7 +79,7 @@ class OpenAI(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response @@ -84,6 +89,8 @@ class OpenAI(BaseProvider): {"role": "user", "content": prompt}, ] response = self.client.chat.completions.create( - messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs + messages=messages, + model=llm_model or self.DEFAULT_MODEL, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response.choices[0].message.content diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index dd32e0b..28ea14d 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -10,11 +10,13 @@ PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" BASE_URL = "https://api.x.ai/v1" DEFAULT_MAX_TOKENS = 1000 +DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} class XAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -45,7 +47,7 @@ class XAI(BaseProvider): response = self.client.chat.completions.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) # Get the response content from the OpenAI response @@ -71,7 +73,7 @@ class XAI(BaseProvider): response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) return response.choices[0].message.content