From 300d5a1d8179492bee4bfcfa307bbc4d3a9d5679 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Mon, 28 Oct 2024 17:29:38 -0400 Subject: [PATCH] finally working --- simplemind/concepts/conversation.py | 29 ------------------------ simplemind/core.py | 0 simplemind/core/__init__.py | 14 +++++++----- simplemind/providers/ollama.py | 35 ++++++++++++++++------------- simplemind/providers/openai.py | 31 ++++++++++++++----------- simplemind/tests/test_openai.py | 24 ++++++++++++++++++++ t.py | 2 +- 7 files changed, 70 insertions(+), 65 deletions(-) delete mode 100644 simplemind/concepts/conversation.py delete mode 100644 simplemind/core.py create mode 100644 simplemind/tests/test_openai.py diff --git a/simplemind/concepts/conversation.py b/simplemind/concepts/conversation.py deleted file mode 100644 index c5d976c..0000000 --- a/simplemind/concepts/conversation.py +++ /dev/null @@ -1,29 +0,0 @@ -class Conversation: - """A class to manage conversation state with an AI model.""" - - def __init__(self, client): - self.client = client - self.messages = [] - - def add_message(self, message, role="user"): - """Add a message to the conversation history.""" - self.messages.append({"role": role, "content": message}) - return self - - def send(self, message=None, **kwargs): - """Send the conversation history (and optionally a new message) to the AI.""" - if message: - self.add_message(message) - - response = self.client.message(message_history=self.messages, **kwargs) - - # Add the AI's response to the conversation history - if isinstance(response.text, str): - self.add_message(response.text, role="assistant") - - return response - - def set_model(self, model: str): - """Set the model for the conversation.""" - self.client.set_model(model) - return True diff --git a/simplemind/core.py b/simplemind/core.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/core/__init__.py b/simplemind/core/__init__.py index fc88883..da9fda1 100644 --- a/simplemind/core/__init__.py +++ b/simplemind/core/__init__.py @@ -32,20 +32,22 @@ class SimpleMind: f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}" ) - return providers[self.provider](api_key=self.api_key) + return providers[self.provider]() def generate(self, prompt: str, **kwargs) -> AIResponse: """Generate a response using the configured provider.""" return self._client.message(prompt, **kwargs) - def create_conversation(self, initial_message: str) -> str: + def create_conversation(self) -> str: """Create a new conversation and return its ID.""" - conversation = self._client.create_conversation(initial_message) - return conversation.id + initial_message = "You are a helpful assistant." - def send_message(self, conversation_id: str, message: str) -> AIResponse: + conversation = self._client.create_conversation(initial_message) + return conversation + + def add_message(self, conversation_id: str, message: str) -> AIResponse: """Send a message in an existing conversation.""" - return self._client.send_message(conversation_id, message) + return self._client.add_message(conversation_id, message) diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 4bfb794..78ee08f 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,13 +1,14 @@ import os - -from ollama import Client as BaseOllama - +import uuid +from typing import List, Optional, Dict, Any +from ..core.errors import ProviderError +from ..core.logger import logger from .base import BaseClientProvider -from ..core.models import AIResponse -from ..concepts.conversation import Conversation +from some_module import BaseOllama # Replace 'some_module' with the actual module name TIMEOUT = 60 + class Ollama(BaseClientProvider): def __init__(self, *args, **kwargs): @@ -17,17 +18,17 @@ class Ollama(BaseClientProvider): def login(self): """Initialize Ollama client, with Instructor enabled.""" - if not os.environ.get('OLLAMA_HOST_URL'): - raise ValueError("Please set the OLLAMA_HOST_URL environment variable") + if not os.environ.get("OLLAMA_HOST_URL"): + raise ValueError("Please set the OLLAMA_HOST_URL environment variable") - if not os.environ.get('OLLAMA_MODEL'): + if not os.environ.get("OLLAMA_MODEL"): raise ValueError("Please set the OLLAMA_MODEL environment variable") else: - self.model = os.environ.get('OLLAMA_MODEL') + self.model = os.environ.get("OLLAMA_MODEL") self.client = BaseOllama( - timeout=TIMEOUT, - host=os.environ.get('OLLAMA_HOST_URL')) + timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL") + ) assert self.test_connection() @property @@ -35,7 +36,7 @@ class Ollama(BaseClientProvider): """Returns the available models from the Ollama client.""" def gen(): - for model in self.client.list().get('models'): + for model in self.client.list().get("models"): yield model return [g for g in gen()] @@ -68,10 +69,12 @@ class Ollama(BaseClientProvider): else: return AIResponse( response=completion, - text=completion.get('response'), + text=completion.get("response"), ) - def message(self, message=None, message_history=None, response_model=False, **kwargs): + def message( + self, message=None, message_history=None, response_model=False, **kwargs + ): """Generates a response from the Ollama client.""" use_instructor = bool(response_model) @@ -82,7 +85,7 @@ class Ollama(BaseClientProvider): if message_history: all_messages.extend(message_history) if message: - all_messages.append({'role': 'user', 'content': message}) + all_messages.append({"role": "user", "content": message}) params = { "messages": all_messages, "model": self.model, @@ -100,7 +103,7 @@ class Ollama(BaseClientProvider): else: return AIResponse( response=completion, - text=completion.get('message').get('content'), + text=completion.get("message").get("content"), ) def start_conversation(self): diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index c227222..1b83b41 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,13 +1,15 @@ from typing import List, Optional -from openai import OpenAI as BaseOpenAI +import openai from ..core.errors import AuthenticationError, ProviderError from ..core.config import settings from ..core.logger import logger from .base import BaseClientProvider +DEFAULT_MODEL = "gpt-4o" + class OpenAI(BaseClientProvider): - def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None): + def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None): super().__init__(model=model, api_key=api_key) self.login() @@ -18,7 +20,7 @@ class OpenAI(BaseClientProvider): raise AuthenticationError("OpenAI API key not provided") try: - self.client = BaseOpenAI(api_key=self._api_key) + openai.api_key = self._api_key if not self.test_connection(): raise ProviderError("Failed to connect to OpenAI API") logger.info("Successfully connected to OpenAI") @@ -28,7 +30,8 @@ class OpenAI(BaseClientProvider): @property def available_models(self) -> List[str]: try: - return [model.id for model in self.client.models.list()] + models = openai.models.list() + return [model.id for model in models["data"]] except Exception as e: logger.error(f"Error fetching models: {e}") return [] @@ -36,10 +39,10 @@ class OpenAI(BaseClientProvider): def test_connection(self): """Test the connection to OpenAI API.""" try: - # A simple test call to verify API key works - self.client.models.list() + openai.models.list() return True except Exception as e: + logger.error(f"Connection test failed: {e}") return False def _handle_api_error(self, e: Exception) -> None: @@ -47,16 +50,18 @@ class OpenAI(BaseClientProvider): logger.error(f"OpenAI API error: {e}") raise ProviderError(f"OpenAI API error: {e}") - def generate_response(self, conversation) -> str: + def add_message(self, conversation_id, message, *args, **kwargs) -> str: """Generate a response using the OpenAI API.""" try: - messages = [ - {"role": msg.role, "content": msg.content} - for msg in conversation.messages - ] + # Create a client instance using the API key + client = openai.OpenAI(api_key=self._api_key) - response = self.client.chat.completions.create( - model=self.model, messages=messages + # Create the message for the conversation + messages = [{"role": "user", "content": message}] + + # Use the new API syntax + response = client.chat.completions.create( + model=self.model, messages=messages, *args, **kwargs ) return response.choices[0].message.content diff --git a/simplemind/tests/test_openai.py b/simplemind/tests/test_openai.py new file mode 100644 index 0000000..7721493 --- /dev/null +++ b/simplemind/tests/test_openai.py @@ -0,0 +1,24 @@ +import unittest +from unittest.mock import patch, MagicMock +from simplemind.providers.openai import OpenAI +from simplemind.core.errors import AuthenticationError, ProviderError + + +class TestOpenAIProvider(unittest.TestCase): + def setUp(self): + self.api_key = "test_key" + + @patch("simplemind.providers.openai.openai") + def test_initialization(self, mock_openai): + mock_openai.Model.list.return_value = {"data": ["gpt-4"]} + provider = OpenAI(api_key=self.api_key) + self.assertIsNotNone(provider.client) + mock_openai.Model.list.assert_called_once() + + def test_missing_api_key(self): + with self.assertRaises(AuthenticationError): + OpenAI(api_key=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/t.py b/t.py index ad7cc26..2863b50 100644 --- a/t.py +++ b/t.py @@ -4,6 +4,6 @@ sm = SimpleMind() # The provider will automatically use OPENAI_API_KEY from environment conversation = sm.create_conversation() -r = conversation.send_message("Who is Kenneth Reitz?", model="gpt-4o") +r = sm.add_message(conversation.id, "Who is Kenneth Reitz?") print(r) #