From 4d38ac02cc987a2aaaca9d4c61894a9dd089e0b5 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Wed, 30 Oct 2024 09:07:18 -0400 Subject: [PATCH] Refactor conversation plugin hooks and add plugin interface --- examples/contextual_memory.py | 20 ++++++++++++++------ examples/simple_logging_plugin.py | 31 +++++++++++++++++++++++++++++++ examples/simple_memory.py | 6 +++--- simplemind/__init__.py | 1 + simplemind/models.py | 3 +++ 5 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 examples/simple_logging_plugin.py diff --git a/examples/contextual_memory.py b/examples/contextual_memory.py index 780fb40..5797213 100644 --- a/examples/contextual_memory.py +++ b/examples/contextual_memory.py @@ -9,7 +9,12 @@ import pickle class ContextualMemoryPlugin: - def __init__(self, api_key: str, memory_file: str = "memories.pkl", embedding_model: str = "text-embedding-ada-002"): + def __init__( + self, + api_key: str, + memory_file: str = "memories.pkl", + embedding_model: str = "text-embedding-ada-002", + ): openai.api_key = api_key self.memory_file = memory_file self.embedding_model = embedding_model @@ -35,29 +40,29 @@ class ContextualMemoryPlugin: def build_faiss_index(self): if self.embeddings: self.index = faiss.IndexFlatL2(len(self.embeddings[0])) - self.index.add(np.array(self.embeddings).astype('float32')) + self.index.add(np.array(self.embeddings).astype("float32")) else: self.index = faiss.IndexFlatL2(1536) def get_embedding(self, text: str) -> list: response = openai.Embedding.create(input=text, model=self.embedding_model) - return response['data'][0]['embedding'] + return response["data"][0]["embedding"] def add_memory(self, memory: str): embedding = self.get_embedding(memory) self.memories.append(memory) self.embeddings.append(embedding) - self.index.add(np.array([embedding]).astype('float32')) + self.index.add(np.array([embedding]).astype("float32")) self.save_memories() def retrieve_memories(self, query: str, top_k: int = 3) -> list: if not self.index or len(self.embeddings) == 0: return [] query_embedding = self.get_embedding(query) - D, I = self.index.search(np.array([query_embedding]).astype('float32'), top_k) + D, I = self.index.search(np.array([query_embedding]).astype("float32"), top_k) return [self.memories[i] for i in I[0] if i < len(self.memories)] - def send_hook(self, conversation: sm.Conversation): + def pre_send_hook(self, conversation: sm.Conversation): # Retrieve relevant memories based on the latest user message if conversation.messages: last_user_message = conversation.messages[-1].text @@ -69,13 +74,16 @@ class ContextualMemoryPlugin: # Optionally, add the AI's response to memories self.add_memory(response) + # Example Usage + # Define a Pydantic model if needed class Story(BaseModel): title: str content: str + # Initialize the conversation with the ContextualMemoryPlugin memory_plugin = ContextualMemoryPlugin(api_key=sm.settings.OPENAI_API_KEY) diff --git a/examples/simple_logging_plugin.py b/examples/simple_logging_plugin.py new file mode 100644 index 0000000..c966f61 --- /dev/null +++ b/examples/simple_logging_plugin.py @@ -0,0 +1,31 @@ +import simplemind as sm + + +class LoggingPlugin(sm.BasePlugin): + def pre_send_hook(self, conversation): + print(f"Sending conversation with {len(conversation.messages)} messages") + + def add_message_hook(self, conversation, message): + print(f"Adding message to conversation: {message.text}") + + def cleanup_hook(self, conversation): + print(f"Cleaning up conversation with {len(conversation.messages)} messages") + + def initialize_hook(self, conversation): + print("Initializing conversation") + + def post_send_hook(self, conversation, response): + print(f"Received response: {response.text}") + + +with sm.create_conversation() as conversation: + # Add the logging plugin. + conversation.add_plugin(LoggingPlugin()) + + # Add a message to the conversation. + conversation.add_message("user", "Hello!", meta={}) + + # Send the conversation. + response = conversation.send() + +print(f"Response: {response.text}") diff --git a/examples/simple_memory.py b/examples/simple_memory.py index 12efd10..c2974da 100644 --- a/examples/simple_memory.py +++ b/examples/simple_memory.py @@ -1,7 +1,7 @@ from _context import sm -class SimpleMemoryPlugin: +class SimpleMemoryPlugin(sm.BasePlugin): def __init__(self): self.memories = [ "the earth has fictionally beeen destroyed.", @@ -11,9 +11,9 @@ class SimpleMemoryPlugin: def yield_memories(self): return (m for m in self.memories) - def send_hook(self, conversation: sm.Conversation): + def initialize_hook(self, conversation: sm.Conversation): for m in self.yield_memories(): - conversation.add_message(role="system", text=m) + conversation.prepend_system_message(role="system", text=m) conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai") diff --git a/simplemind/__init__.py b/simplemind/__init__.py index bfa7045..29da250 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -52,4 +52,5 @@ __all__ = [ "generate_data", "generate_text", "settings", + "BasePlugin", ] diff --git a/simplemind/models.py b/simplemind/models.py index a031d31..0327d46 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -25,6 +25,9 @@ class SMBaseModel(BaseModel): class BasePlugin(ABC): """The base conversation plugin class.""" + # Plugin metadata. + meta: Dict[str, Any] = {} + # @abstractmethod def initialize_hook(self, conversation: "Conversation"): """Initialize a hook for the plugin."""