finally working

This commit is contained in:
2024-10-28 17:29:38 -04:00
parent 86b3ce2b81
commit 300d5a1d81
7 changed files with 70 additions and 65 deletions
-29
View File
@@ -1,29 +0,0 @@
class Conversation:
"""A class to manage conversation state with an AI model."""
def __init__(self, client):
self.client = client
self.messages = []
def add_message(self, message, role="user"):
"""Add a message to the conversation history."""
self.messages.append({"role": role, "content": message})
return self
def send(self, message=None, **kwargs):
"""Send the conversation history (and optionally a new message) to the AI."""
if message:
self.add_message(message)
response = self.client.message(message_history=self.messages, **kwargs)
# Add the AI's response to the conversation history
if isinstance(response.text, str):
self.add_message(response.text, role="assistant")
return response
def set_model(self, model: str):
"""Set the model for the conversation."""
self.client.set_model(model)
return True
View File
+8 -6
View File
@@ -32,20 +32,22 @@ class SimpleMind:
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
)
return providers[self.provider](api_key=self.api_key)
return providers[self.provider]()
def generate(self, prompt: str, **kwargs) -> AIResponse:
"""Generate a response using the configured provider."""
return self._client.message(prompt, **kwargs)
def create_conversation(self, initial_message: str) -> str:
def create_conversation(self) -> str:
"""Create a new conversation and return its ID."""
conversation = self._client.create_conversation(initial_message)
return conversation.id
initial_message = "You are a helpful assistant."
def send_message(self, conversation_id: str, message: str) -> AIResponse:
conversation = self._client.create_conversation(initial_message)
return conversation
def add_message(self, conversation_id: str, message: str) -> AIResponse:
"""Send a message in an existing conversation."""
return self._client.send_message(conversation_id, message)
return self._client.add_message(conversation_id, message)
+19 -16
View File
@@ -1,13 +1,14 @@
import os
from ollama import Client as BaseOllama
import uuid
from typing import List, Optional, Dict, Any
from ..core.errors import ProviderError
from ..core.logger import logger
from .base import BaseClientProvider
from ..core.models import AIResponse
from ..concepts.conversation import Conversation
from some_module import BaseOllama # Replace 'some_module' with the actual module name
TIMEOUT = 60
class Ollama(BaseClientProvider):
def __init__(self, *args, **kwargs):
@@ -17,17 +18,17 @@ class Ollama(BaseClientProvider):
def login(self):
"""Initialize Ollama client, with Instructor enabled."""
if not os.environ.get('OLLAMA_HOST_URL'):
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
if not os.environ.get("OLLAMA_HOST_URL"):
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
if not os.environ.get('OLLAMA_MODEL'):
if not os.environ.get("OLLAMA_MODEL"):
raise ValueError("Please set the OLLAMA_MODEL environment variable")
else:
self.model = os.environ.get('OLLAMA_MODEL')
self.model = os.environ.get("OLLAMA_MODEL")
self.client = BaseOllama(
timeout=TIMEOUT,
host=os.environ.get('OLLAMA_HOST_URL'))
timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL")
)
assert self.test_connection()
@property
@@ -35,7 +36,7 @@ class Ollama(BaseClientProvider):
"""Returns the available models from the Ollama client."""
def gen():
for model in self.client.list().get('models'):
for model in self.client.list().get("models"):
yield model
return [g for g in gen()]
@@ -68,10 +69,12 @@ class Ollama(BaseClientProvider):
else:
return AIResponse(
response=completion,
text=completion.get('response'),
text=completion.get("response"),
)
def message(self, message=None, message_history=None, response_model=False, **kwargs):
def message(
self, message=None, message_history=None, response_model=False, **kwargs
):
"""Generates a response from the Ollama client."""
use_instructor = bool(response_model)
@@ -82,7 +85,7 @@ class Ollama(BaseClientProvider):
if message_history:
all_messages.extend(message_history)
if message:
all_messages.append({'role': 'user', 'content': message})
all_messages.append({"role": "user", "content": message})
params = {
"messages": all_messages,
"model": self.model,
@@ -100,7 +103,7 @@ class Ollama(BaseClientProvider):
else:
return AIResponse(
response=completion,
text=completion.get('message').get('content'),
text=completion.get("message").get("content"),
)
def start_conversation(self):
+18 -13
View File
@@ -1,13 +1,15 @@
from typing import List, Optional
from openai import OpenAI as BaseOpenAI
import openai
from ..core.errors import AuthenticationError, ProviderError
from ..core.config import settings
from ..core.logger import logger
from .base import BaseClientProvider
DEFAULT_MODEL = "gpt-4o"
class OpenAI(BaseClientProvider):
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
super().__init__(model=model, api_key=api_key)
self.login()
@@ -18,7 +20,7 @@ class OpenAI(BaseClientProvider):
raise AuthenticationError("OpenAI API key not provided")
try:
self.client = BaseOpenAI(api_key=self._api_key)
openai.api_key = self._api_key
if not self.test_connection():
raise ProviderError("Failed to connect to OpenAI API")
logger.info("Successfully connected to OpenAI")
@@ -28,7 +30,8 @@ class OpenAI(BaseClientProvider):
@property
def available_models(self) -> List[str]:
try:
return [model.id for model in self.client.models.list()]
models = openai.models.list()
return [model.id for model in models["data"]]
except Exception as e:
logger.error(f"Error fetching models: {e}")
return []
@@ -36,10 +39,10 @@ class OpenAI(BaseClientProvider):
def test_connection(self):
"""Test the connection to OpenAI API."""
try:
# A simple test call to verify API key works
self.client.models.list()
openai.models.list()
return True
except Exception as e:
logger.error(f"Connection test failed: {e}")
return False
def _handle_api_error(self, e: Exception) -> None:
@@ -47,16 +50,18 @@ class OpenAI(BaseClientProvider):
logger.error(f"OpenAI API error: {e}")
raise ProviderError(f"OpenAI API error: {e}")
def generate_response(self, conversation) -> str:
def add_message(self, conversation_id, message, *args, **kwargs) -> str:
"""Generate a response using the OpenAI API."""
try:
messages = [
{"role": msg.role, "content": msg.content}
for msg in conversation.messages
]
# Create a client instance using the API key
client = openai.OpenAI(api_key=self._api_key)
response = self.client.chat.completions.create(
model=self.model, messages=messages
# Create the message for the conversation
messages = [{"role": "user", "content": message}]
# Use the new API syntax
response = client.chat.completions.create(
model=self.model, messages=messages, *args, **kwargs
)
return response.choices[0].message.content
+24
View File
@@ -0,0 +1,24 @@
import unittest
from unittest.mock import patch, MagicMock
from simplemind.providers.openai import OpenAI
from simplemind.core.errors import AuthenticationError, ProviderError
class TestOpenAIProvider(unittest.TestCase):
def setUp(self):
self.api_key = "test_key"
@patch("simplemind.providers.openai.openai")
def test_initialization(self, mock_openai):
mock_openai.Model.list.return_value = {"data": ["gpt-4"]}
provider = OpenAI(api_key=self.api_key)
self.assertIsNotNone(provider.client)
mock_openai.Model.list.assert_called_once()
def test_missing_api_key(self):
with self.assertRaises(AuthenticationError):
OpenAI(api_key=None)
if __name__ == "__main__":
unittest.main()
+1 -1
View File
@@ -4,6 +4,6 @@ sm = SimpleMind()
# The provider will automatically use OPENAI_API_KEY from environment
conversation = sm.create_conversation()
r = conversation.send_message("Who is Kenneth Reitz?", model="gpt-4o")
r = sm.add_message(conversation.id, "Who is Kenneth Reitz?")
print(r)
#