mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor conversation plugin hooks and add plugin interface
This commit is contained in:
@@ -9,7 +9,12 @@ import pickle
|
||||
|
||||
|
||||
class ContextualMemoryPlugin:
|
||||
def __init__(self, api_key: str, memory_file: str = "memories.pkl", embedding_model: str = "text-embedding-ada-002"):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
memory_file: str = "memories.pkl",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
):
|
||||
openai.api_key = api_key
|
||||
self.memory_file = memory_file
|
||||
self.embedding_model = embedding_model
|
||||
@@ -35,29 +40,29 @@ class ContextualMemoryPlugin:
|
||||
def build_faiss_index(self):
|
||||
if self.embeddings:
|
||||
self.index = faiss.IndexFlatL2(len(self.embeddings[0]))
|
||||
self.index.add(np.array(self.embeddings).astype('float32'))
|
||||
self.index.add(np.array(self.embeddings).astype("float32"))
|
||||
else:
|
||||
self.index = faiss.IndexFlatL2(1536)
|
||||
|
||||
def get_embedding(self, text: str) -> list:
|
||||
response = openai.Embedding.create(input=text, model=self.embedding_model)
|
||||
return response['data'][0]['embedding']
|
||||
return response["data"][0]["embedding"]
|
||||
|
||||
def add_memory(self, memory: str):
|
||||
embedding = self.get_embedding(memory)
|
||||
self.memories.append(memory)
|
||||
self.embeddings.append(embedding)
|
||||
self.index.add(np.array([embedding]).astype('float32'))
|
||||
self.index.add(np.array([embedding]).astype("float32"))
|
||||
self.save_memories()
|
||||
|
||||
def retrieve_memories(self, query: str, top_k: int = 3) -> list:
|
||||
if not self.index or len(self.embeddings) == 0:
|
||||
return []
|
||||
query_embedding = self.get_embedding(query)
|
||||
D, I = self.index.search(np.array([query_embedding]).astype('float32'), top_k)
|
||||
D, I = self.index.search(np.array([query_embedding]).astype("float32"), top_k)
|
||||
return [self.memories[i] for i in I[0] if i < len(self.memories)]
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
# Retrieve relevant memories based on the latest user message
|
||||
if conversation.messages:
|
||||
last_user_message = conversation.messages[-1].text
|
||||
@@ -69,13 +74,16 @@ class ContextualMemoryPlugin:
|
||||
# Optionally, add the AI's response to memories
|
||||
self.add_memory(response)
|
||||
|
||||
|
||||
# Example Usage
|
||||
|
||||
|
||||
# Define a Pydantic model if needed
|
||||
class Story(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
# Initialize the conversation with the ContextualMemoryPlugin
|
||||
memory_plugin = ContextualMemoryPlugin(api_key=sm.settings.OPENAI_API_KEY)
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class LoggingPlugin(sm.BasePlugin):
|
||||
def pre_send_hook(self, conversation):
|
||||
print(f"Sending conversation with {len(conversation.messages)} messages")
|
||||
|
||||
def add_message_hook(self, conversation, message):
|
||||
print(f"Adding message to conversation: {message.text}")
|
||||
|
||||
def cleanup_hook(self, conversation):
|
||||
print(f"Cleaning up conversation with {len(conversation.messages)} messages")
|
||||
|
||||
def initialize_hook(self, conversation):
|
||||
print("Initializing conversation")
|
||||
|
||||
def post_send_hook(self, conversation, response):
|
||||
print(f"Received response: {response.text}")
|
||||
|
||||
|
||||
with sm.create_conversation() as conversation:
|
||||
# Add the logging plugin.
|
||||
conversation.add_plugin(LoggingPlugin())
|
||||
|
||||
# Add a message to the conversation.
|
||||
conversation.add_message("user", "Hello!", meta={})
|
||||
|
||||
# Send the conversation.
|
||||
response = conversation.send()
|
||||
|
||||
print(f"Response: {response.text}")
|
||||
@@ -1,7 +1,7 @@
|
||||
from _context import sm
|
||||
|
||||
|
||||
class SimpleMemoryPlugin:
|
||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally beeen destroyed.",
|
||||
@@ -11,9 +11,9 @@ class SimpleMemoryPlugin:
|
||||
def yield_memories(self):
|
||||
return (m for m in self.memories)
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
def initialize_hook(self, conversation: sm.Conversation):
|
||||
for m in self.yield_memories():
|
||||
conversation.add_message(role="system", text=m)
|
||||
conversation.prepend_system_message(role="system", text=m)
|
||||
|
||||
|
||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
||||
|
||||
@@ -52,4 +52,5 @@ __all__ = [
|
||||
"generate_data",
|
||||
"generate_text",
|
||||
"settings",
|
||||
"BasePlugin",
|
||||
]
|
||||
|
||||
@@ -25,6 +25,9 @@ class SMBaseModel(BaseModel):
|
||||
class BasePlugin(ABC):
|
||||
"""The base conversation plugin class."""
|
||||
|
||||
# Plugin metadata.
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# @abstractmethod
|
||||
def initialize_hook(self, conversation: "Conversation"):
|
||||
"""Initialize a hook for the plugin."""
|
||||
|
||||
Reference in New Issue
Block a user