Refactor code structure and imports in simplemind package

This commit is contained in:
2024-10-28 16:24:21 -04:00
parent 0e81e5bf17
commit 25b523e372
24 changed files with 211 additions and 175 deletions
-1
View File
@@ -1 +0,0 @@
from .integrations import *
-13
View File
@@ -1,13 +0,0 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
openai_api_key: str = ""
anthropic_api_key: str = ""
default_model: str = "gpt-4"
class Config:
env_file = ".env"
settings = Settings()
+22 -10
View File
@@ -1,12 +1,24 @@
class Conversation:
def __init__(self, ai_client):
self.messages = []
self.ai_client = ai_client
def say(self, message):
self.messages.append({'role': 'user', 'content': message})
"""A class to manage conversation state with an AI model."""
def get_reply(self):
reply = self.ai_client.message(messages=self.messages)
self.messages.append({'role': 'system', 'content': reply.text})
return reply
def __init__(self, client):
self.client = client
self.messages = []
def add_message(self, message, role="user"):
"""Add a message to the conversation history."""
self.messages.append({"role": role, "content": message})
return self
def send(self, message=None, **kwargs):
"""Send the conversation history (and optionally a new message) to the AI."""
if message:
self.add_message(message)
response = self.client.message(message_history=self.messages, **kwargs)
# Add the AI's response to the conversation history
if isinstance(response.text, str):
self.add_message(response.text, role="assistant")
return response
-44
View File
@@ -1,44 +0,0 @@
from typing import Dict, Any, Optional
from .models import AIResponse
from .concepts import Context
from .integrations.base import BaseClientProvider
class SimpleMind:
"""Main class for SimpleMind functionality."""
def __init__(
self, api_key: str, provider: str = "openai", context: Optional[Context] = None
):
"""Initialize SimpleMind with the specified provider."""
self.api_key = api_key
self.provider = provider
self.context = context or Context()
self._client = self._get_provider()
def _get_provider(self) -> BaseClientProvider:
"""Get the appropriate provider client."""
from .integrations.openai import OpenAI
from .integrations.anthropic import Anthropic
providers = {"openai": OpenAI, "anthropic": Anthropic}
if self.provider not in providers:
raise ValueError(
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
)
return providers[self.provider](api_key=self.api_key)
def generate(self, prompt: str, **kwargs) -> AIResponse:
"""Generate a response using the configured provider."""
return self._client.message(prompt, **kwargs)
def create_conversation(self, initial_message: str) -> str:
"""Create a new conversation and return its ID."""
conversation = self._client.create_conversation(initial_message)
return conversation.id
def send_message(self, conversation_id: str, message: str) -> AIResponse:
"""Send a message in an existing conversation."""
return self._client.send_message(conversation_id, message)
+44
View File
@@ -0,0 +1,44 @@
from typing import Dict, Any, Optional
from .models import AIResponse
from ..concepts.context import Context
from ..integrations.base import BaseClientProvider
class SimpleMind:
"""Main class for SimpleMind functionality."""
def __init__(
self, api_key: str, provider: str = "openai", context: Optional[Context] = None
):
"""Initialize SimpleMind with the specified provider."""
self.api_key = api_key
self.provider = provider
self.context = context or Context()
self._client = self._get_provider()
def _get_provider(self) -> BaseClientProvider:
"""Get the appropriate provider client."""
from .integrations.openai import OpenAI
from .integrations.anthropic import Anthropic
providers = {"openai": OpenAI, "anthropic": Anthropic}
if self.provider not in providers:
raise ValueError(
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
)
return providers[self.provider](api_key=self.api_key)
def generate(self, prompt: str, **kwargs) -> AIResponse:
"""Generate a response using the configured provider."""
return self._client.message(prompt, **kwargs)
def create_conversation(self, initial_message: str) -> str:
"""Create a new conversation and return its ID."""
conversation = self._client.create_conversation(initial_message)
return conversation.id
def send_message(self, conversation_id: str, message: str) -> AIResponse:
"""Send a message in an existing conversation."""
return self._client.send_message(conversation_id, message)
@@ -1,6 +1,6 @@
from typing import Optional
from simplemind.models import Conversation, AIResponse
from simplemind.concepts import Context
from simplemind.core.models import Conversation, AIResponse
from simplemind.concepts.context import Context
from simplemind.integrations.openai import OpenAI
from simplemind.integrations.anthropic import Anthropic
import logging
@@ -24,7 +24,7 @@ class Client:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
return self.providers[provider].create_conversation(
initial_message="Hello!", context=self.context.dict()
initial_message="Hello!", context=self.context.model_dump()
)
def _handle_api_error(self, error: Exception, operation: str):
+17
View File
@@ -0,0 +1,17 @@
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
openai_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
ollama_host_url: Optional[str] = None
default_model: str = "gpt-4"
log_level: str = "INFO"
class Config:
env_file = ".env"
case_sensitive = False
settings = Settings()
+22
View File
@@ -0,0 +1,22 @@
class SimpleMindError(Exception):
"""Base exception for SimpleMind errors."""
pass
class ProviderError(SimpleMindError):
"""Raised when there's an error with the AI provider."""
pass
class ConfigurationError(SimpleMindError):
"""Raised when there's a configuration error."""
pass
class AuthenticationError(SimpleMindError):
"""Raised when authentication fails."""
pass
+2 -2
View File
@@ -5,8 +5,8 @@ import instructor
from anthropic import Anthropic as BaseAnthropic
from .base import BaseClientProvider
from ..models import AIResponse, Conversation
from ..logger import logger
from ..core.models import AIResponse, Conversation
from ..core.logger import logger
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
+1 -1
View File
@@ -2,7 +2,7 @@
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from ..models import AIResponse, Conversation, Message
from ..core.models import AIResponse, Conversation, Message
import uuid
+2 -2
View File
@@ -3,7 +3,7 @@ import os
from ollama import Client as BaseOllama
from .base import BaseClientProvider
from ..models import AIResponse
from ..core.models import AIResponse
from ..conversation import Conversation
TIMEOUT = 60
@@ -19,7 +19,7 @@ class Ollama(BaseClientProvider):
"""Initialize Ollama client, with Instructor enabled."""
if not os.environ.get('OLLAMA_HOST_URL'):
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
if not os.environ.get('OLLAMA_MODEL'):
raise ValueError("Please set the OLLAMA_MODEL environment variable")
else:
+37 -47
View File
@@ -1,32 +1,29 @@
import os
from typing import Optional, List
import instructor
from typing import List, Optional
from openai import OpenAI as BaseOpenAI
from ..core.errors import AuthenticationError, ProviderError
from simplemind.core.config import settings
from ..core.logger import logger
from .base import BaseClientProvider
from ..models import AIResponse, Conversation
from ..logger import logger
from simplemind.config import settings
DEFAULT_MODEL = "gpt-4o"
class OpenAI(BaseClientProvider):
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
super().__init__(model=model, api_key=api_key)
self.login()
def login(self):
def login(self) -> None:
if not self._api_key:
self._api_key = settings.openai_api_key
if not self._api_key:
raise ValueError("OpenAI API key not provided.")
self.client = BaseOpenAI(api_key=self._api_key)
self.instructor_client = instructor.from_openai(self.client)
if not self.test_connection():
raise ConnectionError("Failed to connect to OpenAI API.")
logger.info("Logged in to OpenAI successfully.")
raise AuthenticationError("OpenAI API key not provided")
try:
self.client = BaseOpenAI(api_key=self._api_key)
if not self.test_connection():
raise ProviderError("Failed to connect to OpenAI API")
logger.info("Successfully connected to OpenAI")
except Exception as e:
raise ProviderError(f"OpenAI initialization failed: {str(e)}")
@property
def available_models(self) -> List[str]:
@@ -36,39 +33,32 @@ class OpenAI(BaseClientProvider):
logger.error(f"Error fetching models: {e}")
return []
def test_connection(self) -> bool:
def test_connection(self):
"""Test the connection to OpenAI API."""
try:
models = self.available_models
if models:
logger.info(f"Available models: {models}")
return True
else:
logger.warning("No available models found.")
return False
# A simple test call to verify API key works
self.client.models.list()
return True
except Exception as e:
logger.error(f"Error testing connection: {e}")
return False
def generate_response(self, conversation: Conversation) -> AIResponse:
messages = conversation.get_messages()
params = {
"model": self.model,
"messages": [
{
"role": msg.role,
"content": [{"type": "text", "text": msg.content}], # New format
}
for msg in messages
],
"temperature": getattr(
self, "temperature", 0.7
), # Use 0.7 as default if not set
}
def _handle_api_error(self, e: Exception) -> None:
"""Handle API errors."""
logger.error(f"OpenAI API error: {e}")
raise ProviderError(f"OpenAI API error: {e}")
def generate_response(self, conversation) -> str:
"""Generate a response using the OpenAI API."""
try:
completion = self.client.chat.completions.create(**params)
return completion
messages = [
{"role": msg.role, "content": msg.content}
for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=self.model, messages=messages
)
return response.choices[0].message.content
except Exception as e:
# Enhanced error handling (optional)
logger.error(f"OpenAI API Error: {e}")
raise RuntimeError(f"Failed to generate response: {e}")
self._handle_api_error(e)
+26
View File
@@ -0,0 +1,26 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
class Message(BaseModel):
role: str
content: str
created_at: datetime = datetime.now()
class Conversation(BaseModel):
id: str
messages: List[Message] = []
context: Dict[str, Any] = {}
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
def add_message(self, role: str, content: str) -> Message:
message = Message(role=role, content=content)
self.messages.append(message)
self.updated_at = datetime.now()
return message
class AIResponse(BaseModel):
text: str
response: Any
metadata: Dict[str, Any] = {}
+11 -6
View File
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Optional
class BasePlugin(ABC):
"""Base class for all SimpleMind plugins."""
def __init__(self):
self.is_enabled = True
self.enabled: bool = True
self.name: str = self.__class__.__name__
@abstractmethod
def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
@@ -20,10 +21,14 @@ class BasePlugin(ABC):
"""
pass
def enable(self):
def enable(self) -> None:
"""Enable the plugin."""
self.is_enabled = True
self.enabled = True
def disable(self):
def disable(self) -> None:
"""Disable the plugin."""
self.is_enabled = False
self.enabled = False
@property
def is_enabled(self) -> bool:
return self.enabled
View File
+11 -8
View File
@@ -2,11 +2,11 @@ import os
from pprint import pprint
from pydantic import BaseModel
import simplemind
from simplemind.concepts import Context
from simplemind.concepts.context import Context
from simplemind.plugins.kv import KVPlugin
from simplemind.plugins.basic_memory import BasicMemoryPlugin
from simplemind.chains.reverse_text import ReverseTextChain
from simplemind.client import Client
from simplemind.core.client import Client
class CustomContext(Context):
@@ -27,12 +27,15 @@ aiclient = Client(
print(aiclient.available_models)
# Example usage
conversation = aiclient.create_conversation(provider="openai")
conversation = aiclient.create_conversation(provider="anthropic", context=ctx)
response = aiclient.send_message(
conversation, "Who is Kenneth Reitz?", provider="openai"
conversation, "Who is Kenneth Reitz?", provider="anthropic"
)
print(response)
# response = aiclient.send_message(
# conversation, "Who is Kenneth Reitz?", provider="openai"
# )
# print(response)
reverse_chain = ReverseTextChain()
result = reverse_chain.run("Hello, World!")
print(result) # Output: !dlroW ,olleH
# reverse_chain = ReverseTextChain()
# result = reverse_chain.run("Hello, World!")
# print(result) # Output: !dlroW ,olleH
-19
View File
@@ -47,22 +47,3 @@ message = "who is kenneth reitz?"
print(f"> {message}")
pprint(openai.message(message, response_model=BioData))
# claude = simplemind.integrations.Anthropic()
# # print(claude.test_connection())
# # print(claude.available_models)
# claude.login()
vector_store = FAISSStore(dimension=768) # Example dimension for embeddings
# Add embeddings
embeddings = np.random.random((10, 768)).astype('float32')
ids = [f"doc_{i}" for i in range(10)]
vector_store.add_embeddings(embeddings, ids)
# Search
query_embedding = np.random.random((1, 768)).astype('float32')
results = vector_store.search(query_embedding, top_k=3)
print(results)
+3 -5
View File
@@ -1,17 +1,16 @@
import simplemind
aiclient = simplemind.Ollama()
print('Messaging client')
print("Messaging client")
message_response = aiclient.message(message="Once upon a time in a land far away...")
print(message_response)
print('Generating Text')
print("Generating Text")
generated_text = aiclient.generate_text(prompt="Once upon a time in a land far away...")
print(generated_text)
print('Initiating Conversation')
print("Initiating Conversation")
conversation = aiclient.start_conversation()
# Add a message to the conversation
@@ -27,4 +26,3 @@ conversation.say("What number did I ask you to remember?")
# Get the AI's response
reply = conversation.get_reply()
print(reply)
+10 -14
View File
@@ -1,25 +1,21 @@
import unittest
from unittest.mock import patch, MagicMock
from simplemind.integrations.openai import OpenAI
from simplemind.core.errors import AuthenticationError, ProviderError
class TestOpenAIProvider(unittest.TestCase):
def setUp(self):
self.api_key = "test_key"
@patch("simplemind.integrations.openai.BaseOpenAI")
def setUp(self, mock_openai):
self.mock_openai = mock_openai.return_value
self.mock_openai.models.list.return_value = [MagicMock(id="gpt-4")]
self.provider = OpenAI(api_key="test_api_key", model="gpt-4")
def test_initialization(self, mock_openai):
provider = OpenAI(api_key=self.api_key)
self.assertIsNotNone(provider.client)
def test_available_models(self):
models = self.provider.available_models
self.assertIn("gpt-4", models)
def test_test_connection_success(self):
self.assertTrue(self.provider.test_connection())
def test_generate_response_not_implemented(self):
with self.assertRaises(NotImplementedError):
self.provider.generate_response(None)
def test_missing_api_key(self):
with self.assertRaises(AuthenticationError):
OpenAI(api_key=None)
if __name__ == "__main__":