mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Refactor code structure and imports in simplemind package
This commit is contained in:
@@ -1 +0,0 @@
|
||||
from .integrations import *
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
openai_api_key: str = ""
|
||||
anthropic_api_key: str = ""
|
||||
default_model: str = "gpt-4"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
+22
-10
@@ -1,12 +1,24 @@
|
||||
class Conversation:
|
||||
def __init__(self, ai_client):
|
||||
self.messages = []
|
||||
self.ai_client = ai_client
|
||||
|
||||
def say(self, message):
|
||||
self.messages.append({'role': 'user', 'content': message})
|
||||
"""A class to manage conversation state with an AI model."""
|
||||
|
||||
def get_reply(self):
|
||||
reply = self.ai_client.message(messages=self.messages)
|
||||
self.messages.append({'role': 'system', 'content': reply.text})
|
||||
return reply
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.messages = []
|
||||
|
||||
def add_message(self, message, role="user"):
|
||||
"""Add a message to the conversation history."""
|
||||
self.messages.append({"role": role, "content": message})
|
||||
return self
|
||||
|
||||
def send(self, message=None, **kwargs):
|
||||
"""Send the conversation history (and optionally a new message) to the AI."""
|
||||
if message:
|
||||
self.add_message(message)
|
||||
|
||||
response = self.client.message(message_history=self.messages, **kwargs)
|
||||
|
||||
# Add the AI's response to the conversation history
|
||||
if isinstance(response.text, str):
|
||||
self.add_message(response.text, role="assistant")
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from .models import AIResponse
|
||||
from .concepts import Context
|
||||
from .integrations.base import BaseClientProvider
|
||||
|
||||
|
||||
class SimpleMind:
|
||||
"""Main class for SimpleMind functionality."""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, provider: str = "openai", context: Optional[Context] = None
|
||||
):
|
||||
"""Initialize SimpleMind with the specified provider."""
|
||||
self.api_key = api_key
|
||||
self.provider = provider
|
||||
self.context = context or Context()
|
||||
self._client = self._get_provider()
|
||||
|
||||
def _get_provider(self) -> BaseClientProvider:
|
||||
"""Get the appropriate provider client."""
|
||||
from .integrations.openai import OpenAI
|
||||
from .integrations.anthropic import Anthropic
|
||||
|
||||
providers = {"openai": OpenAI, "anthropic": Anthropic}
|
||||
|
||||
if self.provider not in providers:
|
||||
raise ValueError(
|
||||
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
|
||||
)
|
||||
|
||||
return providers[self.provider](api_key=self.api_key)
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> AIResponse:
|
||||
"""Generate a response using the configured provider."""
|
||||
return self._client.message(prompt, **kwargs)
|
||||
|
||||
def create_conversation(self, initial_message: str) -> str:
|
||||
"""Create a new conversation and return its ID."""
|
||||
conversation = self._client.create_conversation(initial_message)
|
||||
return conversation.id
|
||||
|
||||
def send_message(self, conversation_id: str, message: str) -> AIResponse:
|
||||
"""Send a message in an existing conversation."""
|
||||
return self._client.send_message(conversation_id, message)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from .models import AIResponse
|
||||
from ..concepts.context import Context
|
||||
from ..integrations.base import BaseClientProvider
|
||||
|
||||
|
||||
class SimpleMind:
|
||||
"""Main class for SimpleMind functionality."""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, provider: str = "openai", context: Optional[Context] = None
|
||||
):
|
||||
"""Initialize SimpleMind with the specified provider."""
|
||||
self.api_key = api_key
|
||||
self.provider = provider
|
||||
self.context = context or Context()
|
||||
self._client = self._get_provider()
|
||||
|
||||
def _get_provider(self) -> BaseClientProvider:
|
||||
"""Get the appropriate provider client."""
|
||||
from .integrations.openai import OpenAI
|
||||
from .integrations.anthropic import Anthropic
|
||||
|
||||
providers = {"openai": OpenAI, "anthropic": Anthropic}
|
||||
|
||||
if self.provider not in providers:
|
||||
raise ValueError(
|
||||
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
|
||||
)
|
||||
|
||||
return providers[self.provider](api_key=self.api_key)
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> AIResponse:
|
||||
"""Generate a response using the configured provider."""
|
||||
return self._client.message(prompt, **kwargs)
|
||||
|
||||
def create_conversation(self, initial_message: str) -> str:
|
||||
"""Create a new conversation and return its ID."""
|
||||
conversation = self._client.create_conversation(initial_message)
|
||||
return conversation.id
|
||||
|
||||
def send_message(self, conversation_id: str, message: str) -> AIResponse:
|
||||
"""Send a message in an existing conversation."""
|
||||
return self._client.send_message(conversation_id, message)
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from simplemind.models import Conversation, AIResponse
|
||||
from simplemind.concepts import Context
|
||||
from simplemind.core.models import Conversation, AIResponse
|
||||
from simplemind.concepts.context import Context
|
||||
from simplemind.integrations.openai import OpenAI
|
||||
from simplemind.integrations.anthropic import Anthropic
|
||||
import logging
|
||||
@@ -24,7 +24,7 @@ class Client:
|
||||
if provider not in self.providers:
|
||||
raise ValueError(f"Provider '{provider}' not supported.")
|
||||
return self.providers[provider].create_conversation(
|
||||
initial_message="Hello!", context=self.context.dict()
|
||||
initial_message="Hello!", context=self.context.model_dump()
|
||||
)
|
||||
|
||||
def _handle_api_error(self, error: Exception, operation: str):
|
||||
@@ -0,0 +1,17 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
openai_api_key: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
ollama_host_url: Optional[str] = None
|
||||
default_model: str = "gpt-4"
|
||||
log_level: str = "INFO"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,22 @@
|
||||
class SimpleMindError(Exception):
|
||||
"""Base exception for SimpleMind errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProviderError(SimpleMindError):
|
||||
"""Raised when there's an error with the AI provider."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(SimpleMindError):
|
||||
"""Raised when there's a configuration error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(SimpleMindError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
pass
|
||||
@@ -5,8 +5,8 @@ import instructor
|
||||
from anthropic import Anthropic as BaseAnthropic
|
||||
|
||||
from .base import BaseClientProvider
|
||||
from ..models import AIResponse, Conversation
|
||||
from ..logger import logger
|
||||
from ..core.models import AIResponse, Conversation
|
||||
from ..core.logger import logger
|
||||
|
||||
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from ..models import AIResponse, Conversation, Message
|
||||
from ..core.models import AIResponse, Conversation, Message
|
||||
import uuid
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from ollama import Client as BaseOllama
|
||||
|
||||
from .base import BaseClientProvider
|
||||
from ..models import AIResponse
|
||||
from ..core.models import AIResponse
|
||||
from ..conversation import Conversation
|
||||
|
||||
TIMEOUT = 60
|
||||
@@ -19,7 +19,7 @@ class Ollama(BaseClientProvider):
|
||||
"""Initialize Ollama client, with Instructor enabled."""
|
||||
if not os.environ.get('OLLAMA_HOST_URL'):
|
||||
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
|
||||
|
||||
|
||||
if not os.environ.get('OLLAMA_MODEL'):
|
||||
raise ValueError("Please set the OLLAMA_MODEL environment variable")
|
||||
else:
|
||||
|
||||
@@ -1,32 +1,29 @@
|
||||
import os
|
||||
from typing import Optional, List
|
||||
import instructor
|
||||
from typing import List, Optional
|
||||
from openai import OpenAI as BaseOpenAI
|
||||
|
||||
from ..core.errors import AuthenticationError, ProviderError
|
||||
from simplemind.core.config import settings
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
from ..models import AIResponse, Conversation
|
||||
from ..logger import logger
|
||||
from simplemind.config import settings
|
||||
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
class OpenAI(BaseClientProvider):
|
||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
||||
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
|
||||
super().__init__(model=model, api_key=api_key)
|
||||
self.login()
|
||||
|
||||
def login(self):
|
||||
def login(self) -> None:
|
||||
if not self._api_key:
|
||||
self._api_key = settings.openai_api_key
|
||||
if not self._api_key:
|
||||
raise ValueError("OpenAI API key not provided.")
|
||||
self.client = BaseOpenAI(api_key=self._api_key)
|
||||
self.instructor_client = instructor.from_openai(self.client)
|
||||
if not self.test_connection():
|
||||
raise ConnectionError("Failed to connect to OpenAI API.")
|
||||
logger.info("Logged in to OpenAI successfully.")
|
||||
raise AuthenticationError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
self.client = BaseOpenAI(api_key=self._api_key)
|
||||
if not self.test_connection():
|
||||
raise ProviderError("Failed to connect to OpenAI API")
|
||||
logger.info("Successfully connected to OpenAI")
|
||||
except Exception as e:
|
||||
raise ProviderError(f"OpenAI initialization failed: {str(e)}")
|
||||
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
@@ -36,39 +33,32 @@ class OpenAI(BaseClientProvider):
|
||||
logger.error(f"Error fetching models: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
def test_connection(self):
|
||||
"""Test the connection to OpenAI API."""
|
||||
try:
|
||||
models = self.available_models
|
||||
if models:
|
||||
logger.info(f"Available models: {models}")
|
||||
return True
|
||||
else:
|
||||
logger.warning("No available models found.")
|
||||
return False
|
||||
# A simple test call to verify API key works
|
||||
self.client.models.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing connection: {e}")
|
||||
return False
|
||||
|
||||
def generate_response(self, conversation: Conversation) -> AIResponse:
|
||||
messages = conversation.get_messages()
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": [{"type": "text", "text": msg.content}], # New format
|
||||
}
|
||||
for msg in messages
|
||||
],
|
||||
"temperature": getattr(
|
||||
self, "temperature", 0.7
|
||||
), # Use 0.7 as default if not set
|
||||
}
|
||||
def _handle_api_error(self, e: Exception) -> None:
|
||||
"""Handle API errors."""
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise ProviderError(f"OpenAI API error: {e}")
|
||||
|
||||
def generate_response(self, conversation) -> str:
|
||||
"""Generate a response using the OpenAI API."""
|
||||
try:
|
||||
completion = self.client.chat.completions.create(**params)
|
||||
return completion
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model, messages=messages
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
# Enhanced error handling (optional)
|
||||
logger.error(f"OpenAI API Error: {e}")
|
||||
raise RuntimeError(f"Failed to generate response: {e}")
|
||||
self._handle_api_error(e)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime = datetime.now()
|
||||
|
||||
class Conversation(BaseModel):
|
||||
id: str
|
||||
messages: List[Message] = []
|
||||
context: Dict[str, Any] = {}
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
|
||||
def add_message(self, role: str, content: str) -> Message:
|
||||
message = Message(role=role, content=content)
|
||||
self.messages.append(message)
|
||||
self.updated_at = datetime.now()
|
||||
return message
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
text: str
|
||||
response: Any
|
||||
metadata: Dict[str, Any] = {}
|
||||
@@ -1,12 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class BasePlugin(ABC):
|
||||
"""Base class for all SimpleMind plugins."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_enabled = True
|
||||
self.enabled: bool = True
|
||||
self.name: str = self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -20,10 +21,14 @@ class BasePlugin(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> None:
|
||||
"""Enable the plugin."""
|
||||
self.is_enabled = True
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> None:
|
||||
"""Disable the plugin."""
|
||||
self.is_enabled = False
|
||||
self.enabled = False
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
@@ -2,11 +2,11 @@ import os
|
||||
from pprint import pprint
|
||||
from pydantic import BaseModel
|
||||
import simplemind
|
||||
from simplemind.concepts import Context
|
||||
from simplemind.concepts.context import Context
|
||||
from simplemind.plugins.kv import KVPlugin
|
||||
from simplemind.plugins.basic_memory import BasicMemoryPlugin
|
||||
from simplemind.chains.reverse_text import ReverseTextChain
|
||||
from simplemind.client import Client
|
||||
from simplemind.core.client import Client
|
||||
|
||||
|
||||
class CustomContext(Context):
|
||||
@@ -27,12 +27,15 @@ aiclient = Client(
|
||||
print(aiclient.available_models)
|
||||
|
||||
# Example usage
|
||||
conversation = aiclient.create_conversation(provider="openai")
|
||||
conversation = aiclient.create_conversation(provider="anthropic", context=ctx)
|
||||
response = aiclient.send_message(
|
||||
conversation, "Who is Kenneth Reitz?", provider="openai"
|
||||
conversation, "Who is Kenneth Reitz?", provider="anthropic"
|
||||
)
|
||||
print(response)
|
||||
# response = aiclient.send_message(
|
||||
# conversation, "Who is Kenneth Reitz?", provider="openai"
|
||||
# )
|
||||
# print(response)
|
||||
|
||||
reverse_chain = ReverseTextChain()
|
||||
result = reverse_chain.run("Hello, World!")
|
||||
print(result) # Output: !dlroW ,olleH
|
||||
# reverse_chain = ReverseTextChain()
|
||||
# result = reverse_chain.run("Hello, World!")
|
||||
# print(result) # Output: !dlroW ,olleH
|
||||
|
||||
@@ -47,22 +47,3 @@ message = "who is kenneth reitz?"
|
||||
|
||||
print(f"> {message}")
|
||||
pprint(openai.message(message, response_model=BioData))
|
||||
|
||||
# claude = simplemind.integrations.Anthropic()
|
||||
|
||||
# # print(claude.test_connection())
|
||||
# # print(claude.available_models)
|
||||
|
||||
# claude.login()
|
||||
|
||||
vector_store = FAISSStore(dimension=768) # Example dimension for embeddings
|
||||
|
||||
# Add embeddings
|
||||
embeddings = np.random.random((10, 768)).astype('float32')
|
||||
ids = [f"doc_{i}" for i in range(10)]
|
||||
vector_store.add_embeddings(embeddings, ids)
|
||||
|
||||
# Search
|
||||
query_embedding = np.random.random((1, 768)).astype('float32')
|
||||
results = vector_store.search(query_embedding, top_k=3)
|
||||
print(results)
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
|
||||
import simplemind
|
||||
|
||||
aiclient = simplemind.Ollama()
|
||||
|
||||
print('Messaging client')
|
||||
print("Messaging client")
|
||||
message_response = aiclient.message(message="Once upon a time in a land far away...")
|
||||
print(message_response)
|
||||
|
||||
print('Generating Text')
|
||||
print("Generating Text")
|
||||
generated_text = aiclient.generate_text(prompt="Once upon a time in a land far away...")
|
||||
print(generated_text)
|
||||
|
||||
print('Initiating Conversation')
|
||||
print("Initiating Conversation")
|
||||
conversation = aiclient.start_conversation()
|
||||
|
||||
# Add a message to the conversation
|
||||
@@ -27,4 +26,3 @@ conversation.say("What number did I ask you to remember?")
|
||||
# Get the AI's response
|
||||
reply = conversation.get_reply()
|
||||
print(reply)
|
||||
|
||||
|
||||
+10
-14
@@ -1,25 +1,21 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from simplemind.integrations.openai import OpenAI
|
||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
||||
|
||||
|
||||
class TestOpenAIProvider(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_key"
|
||||
|
||||
@patch("simplemind.integrations.openai.BaseOpenAI")
|
||||
def setUp(self, mock_openai):
|
||||
self.mock_openai = mock_openai.return_value
|
||||
self.mock_openai.models.list.return_value = [MagicMock(id="gpt-4")]
|
||||
self.provider = OpenAI(api_key="test_api_key", model="gpt-4")
|
||||
def test_initialization(self, mock_openai):
|
||||
provider = OpenAI(api_key=self.api_key)
|
||||
self.assertIsNotNone(provider.client)
|
||||
|
||||
def test_available_models(self):
|
||||
models = self.provider.available_models
|
||||
self.assertIn("gpt-4", models)
|
||||
|
||||
def test_test_connection_success(self):
|
||||
self.assertTrue(self.provider.test_connection())
|
||||
|
||||
def test_generate_response_not_implemented(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.provider.generate_response(None)
|
||||
def test_missing_api_key(self):
|
||||
with self.assertRaises(AuthenticationError):
|
||||
OpenAI(api_key=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user