chore: Add base classes for chains and agents

This commit is contained in:
2024-10-28 08:41:10 -04:00
parent 146467e117
commit 45e0eb5175
24 changed files with 454 additions and 179 deletions
+13
View File
@@ -104,3 +104,16 @@ SimpleMind is inspired by the philosophy of "code for humans" and aims to make w
---
SimpleMind: Keep it simple, keep it human.
------------------------
## Plugins
SimpleMind supports a plugin system to extend its functionality. Currently available plugins:
- **KVPlugin**: Key-Value storage for context management.
- **BasicMemoryPlugin**: Simple memory storage for conversations.
**Adding a Plugin:**
View File
+7
View File
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class BaseAgent(ABC):
@abstractmethod
def decide(self, context, *args, **kwargs):
pass
View File
+7
View File
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class BaseChain(ABC):
@abstractmethod
def run(self, input_data):
pass
+6
View File
@@ -0,0 +1,6 @@
from .base import BaseChain
class ReverseTextChain(BaseChain):
def run(self, input_data):
return input_data[::-1]
+35
View File
@@ -0,0 +1,35 @@
from typing import Optional
from simplemind.models import Conversation, AIResponse
from simplemind.concepts import Context
from simplemind.integrations.openai import OpenAI
from simplemind.integrations.anthropic import Anthropic
class Client:
def __init__(self, api_key: str, context: Optional[Context] = None):
self.api_key = api_key
self.context = context or Context()
self.providers = self._initialize_providers()
def _initialize_providers(self):
return {
"openai": OpenAI(api_key=self.api_key),
"anthropic": Anthropic(api_key=self.api_key),
}
def create_conversation(self, provider: str = "openai") -> Conversation:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
return self.providers[provider].create_conversation(initial_message="Hello!", context=self.context.dict())
def send_message(self, conversation: Conversation, message: str, provider: str = "openai") -> AIResponse:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
return self.providers[provider].send_message(conversation.id, message)
@property
def available_models(self):
available = {}
for name, provider in self.providers.items():
available[name] = provider.available_models
return available
+14 -4
View File
@@ -1,6 +1,16 @@
class Context:
def __init__(self):
self.plugins = [kv, basic_memory]
from pydantic import BaseModel
from typing import Dict, Any
from simplemind.plugins.base import BasePlugin
# TODO: explore pluggy for this.
class Context(BaseModel):
plugins: Dict[str, BasePlugin] = {}
def add_plugin(self, name: str, plugin: BasePlugin):
self.plugins[name] = plugin
def execute_plugin(self, name: str, *args, **kwargs):
if name in self.plugins:
return self.plugins[name].execute(self, *args, **kwargs)
else:
raise ValueError(f"Plugin '{name}' not found in context.")
+13
View File
@@ -0,0 +1,13 @@
from pydantic import BaseSettings
class Settings(BaseSettings):
openai_api_key: str = ""
anthropic_api_key: str = ""
default_model: str = "gpt-4"
class Config:
env_file = ".env"
settings = Settings()
-21
View File
@@ -1,21 +0,0 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Any
app = FastAPI(title="SimpleMind AI API", description="AI for humans, replacing LangGraph and LangChain for Python users.")
@app.post("/generate", response_model=AIResponse)
def generate_response(request: AIRequest):
try:
# Placeholder for AI generation logic
response = {"message": "This would be the AI response."}
metadata = {"tokens_used": 50}
return AIResponse(response=response, metadata=metadata)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
return {"status": "healthy"}
+49 -21
View File
@@ -1,41 +1,69 @@
import os
from typing import List, Optional
import instructor
from anthropic import Anthropic as BaseAnthropic
from .base import BaseClientProvider
from ..models import AIResponse, Conversation
from ..logger import logger
class Anthropic(BaseClientProvider):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
super().__init__(model=model, api_key=api_key)
self.login()
def login(self):
"""Initialize Anthropic client, with Instructor enabled."""
# Default to environment variable if not provided.
if self._api_key is None:
if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
if not self._api_key:
raise ValueError("Anthropic API key not provided.")
base_client = BaseAnthropic(api_key=self._api_key)
self.client = instructor.from_anthropic(base_client)
# assert self.test_connection()
if not self.test_connection():
raise ConnectionError("Failed to connect to Anthropic API.")
logger.info("Logged in to Anthropic successfully.")
@property
def available_models(self):
"""Returns the available models from the Anthropic client."""
def available_models(self) -> List[str]:
try:
return [
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
]
except Exception as e:
logger.error(f"Error fetching models: {e}")
return []
# TODO: scrape from website or embed
return [
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20240620",
def test_connection(self) -> bool:
models = self.available_models
if models:
logger.info(f"Available models: {models}")
return True
logger.warning("No available models found.")
return False
def generate_response(self, conversation: Conversation) -> AIResponse:
messages = [
{"role": msg.role, "content": msg.content} for msg in conversation.messages
]
params = {
"messages": messages,
"model": self.model,
}
if conversation.context:
params["context"] = conversation.context
# def test_connection(self):
# """Test the connection to Anthropic. Returns True if successful."""
# raise NotImplementedError("Anthropic test_connection not implemented.")
try:
completion = self.client.completions.create(**params)
response_text = completion.completion
metadata = {"model": completion.model, "usage": completion.usage}
logger.info("Generated response from Anthropic.")
return AIResponse(
text=response_text, response=completion, metadata=metadata
)
except Exception as e:
logger.error(f"Error generating response: {e}")
raise e
+63 -20
View File
@@ -1,6 +1,10 @@
# import logging
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from ..models import AIResponse, Conversation, Message
import uuid
from abc import ABC, abstractmethod
DEFAULT_MODEL = "gpt-4o"
@@ -8,55 +12,46 @@ DEFAULT_MODEL = "gpt-4o"
class BaseClientProvider:
def __init__(self, *, model=DEFAULT_MODEL, api_key=None):
def __init__(self, *, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
# self.logger = logging.getLogger(self.__class__.__name__)
self.client = None
self.model = model
# Load API key from environment if not provided
self._api_key = api_key
self.conversations: Dict[str, Conversation] = {}
@abstractmethod
def login(self):
"""Initializes the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def test_connection(self):
@abstractmethod
def test_connection(self) -> bool:
"""Tests the connection to the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
# def generate_response(self, request):
# """Generates a response from the AI provider client."""
# msg = "This method must be implemented by the AI provider client."
# raise NotImplementedError(msg)
@abstractmethod
def health_check(self):
"""Checks the health of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
@property
def available_models(self):
@abstractmethod
def available_models(self) -> List[str]:
"""Returns the available models from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def message(self, message, **kwargs):
@abstractmethod
def message(self, message: str, **kwargs) -> AIResponse:
"""Generates a response from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
# Uncomment and implement additional methods as needed
# def features(self):
# """Returns the features of the AI provider client."""
# msg = "This method must be implemented by the AI provider client."
# raise NotImplementedError(msg)
@@ -71,3 +66,51 @@ class BaseClientProvider:
# def start_conversation(self, model, message, **kwargs):
# pass
def create_conversation(
self, initial_message: str, context: Optional[Dict[str, Any]] = None
) -> Conversation:
conv_id = str(uuid.uuid4())
conversation = Conversation(
id=conv_id,
messages=[Message(role="user", content=initial_message)],
context=context or {},
)
self.conversations[conv_id] = conversation
return conversation
def send_message(
self,
conversation_id: str,
message: str,
context_update: Optional[Dict[str, Any]] = None,
) -> AIResponse:
if conversation_id not in self.conversations:
raise ValueError("Conversation ID does not exist.")
conversation = self.conversations[conversation_id]
conversation.messages.append(Message(role="user", content=message))
if context_update:
conversation.context.update(context_update)
response = self.generate_response(conversation)
conversation.messages.append(Message(role="assistant", content=response.text))
return response
def generate_response(self, conversation: Conversation) -> AIResponse:
"""Generates a response based on the conversation."""
raise NotImplementedError(
"This method must be implemented by the AI provider client."
)
def get_conversation(self, conversation_id: str) -> Conversation:
if conversation_id not in self.conversations:
raise ValueError("Conversation ID does not exist.")
return self.conversations[conversation_id]
class BasePlugin(ABC):
@abstractmethod
def execute(self, context, *args, **kwargs):
pass
+47 -48
View File
@@ -1,68 +1,67 @@
import os
from typing import Optional, List
import instructor
from openai import OpenAI as BaseOpenAI
from .base import BaseClientProvider
from ..models import AIResponse
from ..models import AIResponse, Conversation
from ..logger import logger
from simplemind.config import settings
class OpenAI(BaseClientProvider):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
super().__init__(model=model, api_key=api_key)
self.login()
def login(self):
"""Initialize OpenAI client, with Instructor enabled."""
# Default to environment variable if not provided.
if self._api_key is None:
self._api_key = os.getenv("OPENAI_API_KEY")
if not self._api_key:
self._api_key = settings.openai_api_key
if not self._api_key:
raise ValueError("OpenAI API key not provided.")
self.client = BaseOpenAI(api_key=self._api_key)
self.instructor_client = instructor.from_openai(self.client)
assert self.test_connection()
if not self.test_connection():
raise ConnectionError("Failed to connect to OpenAI API.")
logger.info("Logged in to OpenAI successfully.")
@property
def available_models(self):
"""Returns the available models from the OpenAI client."""
def available_models(self) -> List[str]:
try:
return [model.id for model in self.client.models.list()]
except Exception as e:
logger.error(f"Error fetching models: {e}")
return []
def gen():
for model in self.client.models.list():
yield model.id
def test_connection(self) -> bool:
try:
models = self.available_models
if models:
logger.info(f"Available models: {models}")
return True
else:
logger.warning("No available models found.")
return False
except Exception as e:
logger.error(f"Error testing connection: {e}")
return False
return [g for g in gen()]
def test_connection(self):
"""Test the connection to OpenAI. Returns True if successful."""
return bool(len(self.available_models))
def message(self, message, *, response_model=False, **kwargs):
"""Generates a response from the OpenAI client."""
use_instructor = bool(response_model)
client = self.instructor_client if use_instructor else self.client
# Parameters for the OpenAI client.
def generate_response(self, conversation: Conversation) -> AIResponse:
messages = [
{"role": msg.role, "content": msg.content} for msg in conversation.messages
]
params = {
"messages": [{"role": "user", "content": message}],
"messages": messages,
"model": self.model,
}
params.update(kwargs)
if conversation.context:
params["context"] = conversation.context
if use_instructor:
params["response_model"] = response_model
# Make the request to OpenAI.
completion = client.chat.completions.create(**params)
if use_instructor:
return completion.model_dump()
else:
return AIResponse(
response=completion,
text=completion.choices[0].message.content,
)
try:
completion = self.client.chat.completions.create(**params)
response_text = completion.choices[0].message.content
metadata = {"model": completion.model, "usage": completion.usage}
logger.info("Generated response from OpenAI.")
return AIResponse(text=response_text, response=completion, metadata=metadata)
except Exception as e:
logger.error(f"Error generating response: {e}")
raise e
+15
View File
@@ -0,0 +1,15 @@
import logging
def setup_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if not logger.hasHandlers():
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Initialize a global logger
logger = setup_logger("simplemind")
+27 -3
View File
@@ -1,10 +1,11 @@
from pydantic import BaseModel
from typing import Any, ClassVar
from typing import Any, Dict, List, Optional
import uuid
class AIRequest(BaseModel):
text: str
parameters: dict = {}
parameters: Dict[str, Any] = {}
def __str__(self):
return self.text
@@ -13,7 +14,30 @@ class AIRequest(BaseModel):
class AIResponse(BaseModel):
text: str
response: Any
metadata: dict = {}
metadata: Dict[str, Any] = {}
def __str__(self):
return self.text
class Message(BaseModel):
role: str # "user" or "assistant"
content: str
class Conversation(BaseModel):
id: str
messages: List[Message] = []
context: Optional[Dict[str, Any]] = {}
class ConversationRequest(BaseModel):
conversation_id: Optional[str] = None
message: str
context_update: Optional[Dict[str, Any]] = None
class ConversationResponse(BaseModel):
conversation_id: str
messages: List[Message]
metadata: Dict[str, Any] = {}
View File
+10
View File
@@ -0,0 +1,10 @@
from .base import BasePlugin
class BasicMemoryPlugin(BasePlugin):
def __init__(self):
self.memory = []
def execute(self, context, message):
self.memory.append(message)
return self.memory
+10
View File
@@ -0,0 +1,10 @@
from .base import BasePlugin
class KVPlugin(BasePlugin):
def __init__(self):
self.store = {}
def execute(self, context, key, value):
self.store[key] = value
return self.store
View File
+19
View File
@@ -0,0 +1,19 @@
import faiss
import numpy as np
from typing import List
class FAISSStore:
def __init__(self, dimension: int):
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.ids = []
def add_embeddings(self, embeddings: np.ndarray, ids: List[str]):
self.index.add(embeddings)
self.ids.extend(ids)
def search(self, query_embedding: np.ndarray, top_k: int = 5):
distances, indices = self.index.search(query_embedding, top_k)
results = [(self.ids[idx], distances[i]) for i, idx in enumerate(indices[0])]
return results
+24 -44
View File
@@ -1,54 +1,34 @@
from pprint import pprint
from pydantic import BaseModel
import simplemind
context = None
openai = simplemind.integrations.OpenAI()
from simplemind.concepts import Context
from simplemind.plugins.kv import KVPlugin
from simplemind.plugins.basic_memory import BasicMemoryPlugin
from simplemind.chains.reverse_text import ReverseTextChain
from simplemind.client import Client
class YearlyData(BaseModel):
year: int
events: list[str]
class MyContext(Context):
def __init__(self):
super().__init__()
self.add_plugin("kv", KVPlugin())
self.add_plugin("basic_memory", BasicMemoryPlugin())
class ProjectData(BaseModel):
name: str
description: str
url: str
github_url: str
# Initialize context and client
context = MyContext()
aiclient = Client(api_key="YOUR_API_KEY", context=context)
# Test connection and available models
print(aiclient.available_models)
class BioData(BaseModel):
bio: str
spouse_name: str
history: list[YearlyData]
fun_facts: list[str]
# age: int
# occupation: str
# bio: str
# affiliations: list[str]
# Example usage
conversation = aiclient.create_conversation(provider="openai")
response = aiclient.send_message(
conversation, "Who is Kenneth Reitz?", provider="openai"
)
print(response)
class PersonData(BaseModel):
bio: BioData
projects: list[ProjectData]
yearly_breakdown: list[YearlyData]
print(openai.test_connection())
print(openai.available_models)
print()
print()
message = "who is kenneth reitz?"
print(f"> {message}")
pprint(openai.message(message, response_model=BioData))
# claude = simplemind.integrations.Anthropic()
# # print(claude.test_connection())
# # print(claude.available_models)
# claude.login()
reverse_chain = ReverseTextChain()
result = reverse_chain.run("Hello, World!")
print(result) # Output: !dlroW ,olleH
+54 -18
View File
@@ -1,32 +1,68 @@
import instructor
from pprint import pprint
from pydantic import BaseModel
from openai import OpenAI
import simplemind
from simplemind.vector_store.faiss_store import FAISSStore
import numpy as np
context = None
openai = simplemind.integrations.OpenAI()
class ProjectInfo(BaseModel):
class YearlyData(BaseModel):
year: int
events: list[str]
class ProjectData(BaseModel):
name: str
description: str
url: str
github_url: str
# Define your desired output structure
class UserInfo(BaseModel):
name: str
age: int
class BioData(BaseModel):
bio: str
projects: list[ProjectInfo]
spouse_name: str
history: list[YearlyData]
fun_facts: list[str]
# age: int
# occupation: str
# bio: str
# affiliations: list[str]
# Patch the OpenAI client
client = instructor.from_openai(OpenAI())
class PersonData(BaseModel):
bio: BioData
projects: list[ProjectData]
yearly_breakdown: list[YearlyData]
# Extract structured data from natural language
user_info = client.chat.completions.create(
model="gpt-4o",
response_model=UserInfo,
messages=[{"role": "user", "content": "who is kennethreitz?"}],
)
print(user_info.model_dump())
# > 30
print(openai.test_connection())
print(openai.available_models)
print()
print()
message = "who is kenneth reitz?"
print(f"> {message}")
pprint(openai.message(message, response_model=BioData))
# claude = simplemind.integrations.Anthropic()
# # print(claude.test_connection())
# # print(claude.available_models)
# claude.login()
vector_store = FAISSStore(dimension=768) # Example dimension for embeddings
# Add embeddings
embeddings = np.random.random((10, 768)).astype('float32')
ids = [f"doc_{i}" for i in range(10)]
vector_store.add_embeddings(embeddings, ids)
# Search
query_embedding = np.random.random((1, 768)).astype('float32')
results = vector_store.search(query_embedding, top_k=3)
print(results)
+15
View File
@@ -0,0 +1,15 @@
from simplemind.concepts import Context
# from simplemind.plugins.default_plugin import DefaultPlugin
# Initialize the context
ctx = Context()
# Add and initialize the DefaultPlugin
# ctx.add_plugin(DefaultPlugin, "DefaultPlugin")
# Execute the DefaultPlugin with some data
# ctx.execute_plugin("DefaultPlugin", {"key": "value"})
# Shutdown all plugins
# ctx.shutdown_plugins()
+26
View File
@@ -0,0 +1,26 @@
import unittest
from unittest.mock import patch, MagicMock
from simplemind.integrations.openai import OpenAI
class TestOpenAIProvider(unittest.TestCase):
@patch("simplemind.integrations.openai.BaseOpenAI")
def setUp(self, mock_openai):
self.mock_openai = mock_openai.return_value
self.mock_openai.models.list.return_value = [MagicMock(id="gpt-4")]
self.provider = OpenAI(api_key="test_api_key", model="gpt-4")
def test_available_models(self):
models = self.provider.available_models
self.assertIn("gpt-4", models)
def test_test_connection_success(self):
self.assertTrue(self.provider.test_connection())
def test_generate_response_not_implemented(self):
with self.assertRaises(NotImplementedError):
self.provider.generate_response(None)
if __name__ == "__main__":
unittest.main()