Refactor enhanced_context.py and update requirements.txt

This commit is contained in:
2024-11-06 08:30:29 -05:00
parent 03f91c5153
commit 07715ed8df
+268 -73
View File
@@ -1,8 +1,10 @@
from datetime import datetime
from datetime import datetime, timedelta
import logging
import sqlite3
from typing import List
import spacy
import sqlite3
from contextlib import contextmanager
from _context import simplemind as sm
@@ -10,119 +12,312 @@ DB_PATH = "enhanced_context.db"
class EnhancedContextPlugin(sm.BasePlugin):
# Add model configuration to allow arbitrary attributes
model_config = {"extra": "allow"}
def __init__(self):
super().__init__() # Don't forget to call parent's __init__
super().__init__()
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
self.logger = logging.getLogger(__name__)
# Initialize NLP model and memory database
self.nlp = spacy.load("en_core_web_sm")
self.conn = sqlite3.connect(DB_PATH)
# Initialize NLP model
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
self.logger.error(
"Failed to load spaCy model. Please install it using: python -m spacy download en_core_web_sm"
)
raise
# Initialize database
self.init_db()
self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}")
# Load identity from database
self.personal_identity = None
self.load_identity()
@contextmanager
def get_connection(self):
"""Context manager for database connections"""
conn = sqlite3.connect(DB_PATH)
try:
yield conn
finally:
conn.close()
def init_db(self):
# Create a table to store entities and their last mention time
with self.conn:
self.conn.execute(
"""Initialize the database with proper schema"""
with self.get_connection() as conn:
# Create memory table for entities
conn.execute(
"""
CREATE TABLE IF NOT EXISTS memory (
entity TEXT PRIMARY KEY,
last_mentioned TIMESTAMP
last_mentioned TIMESTAMP,
mention_count INTEGER DEFAULT 1
)
"""
)
def store_entity(self, entity):
# Store or update entity mention time
with self.conn:
self.conn.execute(
# Create identity table
conn.execute(
"""
INSERT OR REPLACE INTO memory (entity, last_mentioned)
VALUES (?, ?)
""",
(entity, datetime.now()),
)
self.logger.info(f"Stored entity: {entity}")
def retrieve_recent_entities(self):
# Retrieve entities mentioned in the last 7 days
cur = self.conn.cursor()
cur.execute(
CREATE TABLE IF NOT EXISTS identity (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
last_updated TIMESTAMP
)
"""
SELECT entity FROM memory
WHERE last_mentioned >= datetime('now', '-7 days')
"""
)
entities = [row[0] for row in cur.fetchall()]
self.logger.info(f"Retrieved recent entities: {entities}")
return entities
)
def extract_entities(self, text):
# Extract entities (people, places, organizations) from text
def store_entity(self, entity: str) -> None:
"""Store or update entity mention with error handling"""
try:
with self.get_connection() as conn:
conn.execute(
"""
INSERT INTO memory (entity, last_mentioned, mention_count)
VALUES (?, ?, 1)
ON CONFLICT(entity) DO UPDATE SET
last_mentioned = ?,
mention_count = mention_count + 1
""",
(entity, datetime.now(), datetime.now()),
)
self.logger.info(f"Stored entity: {entity}")
except sqlite3.Error as e:
self.logger.error(f"Database error while storing entity {entity}: {e}")
def retrieve_recent_entities(self, days: int = 7) -> List[str]:
"""Retrieve recently mentioned entities with frequency"""
try:
with self.get_connection() as conn:
cur = conn.cursor()
cur.execute(
"""
SELECT entity, mention_count
FROM memory
WHERE last_mentioned >= datetime('now', ?)
ORDER BY mention_count DESC, last_mentioned DESC
LIMIT 5
""",
(f"-{days} days",),
)
entities = [(row[0], row[1]) for row in cur.fetchall()]
self.logger.info(f"Retrieved recent entities: {entities}")
return entities
except sqlite3.Error as e:
self.logger.error(f"Database error while retrieving entities: {e}")
return []
def extract_entities(self, text: str) -> List[str]:
"""Extract named entities with improved filtering"""
doc = self.nlp(text)
return [
ent.text
for ent in doc.ents
if ent.label_ in {"PERSON", "ORG", "GPE", "NORP"}
]
entities = []
# Define important entity types
important_types = {
"PERSON",
"ORG",
"GPE",
"NORP",
"PRODUCT",
"EVENT",
"WORK_OF_ART",
}
for ent in doc.ents:
if (
ent.label_ in important_types
and len(ent.text.strip()) > 1 # Avoid single characters
and not ent.text.isnumeric()
): # Avoid pure numbers
entities.append(ent.text.strip())
return list(set(entities)) # Remove duplicates
def format_context_message(
self, entities: List[tuple], include_identity: bool = True
) -> str:
"""Format context message more naturally"""
context_parts = []
# Add identity context if available and requested
if include_identity and self.personal_identity:
context_parts.append(f"You are speaking with {self.personal_identity}")
# Add entity context if available
if entities:
# Format entities with their mention counts
entity_strings = [
f"{entity} ({'mentioned multiple times' if count > 1 else 'mentioned recently'})"
for entity, count in entities
]
context_parts.append(
f"Previously discussed topic{'s' if len(entity_strings) > 1 else ''}: "
+ (
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
if len(entity_strings) > 1
else entity_strings[0]
)
)
return ". ".join(context_parts) + ("." if context_parts else "")
def extract_identity(self, text: str) -> str | None:
"""Extract identity statements like 'I am X'"""
text = text.lower().strip()
if text.startswith("i am ") or text.startswith("my name is "):
identity = text.replace("i am ", "").replace("my name is ", "").strip()
return identity if identity else None
return None
def is_identity_question(self, text: str) -> bool:
"""Check if the message is asking about identity"""
text = text.lower().strip()
identity_questions = {
"who am i",
"what's my name",
"what is my name",
"do you know who i am",
"do you know my name",
}
return text in identity_questions
def store_identity(self, identity: str) -> None:
"""Store personal identity in database"""
try:
with self.get_connection() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO identity (id, name, last_updated)
VALUES (1, ?, ?)
""",
(identity, datetime.now()),
)
conn.commit() # Add explicit commit
self.logger.info(f"Stored identity in database: {identity}")
except sqlite3.Error as e:
self.logger.error(f"Database error while storing identity: {e}")
def load_identity(self) -> str | None:
"""Load personal identity from database"""
try:
with self.get_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT name FROM identity WHERE id = 1")
result = cur.fetchone()
if result:
self.personal_identity = result[0]
self.logger.info(
f"Loaded identity from database: {self.personal_identity}"
)
else:
self.logger.info("No identity found in database")
return self.personal_identity
except sqlite3.Error as e:
self.logger.error(f"Database error while loading identity: {e}")
return None
def pre_send_hook(self, conversation: sm.Conversation):
# Process the latest user message
last_message = conversation.get_last_message(role="user")
if last_message:
self.logger.info(f"Processing message: {last_message.text}")
if not last_message:
return
# Extract entities and store in memory
entities = self.extract_entities(last_message.text)
if entities:
self.logger.info(f"Extracted entities: {entities}")
for entity in entities:
self.store_entity(entity)
else:
self.logger.info("No entities found in message")
self.logger.info(f"Processing message: {last_message.text}")
# Retrieve recent entities for context
recent_entities = self.retrieve_recent_entities()
if recent_entities:
context_message = f"Here are some topics recently discussed: {', '.join(recent_entities)}. Feel free to bring them up if relevant."
conversation.add_message(role="system", text=context_message)
self.logger.info(
f"Added context message with entities: {recent_entities}"
)
# Check for identity statements FIRST
identity = self.extract_identity(last_message.text)
if identity:
self.logger.info(f"Extracted identity: {identity}")
self.personal_identity = identity
self.store_identity(identity)
conversation.add_message(
role="assistant", text=f"I'll remember that your name is {identity}."
)
return False
# Handle identity questions
if self.is_identity_question(last_message.text):
self.load_identity() # Reload identity from database
conversation.add_message(
role="assistant",
text=(
f"You are {self.personal_identity}."
if self.personal_identity
else "I don't know your name yet. You can tell me by saying 'I am [your name]' or 'My name is [your name]'."
),
)
return False
# Extract and store entities
entities = self.extract_entities(last_message.text)
for entity in entities:
self.store_entity(entity)
self.logger.info(f"Stored entity: {entity}")
if not entities:
self.logger.info("No entities found in message")
# Add context message
recent_entities = self.retrieve_recent_entities()
context_message = self.format_context_message(recent_entities)
if context_message: # Only add if there's actual context to share
conversation.add_message(role="system", text=context_message)
self.logger.info(f"Added context message: {context_message}")
# Replace the example usage code at the bottom with this chat interface:
def main():
# Create a conversation and add the plugin
conversation = sm.create_conversation(llm_model="gpt-4", llm_provider="openai")
conversation.add_plugin(EnhancedContextPlugin())
plugin = EnhancedContextPlugin()
conversation.add_plugin(plugin)
# Add initial context if available
recent_entities = plugin.retrieve_recent_entities()
context_message = plugin.format_context_message(recent_entities)
if context_message:
conversation.add_message(role="system", text=context_message)
plugin.logger.info(f"Added initial context message: {context_message}")
print("Chat interface ready! Type 'quit' to exit.")
print("-" * 50)
while True:
# Get user input
user_input = input("\nYou: ").strip()
try:
while True:
# Get user input
user_input = input("\nYou: ").strip()
# Check for quit command
if user_input.lower() in ["quit", "exit", "q"]:
print("\nGoodbye!")
break
# Check for quit command
if user_input.lower() in ["quit", "exit", "q"]:
print("\nGoodbye!")
break
# Add user message and get response
conversation.add_message(role="user", text=user_input)
response = conversation.send()
# Add user message and get response
conversation.add_message(role="user", text=user_input)
# Store the result of pre_send_hook
should_continue = plugin.pre_send_hook(conversation)
# Print assistant's response
print("\nAssistant:", response)
print("-" * 50)
# Only send to LLM if pre_send_hook returns True or None
if should_continue is not False:
response = conversation.send()
print("\nAssistant:", response)
else:
# Get the last assistant message that was added by the plugin
response = conversation.get_last_message(role="assistant")
if response:
print("\nAssistant:", response.text)
print("-" * 50)
except KeyboardInterrupt:
print("\n\nGoodbye!")
return
if __name__ == "__main__":