added default_kwargs logic to Groq provider

This commit is contained in:
Barış Özmen
2024-11-01 00:14:41 +03:00
parent 7c8f22bef1
commit cbc3739411
+6 -3
View File
@@ -10,6 +10,8 @@ from ._base import BaseProvider
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
T = TypeVar("T", bound=BaseModel)
@@ -17,6 +19,7 @@ T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -48,7 +51,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the Groq response
@@ -73,7 +76,7 @@ class Groq(BaseProvider):
messages=messages,
response_model=response_model,
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
@@ -91,7 +94,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content