Merge pull request #24 from barisozmen/default_kwargs

Add default kwargs logic to Groq, OpenAI, XAI, and Ollama providers
This commit is contained in:
2024-10-31 19:48:02 -04:00
committed by GitHub
4 changed files with 35 additions and 13 deletions
+6 -3
View File
@@ -10,6 +10,8 @@ from ._base import BaseProvider
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
T = TypeVar("T", bound=BaseModel)
@@ -17,6 +19,7 @@ T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -48,7 +51,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the Groq response
@@ -73,7 +76,7 @@ class Groq(BaseProvider):
messages=messages,
response_model=response_model,
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
@@ -91,7 +94,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
+15 -5
View File
@@ -15,11 +15,17 @@ T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {
"max_tokens": DEFAULT_MAX_TOKENS,
"timeout": DEFAULT_TIMEOUT,
}
class Ollama(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str | None = None):
@@ -43,7 +49,7 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON,
)
def send_conversation(self, conversation: "Conversation") -> "Message":
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Ollama API."""
from ..models import Message
@@ -51,7 +57,9 @@ class Ollama(BaseProvider):
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs}
)
assistant_message = response.get("message")
@@ -81,18 +89,20 @@ class Ollama(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat(
messages=messages, model=llm_model or self.DEFAULT_MODEL
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response.get("message", {}).get("content", "")
+10 -3
View File
@@ -12,11 +12,14 @@ T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -42,7 +45,9 @@ class OpenAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs}
)
# Get the response content from the OpenAI response
@@ -74,7 +79,7 @@ class OpenAI(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response
@@ -84,6 +89,8 @@ class OpenAI(BaseProvider):
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs}
)
return response.choices[0].message.content
+4 -2
View File
@@ -10,11 +10,13 @@ PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -45,7 +47,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the OpenAI response
@@ -71,7 +73,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content