added default_kwargs logic to OpenAI provider

This commit is contained in:
Barış Özmen
2024-11-01 00:15:49 +03:00
parent cbc3739411
commit 37a9333be3
+10 -3
View File
@@ -12,11 +12,14 @@ T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -42,7 +45,9 @@ class OpenAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs}
)
# Get the response content from the OpenAI response
@@ -74,7 +79,7 @@ class OpenAI(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response
@@ -84,6 +89,8 @@ class OpenAI(BaseProvider):
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response.choices[0].message.content