mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
325 lines
11 KiB
Python
325 lines
11 KiB
Python
from datetime import datetime, timedelta
|
|
import logging
|
|
import sqlite3
|
|
from typing import List
|
|
|
|
import spacy
|
|
from contextlib import contextmanager
|
|
|
|
from _context import simplemind as sm
|
|
|
|
DB_PATH = "enhanced_context.db"
|
|
|
|
|
|
class EnhancedContextPlugin(sm.BasePlugin):
|
|
model_config = {"extra": "allow"}
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Initialize NLP model
|
|
try:
|
|
self.nlp = spacy.load("en_core_web_sm")
|
|
except OSError:
|
|
self.logger.error(
|
|
"Failed to load spaCy model. Please install it using: python -m spacy download en_core_web_sm"
|
|
)
|
|
raise
|
|
|
|
# Initialize database
|
|
self.init_db()
|
|
self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}")
|
|
|
|
# Load identity from database
|
|
self.personal_identity = None
|
|
self.load_identity()
|
|
|
|
@contextmanager
|
|
def get_connection(self):
|
|
"""Context manager for database connections"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
|
|
def init_db(self):
|
|
"""Initialize the database with proper schema"""
|
|
with self.get_connection() as conn:
|
|
# Create memory table for entities
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS memory (
|
|
entity TEXT PRIMARY KEY,
|
|
last_mentioned TIMESTAMP,
|
|
mention_count INTEGER DEFAULT 1
|
|
)
|
|
"""
|
|
)
|
|
|
|
# Create identity table
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS identity (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
last_updated TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
|
|
def store_entity(self, entity: str) -> None:
|
|
"""Store or update entity mention with error handling"""
|
|
try:
|
|
with self.get_connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO memory (entity, last_mentioned, mention_count)
|
|
VALUES (?, ?, 1)
|
|
ON CONFLICT(entity) DO UPDATE SET
|
|
last_mentioned = ?,
|
|
mention_count = mention_count + 1
|
|
""",
|
|
(entity, datetime.now(), datetime.now()),
|
|
)
|
|
self.logger.info(f"Stored entity: {entity}")
|
|
except sqlite3.Error as e:
|
|
self.logger.error(f"Database error while storing entity {entity}: {e}")
|
|
|
|
def retrieve_recent_entities(self, days: int = 7) -> List[str]:
|
|
"""Retrieve recently mentioned entities with frequency"""
|
|
try:
|
|
with self.get_connection() as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
SELECT entity, mention_count
|
|
FROM memory
|
|
WHERE last_mentioned >= datetime('now', ?)
|
|
ORDER BY mention_count DESC, last_mentioned DESC
|
|
LIMIT 5
|
|
""",
|
|
(f"-{days} days",),
|
|
)
|
|
|
|
entities = [(row[0], row[1]) for row in cur.fetchall()]
|
|
self.logger.info(f"Retrieved recent entities: {entities}")
|
|
return entities
|
|
except sqlite3.Error as e:
|
|
self.logger.error(f"Database error while retrieving entities: {e}")
|
|
return []
|
|
|
|
def extract_entities(self, text: str) -> List[str]:
|
|
"""Extract named entities with improved filtering"""
|
|
doc = self.nlp(text)
|
|
entities = []
|
|
|
|
# Define important entity types
|
|
important_types = {
|
|
"PERSON",
|
|
"ORG",
|
|
"GPE",
|
|
"NORP",
|
|
"PRODUCT",
|
|
"EVENT",
|
|
"WORK_OF_ART",
|
|
}
|
|
|
|
for ent in doc.ents:
|
|
if (
|
|
ent.label_ in important_types
|
|
and len(ent.text.strip()) > 1 # Avoid single characters
|
|
and not ent.text.isnumeric()
|
|
): # Avoid pure numbers
|
|
entities.append(ent.text.strip())
|
|
|
|
return list(set(entities)) # Remove duplicates
|
|
|
|
def format_context_message(
|
|
self, entities: List[tuple], include_identity: bool = True
|
|
) -> str:
|
|
"""Format context message more naturally"""
|
|
context_parts = []
|
|
|
|
# Add identity context if available and requested
|
|
if include_identity and self.personal_identity:
|
|
context_parts.append(f"You are speaking with {self.personal_identity}")
|
|
|
|
# Add entity context if available
|
|
if entities:
|
|
# Format entities with their mention counts
|
|
entity_strings = [
|
|
f"{entity} ({'mentioned multiple times' if count > 1 else 'mentioned recently'})"
|
|
for entity, count in entities
|
|
]
|
|
|
|
context_parts.append(
|
|
f"Previously discussed topic{'s' if len(entity_strings) > 1 else ''}: "
|
|
+ (
|
|
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
|
if len(entity_strings) > 1
|
|
else entity_strings[0]
|
|
)
|
|
)
|
|
|
|
return ". ".join(context_parts) + ("." if context_parts else "")
|
|
|
|
def extract_identity(self, text: str) -> str | None:
|
|
"""Extract identity statements like 'I am X'"""
|
|
text = text.lower().strip()
|
|
if text.startswith("i am ") or text.startswith("my name is "):
|
|
identity = text.replace("i am ", "").replace("my name is ", "").strip()
|
|
return identity if identity else None
|
|
return None
|
|
|
|
def is_identity_question(self, text: str) -> bool:
|
|
"""Check if the message is asking about identity"""
|
|
text = text.lower().strip()
|
|
identity_questions = {
|
|
"who am i",
|
|
"what's my name",
|
|
"what is my name",
|
|
"do you know who i am",
|
|
"do you know my name",
|
|
}
|
|
return text in identity_questions
|
|
|
|
def store_identity(self, identity: str) -> None:
|
|
"""Store personal identity in database"""
|
|
try:
|
|
with self.get_connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO identity (id, name, last_updated)
|
|
VALUES (1, ?, ?)
|
|
""",
|
|
(identity, datetime.now()),
|
|
)
|
|
conn.commit() # Add explicit commit
|
|
self.logger.info(f"Stored identity in database: {identity}")
|
|
except sqlite3.Error as e:
|
|
self.logger.error(f"Database error while storing identity: {e}")
|
|
|
|
def load_identity(self) -> str | None:
|
|
"""Load personal identity from database"""
|
|
try:
|
|
with self.get_connection() as conn:
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT name FROM identity WHERE id = 1")
|
|
result = cur.fetchone()
|
|
if result:
|
|
self.personal_identity = result[0]
|
|
self.logger.info(
|
|
f"Loaded identity from database: {self.personal_identity}"
|
|
)
|
|
else:
|
|
self.logger.info("No identity found in database")
|
|
return self.personal_identity
|
|
except sqlite3.Error as e:
|
|
self.logger.error(f"Database error while loading identity: {e}")
|
|
return None
|
|
|
|
def pre_send_hook(self, conversation: sm.Conversation):
|
|
last_message = conversation.get_last_message(role="user")
|
|
if not last_message:
|
|
return
|
|
|
|
self.logger.info(f"Processing message: {last_message.text}")
|
|
|
|
# Check for identity statements FIRST
|
|
identity = self.extract_identity(last_message.text)
|
|
if identity:
|
|
self.logger.info(f"Extracted identity: {identity}")
|
|
self.personal_identity = identity
|
|
self.store_identity(identity)
|
|
conversation.add_message(
|
|
role="assistant", text=f"I'll remember that your name is {identity}."
|
|
)
|
|
return False
|
|
|
|
# Handle identity questions
|
|
if self.is_identity_question(last_message.text):
|
|
self.load_identity() # Reload identity from database
|
|
conversation.add_message(
|
|
role="assistant",
|
|
text=(
|
|
f"You are {self.personal_identity}."
|
|
if self.personal_identity
|
|
else "I don't know your name yet. You can tell me by saying 'I am [your name]' or 'My name is [your name]'."
|
|
),
|
|
)
|
|
return False
|
|
|
|
# Extract and store entities
|
|
entities = self.extract_entities(last_message.text)
|
|
for entity in entities:
|
|
self.store_entity(entity)
|
|
self.logger.info(f"Stored entity: {entity}")
|
|
|
|
if not entities:
|
|
self.logger.info("No entities found in message")
|
|
|
|
# Add context message
|
|
recent_entities = self.retrieve_recent_entities()
|
|
context_message = self.format_context_message(recent_entities)
|
|
if context_message: # Only add if there's actual context to share
|
|
conversation.add_message(role="system", text=context_message)
|
|
self.logger.info(f"Added context message: {context_message}")
|
|
|
|
|
|
# Replace the example usage code at the bottom with this chat interface:
|
|
def main():
|
|
# Create a conversation and add the plugin
|
|
conversation = sm.create_conversation(llm_model="gpt-4", llm_provider="openai")
|
|
plugin = EnhancedContextPlugin()
|
|
conversation.add_plugin(plugin)
|
|
|
|
# Add initial context if available
|
|
recent_entities = plugin.retrieve_recent_entities()
|
|
context_message = plugin.format_context_message(recent_entities)
|
|
if context_message:
|
|
conversation.add_message(role="system", text=context_message)
|
|
plugin.logger.info(f"Added initial context message: {context_message}")
|
|
|
|
print("Chat interface ready! Type 'quit' to exit.")
|
|
print("-" * 50)
|
|
|
|
try:
|
|
while True:
|
|
# Get user input
|
|
user_input = input("\nYou: ").strip()
|
|
|
|
# Check for quit command
|
|
if user_input.lower() in ["quit", "exit", "q"]:
|
|
print("\nGoodbye!")
|
|
break
|
|
|
|
# Add user message and get response
|
|
conversation.add_message(role="user", text=user_input)
|
|
# Store the result of pre_send_hook
|
|
should_continue = plugin.pre_send_hook(conversation)
|
|
|
|
# Only send to LLM if pre_send_hook returns True or None
|
|
if should_continue is not False:
|
|
response = conversation.send()
|
|
print("\nAssistant:", response)
|
|
else:
|
|
# Get the last assistant message that was added by the plugin
|
|
response = conversation.get_last_message(role="assistant")
|
|
if response:
|
|
print("\nAssistant:", response.text)
|
|
|
|
print("-" * 50)
|
|
except KeyboardInterrupt:
|
|
print("\n\nGoodbye!")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|