mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
finally working
This commit is contained in:
@@ -1,29 +0,0 @@
|
||||
class Conversation:
|
||||
"""A class to manage conversation state with an AI model."""
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.messages = []
|
||||
|
||||
def add_message(self, message, role="user"):
|
||||
"""Add a message to the conversation history."""
|
||||
self.messages.append({"role": role, "content": message})
|
||||
return self
|
||||
|
||||
def send(self, message=None, **kwargs):
|
||||
"""Send the conversation history (and optionally a new message) to the AI."""
|
||||
if message:
|
||||
self.add_message(message)
|
||||
|
||||
response = self.client.message(message_history=self.messages, **kwargs)
|
||||
|
||||
# Add the AI's response to the conversation history
|
||||
if isinstance(response.text, str):
|
||||
self.add_message(response.text, role="assistant")
|
||||
|
||||
return response
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""Set the model for the conversation."""
|
||||
self.client.set_model(model)
|
||||
return True
|
||||
@@ -32,20 +32,22 @@ class SimpleMind:
|
||||
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
|
||||
)
|
||||
|
||||
return providers[self.provider](api_key=self.api_key)
|
||||
return providers[self.provider]()
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> AIResponse:
|
||||
"""Generate a response using the configured provider."""
|
||||
|
||||
return self._client.message(prompt, **kwargs)
|
||||
|
||||
def create_conversation(self, initial_message: str) -> str:
|
||||
def create_conversation(self) -> str:
|
||||
"""Create a new conversation and return its ID."""
|
||||
|
||||
conversation = self._client.create_conversation(initial_message)
|
||||
return conversation.id
|
||||
initial_message = "You are a helpful assistant."
|
||||
|
||||
def send_message(self, conversation_id: str, message: str) -> AIResponse:
|
||||
conversation = self._client.create_conversation(initial_message)
|
||||
return conversation
|
||||
|
||||
def add_message(self, conversation_id: str, message: str) -> AIResponse:
|
||||
"""Send a message in an existing conversation."""
|
||||
|
||||
return self._client.send_message(conversation_id, message)
|
||||
return self._client.add_message(conversation_id, message)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import os
|
||||
|
||||
from ollama import Client as BaseOllama
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from ..core.errors import ProviderError
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
from ..core.models import AIResponse
|
||||
from ..concepts.conversation import Conversation
|
||||
from some_module import BaseOllama # Replace 'some_module' with the actual module name
|
||||
|
||||
TIMEOUT = 60
|
||||
|
||||
|
||||
class Ollama(BaseClientProvider):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -17,17 +18,17 @@ class Ollama(BaseClientProvider):
|
||||
|
||||
def login(self):
|
||||
"""Initialize Ollama client, with Instructor enabled."""
|
||||
if not os.environ.get('OLLAMA_HOST_URL'):
|
||||
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
|
||||
if not os.environ.get("OLLAMA_HOST_URL"):
|
||||
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
|
||||
|
||||
if not os.environ.get('OLLAMA_MODEL'):
|
||||
if not os.environ.get("OLLAMA_MODEL"):
|
||||
raise ValueError("Please set the OLLAMA_MODEL environment variable")
|
||||
else:
|
||||
self.model = os.environ.get('OLLAMA_MODEL')
|
||||
self.model = os.environ.get("OLLAMA_MODEL")
|
||||
|
||||
self.client = BaseOllama(
|
||||
timeout=TIMEOUT,
|
||||
host=os.environ.get('OLLAMA_HOST_URL'))
|
||||
timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL")
|
||||
)
|
||||
assert self.test_connection()
|
||||
|
||||
@property
|
||||
@@ -35,7 +36,7 @@ class Ollama(BaseClientProvider):
|
||||
"""Returns the available models from the Ollama client."""
|
||||
|
||||
def gen():
|
||||
for model in self.client.list().get('models'):
|
||||
for model in self.client.list().get("models"):
|
||||
yield model
|
||||
|
||||
return [g for g in gen()]
|
||||
@@ -68,10 +69,12 @@ class Ollama(BaseClientProvider):
|
||||
else:
|
||||
return AIResponse(
|
||||
response=completion,
|
||||
text=completion.get('response'),
|
||||
text=completion.get("response"),
|
||||
)
|
||||
|
||||
def message(self, message=None, message_history=None, response_model=False, **kwargs):
|
||||
def message(
|
||||
self, message=None, message_history=None, response_model=False, **kwargs
|
||||
):
|
||||
"""Generates a response from the Ollama client."""
|
||||
use_instructor = bool(response_model)
|
||||
|
||||
@@ -82,7 +85,7 @@ class Ollama(BaseClientProvider):
|
||||
if message_history:
|
||||
all_messages.extend(message_history)
|
||||
if message:
|
||||
all_messages.append({'role': 'user', 'content': message})
|
||||
all_messages.append({"role": "user", "content": message})
|
||||
params = {
|
||||
"messages": all_messages,
|
||||
"model": self.model,
|
||||
@@ -100,7 +103,7 @@ class Ollama(BaseClientProvider):
|
||||
else:
|
||||
return AIResponse(
|
||||
response=completion,
|
||||
text=completion.get('message').get('content'),
|
||||
text=completion.get("message").get("content"),
|
||||
)
|
||||
|
||||
def start_conversation(self):
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import List, Optional
|
||||
from openai import OpenAI as BaseOpenAI
|
||||
import openai
|
||||
from ..core.errors import AuthenticationError, ProviderError
|
||||
from ..core.config import settings
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
class OpenAI(BaseClientProvider):
|
||||
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
|
||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
||||
super().__init__(model=model, api_key=api_key)
|
||||
self.login()
|
||||
|
||||
@@ -18,7 +20,7 @@ class OpenAI(BaseClientProvider):
|
||||
raise AuthenticationError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
self.client = BaseOpenAI(api_key=self._api_key)
|
||||
openai.api_key = self._api_key
|
||||
if not self.test_connection():
|
||||
raise ProviderError("Failed to connect to OpenAI API")
|
||||
logger.info("Successfully connected to OpenAI")
|
||||
@@ -28,7 +30,8 @@ class OpenAI(BaseClientProvider):
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
try:
|
||||
return [model.id for model in self.client.models.list()]
|
||||
models = openai.models.list()
|
||||
return [model.id for model in models["data"]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models: {e}")
|
||||
return []
|
||||
@@ -36,10 +39,10 @@ class OpenAI(BaseClientProvider):
|
||||
def test_connection(self):
|
||||
"""Test the connection to OpenAI API."""
|
||||
try:
|
||||
# A simple test call to verify API key works
|
||||
self.client.models.list()
|
||||
openai.models.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def _handle_api_error(self, e: Exception) -> None:
|
||||
@@ -47,16 +50,18 @@ class OpenAI(BaseClientProvider):
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise ProviderError(f"OpenAI API error: {e}")
|
||||
|
||||
def generate_response(self, conversation) -> str:
|
||||
def add_message(self, conversation_id, message, *args, **kwargs) -> str:
|
||||
"""Generate a response using the OpenAI API."""
|
||||
try:
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
# Create a client instance using the API key
|
||||
client = openai.OpenAI(api_key=self._api_key)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model, messages=messages
|
||||
# Create the message for the conversation
|
||||
messages = [{"role": "user", "content": message}]
|
||||
|
||||
# Use the new API syntax
|
||||
response = client.chat.completions.create(
|
||||
model=self.model, messages=messages, *args, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
||||
|
||||
|
||||
class TestOpenAIProvider(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_key"
|
||||
|
||||
@patch("simplemind.providers.openai.openai")
|
||||
def test_initialization(self, mock_openai):
|
||||
mock_openai.Model.list.return_value = {"data": ["gpt-4"]}
|
||||
provider = OpenAI(api_key=self.api_key)
|
||||
self.assertIsNotNone(provider.client)
|
||||
mock_openai.Model.list.assert_called_once()
|
||||
|
||||
def test_missing_api_key(self):
|
||||
with self.assertRaises(AuthenticationError):
|
||||
OpenAI(api_key=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -4,6 +4,6 @@ sm = SimpleMind()
|
||||
|
||||
# The provider will automatically use OPENAI_API_KEY from environment
|
||||
conversation = sm.create_conversation()
|
||||
r = conversation.send_message("Who is Kenneth Reitz?", model="gpt-4o")
|
||||
r = sm.add_message(conversation.id, "Who is Kenneth Reitz?")
|
||||
print(r)
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user