diff --git a/.envrc.template b/.envrc.template index cb0f439..c365f90 100644 --- a/.envrc.template +++ b/.envrc.template @@ -1,4 +1,5 @@ export OPENAI_API_KEY="" export ANTHROPIC_API_KEY="" export XAI_API_KEY="" -export GROQ_API_KEY="" \ No newline at end of file +export OLLAMA_HOST_URL="" +export GROQ_API_KEY="" diff --git a/.gitignore b/.gitignore index b66e3ef..e43c919 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ cython_debug/ .env src/** +requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 906e880..03250bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.1" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.11" -dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "groq"] +dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq"] [build-system] requires = ["hatchling"] diff --git a/simplemind/models.py b/simplemind/models.py index b7448c1..c6325ed 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -60,6 +60,9 @@ class Conversation(SMBaseModel): def __str__(self): return f"" + def prepend_system_message(self, role: str, text: str, meta: Optional[Dict[str, Any]] = None): + self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages + def add_message( self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None ): diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 9e38db1..403d1a7 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -4,6 +4,7 @@ from simplemind.providers._base import BaseProvider from simplemind.providers.anthropic import Anthropic from simplemind.providers.groq import Groq from simplemind.providers.openai import OpenAI +from simplemind.providers.ollama import Ollama from simplemind.providers.xai import XAI -providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, XAI] +providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI] diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py new file mode 100644 index 0000000..ce96503 --- /dev/null +++ b/simplemind/providers/ollama.py @@ -0,0 +1,77 @@ +import ollama as ol +import instructor +from openai import OpenAI + +from ._base import BaseProvider +from ..settings import settings + +PROVIDER_NAME = "ollama" +DEFAULT_MODEL = "llama3.2" +DEFAULT_TIMEOUT = 60 + + +class Ollama(BaseProvider): + NAME = PROVIDER_NAME + DEFAULT_MODEL = DEFAULT_MODEL + TIMEOUT = DEFAULT_TIMEOUT + + def __init__(self, host_url: str = None): + self.host_url = host_url or settings.OLLAMA_HOST_URL + + @property + def client(self): + """The raw Ollama client.""" + if not self.host_url: + raise ValueError("No ollama host url provided") + return ol.Client(timeout=self.TIMEOUT, host=self.host_url) + + @property + def structured_client(self): + """A client patched with Instructor.""" + return instructor.from_openai( + OpenAI( + base_url=f"{self.host_url}/v1", + api_key="ollama", + ), + mode=instructor.Mode.JSON, + ) + + def send_conversation(self, conversation: "Conversation"): + """Send a conversation to the Ollama API.""" + from ..models import Message + + messages = [ + {"role": msg.role, "content": msg.text} for msg in conversation.messages + ] + response = self.client.chat( + model=conversation.llm_model or DEFAULT_MODEL, messages=messages + ) + assistant_message = response.get("message") + + # Create and return a properly formatted Message instance + return Message( + role="assistant", + text=assistant_message.get("content"), + raw=response, + llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_provider=PROVIDER_NAME, + ) + + def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, model=llm_model, response_model=response_model, **kwargs + ) + return response + + def generate_text(self, prompt, *, llm_model): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.client.chat(messages=messages, model=llm_model) + + return response.get("message").get("content") diff --git a/simplemind/settings.py b/simplemind/settings.py index eb420aa..2545f93 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -12,6 +12,7 @@ class Settings(BaseSettings): ) GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq") OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI") + OLLAMA_HOST_URL: Optional[str] = Field(None, description="Fully qualified host URL for Ollama") XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI") DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider") diff --git a/test_ollama.py b/test_ollama.py new file mode 100644 index 0000000..0a5b560 --- /dev/null +++ b/test_ollama.py @@ -0,0 +1,66 @@ +import os +import unittest +from unittest import mock +import simplemind as sm +from pydantic import BaseModel + +class TestOllama(unittest.TestCase): + + def test_generate_text(self): + result = sm.generate_text(prompt="What is the meaning of life?", llm_provider="ollama", llm_model="llama3.2") + self.assertGreater(len(result), 0) + self.assertIsNotNone(result) + + def test_create_conversation(self): + conversation = sm.create_conversation(llm_provider="ollama", llm_model="llama3.2") + conversation.add_message("user", "Remember the number 42.") + result = conversation.send() + self.assertIsNotNone(result) + self.assertGreaterEqual(len(result.text), 0) + self.assertIsInstance(result, sm.models.Message) + + def test_memory(self): + class SimpleMemoryPlugin: + def __init__(self): + self.memories = [ + "the earth has fictionally been destroyed.", + "the moon is made of cheese.", + ] + + def yield_memories(self): + return (m for m in self.memories) + + def send_hook(self, conversation: sm.Conversation): + for m in self.yield_memories(): + conversation.prepend_system_message(role="system", text=m) + + conversation = sm.create_conversation(llm_provider="ollama", llm_model="llama3.2") + + conversation.add_message( + role="user", + text="Write a poem about the moon", + ) + self.assertGreater(len(conversation.messages), 0) + conversation.add_plugin(SimpleMemoryPlugin()) + result = conversation.send() + self.assertGreater(len(conversation.messages), 2) + self.assertIsNotNone(result) + self.assertIsNotNone(result.text) + self.assertGreater(len(result.text), 0) + self.assertIsInstance(result, sm.models.Message) + + def test_structure_response(self): + class Poem(BaseModel): + title: str + content: str + # Test for NotImplementedError + with self.assertRaises(NotImplementedError): + sm.generate_data( + prompt="Write a poem about love", + llm_provider="ollama", + llm_model="llama3.2", + response_model=Poem) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file