From 07715ed8df3cfa33e7d8e406e3a08cfcae0be0b2 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Wed, 6 Nov 2024 08:30:29 -0500 Subject: [PATCH] Refactor enhanced_context.py and update requirements.txt --- examples/enhanced_context.py | 341 +++++++++++++++++++++++++++-------- 1 file changed, 268 insertions(+), 73 deletions(-) diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index ef88305..d893a9f 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -1,8 +1,10 @@ -from datetime import datetime +from datetime import datetime, timedelta import logging +import sqlite3 +from typing import List import spacy -import sqlite3 +from contextlib import contextmanager from _context import simplemind as sm @@ -10,119 +12,312 @@ DB_PATH = "enhanced_context.db" class EnhancedContextPlugin(sm.BasePlugin): - # Add model configuration to allow arbitrary attributes model_config = {"extra": "allow"} def __init__(self): - super().__init__() # Don't forget to call parent's __init__ + super().__init__() # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) self.logger = logging.getLogger(__name__) - # Initialize NLP model and memory database - self.nlp = spacy.load("en_core_web_sm") - self.conn = sqlite3.connect(DB_PATH) + # Initialize NLP model + try: + self.nlp = spacy.load("en_core_web_sm") + except OSError: + self.logger.error( + "Failed to load spaCy model. Please install it using: python -m spacy download en_core_web_sm" + ) + raise + + # Initialize database self.init_db() self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}") + # Load identity from database + self.personal_identity = None + self.load_identity() + + @contextmanager + def get_connection(self): + """Context manager for database connections""" + conn = sqlite3.connect(DB_PATH) + try: + yield conn + finally: + conn.close() + def init_db(self): - # Create a table to store entities and their last mention time - with self.conn: - self.conn.execute( + """Initialize the database with proper schema""" + with self.get_connection() as conn: + # Create memory table for entities + conn.execute( """ CREATE TABLE IF NOT EXISTS memory ( entity TEXT PRIMARY KEY, - last_mentioned TIMESTAMP + last_mentioned TIMESTAMP, + mention_count INTEGER DEFAULT 1 ) """ ) - def store_entity(self, entity): - # Store or update entity mention time - with self.conn: - self.conn.execute( + # Create identity table + conn.execute( """ - INSERT OR REPLACE INTO memory (entity, last_mentioned) - VALUES (?, ?) - """, - (entity, datetime.now()), - ) - self.logger.info(f"Stored entity: {entity}") - - def retrieve_recent_entities(self): - # Retrieve entities mentioned in the last 7 days - cur = self.conn.cursor() - cur.execute( + CREATE TABLE IF NOT EXISTS identity ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + last_updated TIMESTAMP + ) """ - SELECT entity FROM memory - WHERE last_mentioned >= datetime('now', '-7 days') - """ - ) - entities = [row[0] for row in cur.fetchall()] - self.logger.info(f"Retrieved recent entities: {entities}") - return entities + ) - def extract_entities(self, text): - # Extract entities (people, places, organizations) from text + def store_entity(self, entity: str) -> None: + """Store or update entity mention with error handling""" + try: + with self.get_connection() as conn: + conn.execute( + """ + INSERT INTO memory (entity, last_mentioned, mention_count) + VALUES (?, ?, 1) + ON CONFLICT(entity) DO UPDATE SET + last_mentioned = ?, + mention_count = mention_count + 1 + """, + (entity, datetime.now(), datetime.now()), + ) + self.logger.info(f"Stored entity: {entity}") + except sqlite3.Error as e: + self.logger.error(f"Database error while storing entity {entity}: {e}") + + def retrieve_recent_entities(self, days: int = 7) -> List[str]: + """Retrieve recently mentioned entities with frequency""" + try: + with self.get_connection() as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT entity, mention_count + FROM memory + WHERE last_mentioned >= datetime('now', ?) + ORDER BY mention_count DESC, last_mentioned DESC + LIMIT 5 + """, + (f"-{days} days",), + ) + + entities = [(row[0], row[1]) for row in cur.fetchall()] + self.logger.info(f"Retrieved recent entities: {entities}") + return entities + except sqlite3.Error as e: + self.logger.error(f"Database error while retrieving entities: {e}") + return [] + + def extract_entities(self, text: str) -> List[str]: + """Extract named entities with improved filtering""" doc = self.nlp(text) - return [ - ent.text - for ent in doc.ents - if ent.label_ in {"PERSON", "ORG", "GPE", "NORP"} - ] + entities = [] + + # Define important entity types + important_types = { + "PERSON", + "ORG", + "GPE", + "NORP", + "PRODUCT", + "EVENT", + "WORK_OF_ART", + } + + for ent in doc.ents: + if ( + ent.label_ in important_types + and len(ent.text.strip()) > 1 # Avoid single characters + and not ent.text.isnumeric() + ): # Avoid pure numbers + entities.append(ent.text.strip()) + + return list(set(entities)) # Remove duplicates + + def format_context_message( + self, entities: List[tuple], include_identity: bool = True + ) -> str: + """Format context message more naturally""" + context_parts = [] + + # Add identity context if available and requested + if include_identity and self.personal_identity: + context_parts.append(f"You are speaking with {self.personal_identity}") + + # Add entity context if available + if entities: + # Format entities with their mention counts + entity_strings = [ + f"{entity} ({'mentioned multiple times' if count > 1 else 'mentioned recently'})" + for entity, count in entities + ] + + context_parts.append( + f"Previously discussed topic{'s' if len(entity_strings) > 1 else ''}: " + + ( + ", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}" + if len(entity_strings) > 1 + else entity_strings[0] + ) + ) + + return ". ".join(context_parts) + ("." if context_parts else "") + + def extract_identity(self, text: str) -> str | None: + """Extract identity statements like 'I am X'""" + text = text.lower().strip() + if text.startswith("i am ") or text.startswith("my name is "): + identity = text.replace("i am ", "").replace("my name is ", "").strip() + return identity if identity else None + return None + + def is_identity_question(self, text: str) -> bool: + """Check if the message is asking about identity""" + text = text.lower().strip() + identity_questions = { + "who am i", + "what's my name", + "what is my name", + "do you know who i am", + "do you know my name", + } + return text in identity_questions + + def store_identity(self, identity: str) -> None: + """Store personal identity in database""" + try: + with self.get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO identity (id, name, last_updated) + VALUES (1, ?, ?) + """, + (identity, datetime.now()), + ) + conn.commit() # Add explicit commit + self.logger.info(f"Stored identity in database: {identity}") + except sqlite3.Error as e: + self.logger.error(f"Database error while storing identity: {e}") + + def load_identity(self) -> str | None: + """Load personal identity from database""" + try: + with self.get_connection() as conn: + cur = conn.cursor() + cur.execute("SELECT name FROM identity WHERE id = 1") + result = cur.fetchone() + if result: + self.personal_identity = result[0] + self.logger.info( + f"Loaded identity from database: {self.personal_identity}" + ) + else: + self.logger.info("No identity found in database") + return self.personal_identity + except sqlite3.Error as e: + self.logger.error(f"Database error while loading identity: {e}") + return None def pre_send_hook(self, conversation: sm.Conversation): - # Process the latest user message last_message = conversation.get_last_message(role="user") - if last_message: - self.logger.info(f"Processing message: {last_message.text}") + if not last_message: + return - # Extract entities and store in memory - entities = self.extract_entities(last_message.text) - if entities: - self.logger.info(f"Extracted entities: {entities}") - for entity in entities: - self.store_entity(entity) - else: - self.logger.info("No entities found in message") + self.logger.info(f"Processing message: {last_message.text}") - # Retrieve recent entities for context - recent_entities = self.retrieve_recent_entities() - if recent_entities: - context_message = f"Here are some topics recently discussed: {', '.join(recent_entities)}. Feel free to bring them up if relevant." - conversation.add_message(role="system", text=context_message) - self.logger.info( - f"Added context message with entities: {recent_entities}" - ) + # Check for identity statements FIRST + identity = self.extract_identity(last_message.text) + if identity: + self.logger.info(f"Extracted identity: {identity}") + self.personal_identity = identity + self.store_identity(identity) + conversation.add_message( + role="assistant", text=f"I'll remember that your name is {identity}." + ) + return False + + # Handle identity questions + if self.is_identity_question(last_message.text): + self.load_identity() # Reload identity from database + conversation.add_message( + role="assistant", + text=( + f"You are {self.personal_identity}." + if self.personal_identity + else "I don't know your name yet. You can tell me by saying 'I am [your name]' or 'My name is [your name]'." + ), + ) + return False + + # Extract and store entities + entities = self.extract_entities(last_message.text) + for entity in entities: + self.store_entity(entity) + self.logger.info(f"Stored entity: {entity}") + + if not entities: + self.logger.info("No entities found in message") + + # Add context message + recent_entities = self.retrieve_recent_entities() + context_message = self.format_context_message(recent_entities) + if context_message: # Only add if there's actual context to share + conversation.add_message(role="system", text=context_message) + self.logger.info(f"Added context message: {context_message}") # Replace the example usage code at the bottom with this chat interface: def main(): # Create a conversation and add the plugin conversation = sm.create_conversation(llm_model="gpt-4", llm_provider="openai") - conversation.add_plugin(EnhancedContextPlugin()) + plugin = EnhancedContextPlugin() + conversation.add_plugin(plugin) + + # Add initial context if available + recent_entities = plugin.retrieve_recent_entities() + context_message = plugin.format_context_message(recent_entities) + if context_message: + conversation.add_message(role="system", text=context_message) + plugin.logger.info(f"Added initial context message: {context_message}") print("Chat interface ready! Type 'quit' to exit.") print("-" * 50) - while True: - # Get user input - user_input = input("\nYou: ").strip() + try: + while True: + # Get user input + user_input = input("\nYou: ").strip() - # Check for quit command - if user_input.lower() in ["quit", "exit", "q"]: - print("\nGoodbye!") - break + # Check for quit command + if user_input.lower() in ["quit", "exit", "q"]: + print("\nGoodbye!") + break - # Add user message and get response - conversation.add_message(role="user", text=user_input) - response = conversation.send() + # Add user message and get response + conversation.add_message(role="user", text=user_input) + # Store the result of pre_send_hook + should_continue = plugin.pre_send_hook(conversation) - # Print assistant's response - print("\nAssistant:", response) - print("-" * 50) + # Only send to LLM if pre_send_hook returns True or None + if should_continue is not False: + response = conversation.send() + print("\nAssistant:", response) + else: + # Get the last assistant message that was added by the plugin + response = conversation.get_last_message(role="assistant") + if response: + print("\nAssistant:", response.text) + + print("-" * 50) + except KeyboardInterrupt: + print("\n\nGoodbye!") + return if __name__ == "__main__":