mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
added default_kwargs logic to Groq provider
This commit is contained in:
@@ -10,6 +10,8 @@ from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@@ -17,6 +19,7 @@ T = TypeVar("T", bound=BaseModel)
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -48,7 +51,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
@@ -73,7 +76,7 @@ class Groq(BaseProvider):
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -91,7 +94,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
Reference in New Issue
Block a user