Refactor conversation plugin hooks and add plugin interface

This commit is contained in:
2024-10-30 09:07:18 -04:00
parent 88e82d1ad1
commit 4d38ac02cc
5 changed files with 52 additions and 9 deletions
+14 -6
View File
@@ -9,7 +9,12 @@ import pickle
class ContextualMemoryPlugin:
def __init__(self, api_key: str, memory_file: str = "memories.pkl", embedding_model: str = "text-embedding-ada-002"):
def __init__(
self,
api_key: str,
memory_file: str = "memories.pkl",
embedding_model: str = "text-embedding-ada-002",
):
openai.api_key = api_key
self.memory_file = memory_file
self.embedding_model = embedding_model
@@ -35,29 +40,29 @@ class ContextualMemoryPlugin:
def build_faiss_index(self):
if self.embeddings:
self.index = faiss.IndexFlatL2(len(self.embeddings[0]))
self.index.add(np.array(self.embeddings).astype('float32'))
self.index.add(np.array(self.embeddings).astype("float32"))
else:
self.index = faiss.IndexFlatL2(1536)
def get_embedding(self, text: str) -> list:
response = openai.Embedding.create(input=text, model=self.embedding_model)
return response['data'][0]['embedding']
return response["data"][0]["embedding"]
def add_memory(self, memory: str):
embedding = self.get_embedding(memory)
self.memories.append(memory)
self.embeddings.append(embedding)
self.index.add(np.array([embedding]).astype('float32'))
self.index.add(np.array([embedding]).astype("float32"))
self.save_memories()
def retrieve_memories(self, query: str, top_k: int = 3) -> list:
if not self.index or len(self.embeddings) == 0:
return []
query_embedding = self.get_embedding(query)
D, I = self.index.search(np.array([query_embedding]).astype('float32'), top_k)
D, I = self.index.search(np.array([query_embedding]).astype("float32"), top_k)
return [self.memories[i] for i in I[0] if i < len(self.memories)]
def send_hook(self, conversation: sm.Conversation):
def pre_send_hook(self, conversation: sm.Conversation):
# Retrieve relevant memories based on the latest user message
if conversation.messages:
last_user_message = conversation.messages[-1].text
@@ -69,13 +74,16 @@ class ContextualMemoryPlugin:
# Optionally, add the AI's response to memories
self.add_memory(response)
# Example Usage
# Define a Pydantic model if needed
class Story(BaseModel):
title: str
content: str
# Initialize the conversation with the ContextualMemoryPlugin
memory_plugin = ContextualMemoryPlugin(api_key=sm.settings.OPENAI_API_KEY)
+31
View File
@@ -0,0 +1,31 @@
import simplemind as sm
class LoggingPlugin(sm.BasePlugin):
def pre_send_hook(self, conversation):
print(f"Sending conversation with {len(conversation.messages)} messages")
def add_message_hook(self, conversation, message):
print(f"Adding message to conversation: {message.text}")
def cleanup_hook(self, conversation):
print(f"Cleaning up conversation with {len(conversation.messages)} messages")
def initialize_hook(self, conversation):
print("Initializing conversation")
def post_send_hook(self, conversation, response):
print(f"Received response: {response.text}")
with sm.create_conversation() as conversation:
# Add the logging plugin.
conversation.add_plugin(LoggingPlugin())
# Add a message to the conversation.
conversation.add_message("user", "Hello!", meta={})
# Send the conversation.
response = conversation.send()
print(f"Response: {response.text}")
+3 -3
View File
@@ -1,7 +1,7 @@
from _context import sm
class SimpleMemoryPlugin:
class SimpleMemoryPlugin(sm.BasePlugin):
def __init__(self):
self.memories = [
"the earth has fictionally beeen destroyed.",
@@ -11,9 +11,9 @@ class SimpleMemoryPlugin:
def yield_memories(self):
return (m for m in self.memories)
def send_hook(self, conversation: sm.Conversation):
def initialize_hook(self, conversation: sm.Conversation):
for m in self.yield_memories():
conversation.add_message(role="system", text=m)
conversation.prepend_system_message(role="system", text=m)
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
+1
View File
@@ -52,4 +52,5 @@ __all__ = [
"generate_data",
"generate_text",
"settings",
"BasePlugin",
]
+3
View File
@@ -25,6 +25,9 @@ class SMBaseModel(BaseModel):
class BasePlugin(ABC):
"""The base conversation plugin class."""
# Plugin metadata.
meta: Dict[str, Any] = {}
# @abstractmethod
def initialize_hook(self, conversation: "Conversation"):
"""Initialize a hook for the plugin."""