From cbc3739411331233f9b8bb702ab8008a44750645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20=C3=96zmen?= Date: Fri, 1 Nov 2024 00:14:41 +0300 Subject: [PATCH] added default_kwargs logic to Groq provider --- simplemind/providers/groq.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 42d8090..5a033be 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -10,6 +10,8 @@ from ._base import BaseProvider PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" +DEFAULT_MAX_TOKENS = 1_000 +DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} T = TypeVar("T", bound=BaseModel) @@ -17,6 +19,7 @@ T = TypeVar("T", bound=BaseModel) class Groq(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -48,7 +51,7 @@ class Groq(BaseProvider): response = self.client.chat.completions.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) # Get the response content from the Groq response @@ -73,7 +76,7 @@ class Groq(BaseProvider): messages=messages, response_model=response_model, model=kwargs.pop("llm_model", self.DEFAULT_MODEL), - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) return response @@ -91,7 +94,7 @@ class Groq(BaseProvider): response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs}, ) return response.choices[0].message.content