From 37a9333be30eaba1ef066b6ce55c9e9596e0e08b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20=C3=96zmen?= Date: Fri, 1 Nov 2024 00:15:49 +0300 Subject: [PATCH] added default_kwargs logic to OpenAI provider --- simplemind/providers/openai.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index b6fa568..a756835 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -12,11 +12,14 @@ T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" +DEFAULT_MAX_TOKENS = 1_000 +DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} class OpenAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + DEFAULT_KWARGS = DEFAULT_KWARGS def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @@ -42,7 +45,9 @@ class OpenAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs + model=conversation.llm_model or DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs} ) # Get the response content from the OpenAI response @@ -74,7 +79,7 @@ class OpenAI(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - **kwargs, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response @@ -84,6 +89,8 @@ class OpenAI(BaseProvider): {"role": "user", "content": prompt}, ] response = self.client.chat.completions.create( - messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs + messages=messages, + model=llm_model or self.DEFAULT_MODEL, + **{**self.DEFAULT_KWARGS, **kwargs} ) return response.choices[0].message.content