From de36bc13283a9744206f9ba934c558449e3f17d4 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Wed, 30 Oct 2024 09:18:24 -0400 Subject: [PATCH] fix default models --- simplemind/providers/anthropic.py | 8 ++++---- simplemind/providers/groq.py | 6 +++--- simplemind/providers/ollama.py | 11 ++++++++--- simplemind/providers/xai.py | 6 +++--- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 5bb2886..1028b09 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -39,7 +39,7 @@ class Anthropic(BaseProvider): ] response = self.client.messages.create( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, max_tokens=DEFAULT_MAX_TOKENS, **kwargs, @@ -53,13 +53,13 @@ class Anthropic(BaseProvider): role="assistant", text=assistant_message, raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) def structured_response(self, model, response_model, **kwargs): response = self.structured_client.messages.create( - model=model, response_model=response_model, **kwargs + model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs ) return response @@ -69,7 +69,7 @@ class Anthropic(BaseProvider): ] response = self.client.messages.create( - model=llm_model, + model=llm_model or self.DEFAULT_MODEL, messages=messages, max_tokens=DEFAULT_MAX_TOKENS, **kwargs, diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 3e0dbb4..47f14ae 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -42,7 +42,7 @@ class Groq(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **kwargs, ) @@ -55,7 +55,7 @@ class Groq(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -85,7 +85,7 @@ class Groq(BaseProvider): response = self.client.chat.completions.create( messages=messages, - model=llm_model, + model=llm_model or self.DEFAULT_MODEL, **kwargs, ) diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index ce96503..5ba019a 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -53,7 +53,7 @@ class Ollama(BaseProvider): role="assistant", text=assistant_message.get("content"), raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -63,7 +63,10 @@ class Ollama(BaseProvider): ] response = self.structured_client.chat.completions.create( - messages=messages, model=llm_model, response_model=response_model, **kwargs + messages=messages, + model=llm_model or self.DEFAULT_MODEL, + response_model=response_model, + **kwargs, ) return response @@ -72,6 +75,8 @@ class Ollama(BaseProvider): {"role": "user", "content": prompt}, ] - response = self.client.chat(messages=messages, model=llm_model) + response = self.client.chat( + messages=messages, model=llm_model or self.DEFAULT_MODEL + ) return response.get("message").get("content") diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index a5dbac6..8b8b84d 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -43,7 +43,7 @@ class XAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **kwargs, ) @@ -56,7 +56,7 @@ class XAI(BaseProvider): role="assistant", text=assistant_message.content, raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -70,7 +70,7 @@ class XAI(BaseProvider): response = self.client.chat.completions.create( messages=messages, - model=llm_model, + model=llm_model or self.DEFAULT_MODEL, **kwargs, )