Refactor default_kwargs logic in Ollama provider

This commit is contained in:
2024-10-31 19:49:33 -04:00
parent 0795464fd7
commit caceba381d
+7 -9
View File
@@ -15,11 +15,7 @@ T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {
"max_tokens": DEFAULT_MAX_TOKENS,
"timeout": DEFAULT_TIMEOUT,
}
DEFAULT_KWARGS = {}
class Ollama(BaseProvider):
@@ -59,7 +55,7 @@ class Ollama(BaseProvider):
response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs}
**{**self.DEFAULT_KWARGS, **kwargs},
)
assistant_message = response.get("message")
@@ -89,11 +85,13 @@ class Ollama(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs}
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs) -> str:
def generate_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
@@ -102,7 +100,7 @@ class Ollama(BaseProvider):
response = self.client.chat(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs}
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.get("message", {}).get("content", "")