diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index e4b45fd..ce96503 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,19 +1,19 @@ import ollama as ol +import instructor +from openai import OpenAI from ._base import BaseProvider from ..settings import settings PROVIDER_NAME = "ollama" DEFAULT_MODEL = "llama3.2" -TIMEOUT = 60 -NOT_IMPLEMENTED_REASON = """ -# TODO: instructor does not natively support ollama. -# Alternate python dependency may be required -""" +DEFAULT_TIMEOUT = 60 + class Ollama(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL + TIMEOUT = DEFAULT_TIMEOUT def __init__(self, host_url: str = None): self.host_url = host_url or settings.OLLAMA_HOST_URL @@ -23,18 +23,23 @@ class Ollama(BaseProvider): """The raw Ollama client.""" if not self.host_url: raise ValueError("No ollama host url provided") - return ol.Client( - timeout=TIMEOUT, - host=self.host_url) + return ol.Client(timeout=self.TIMEOUT, host=self.host_url) @property def structured_client(self): """A client patched with Instructor.""" - raise NotImplementedError(NOT_IMPLEMENTED_REASON) + return instructor.from_openai( + OpenAI( + base_url=f"{self.host_url}/v1", + api_key="ollama", + ), + mode=instructor.Mode.JSON, + ) def send_conversation(self, conversation: "Conversation"): """Send a conversation to the Ollama API.""" from ..models import Message + messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] @@ -52,16 +57,21 @@ class Ollama(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, *args, **kwargs): - raise NotImplementedError(NOT_IMPLEMENTED_REASON) + def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, model=llm_model, response_model=response_model, **kwargs + ) + return response def generate_text(self, prompt, *, llm_model): messages = [ {"role": "user", "content": prompt}, ] - response = self.client.chat( - messages=messages, model=llm_model - ) + response = self.client.chat(messages=messages, model=llm_model) return response.get("message").get("content")