From 0087a7e8f2504c69a22b56f91de69bdc0f5802e2 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Wed, 6 Nov 2024 08:42:43 -0500 Subject: [PATCH] Refactor enhanced_context.py and update requirements.txt --- examples/enhanced_context.py | 114 ++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 23 deletions(-) diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index d893a9f..9f3d44f 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -2,12 +2,22 @@ from datetime import datetime, timedelta import logging import sqlite3 from typing import List +import re import spacy from contextlib import contextmanager from _context import simplemind as sm +import nltk +from nltk.tokenize import word_tokenize +from nltk.tag import pos_tag + +from rich import print +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + DB_PATH = "enhanced_context.db" @@ -39,6 +49,14 @@ class EnhancedContextPlugin(sm.BasePlugin): self.personal_identity = None self.load_identity() + # Download required NLTK data + try: + nltk.data.find("tokenizers/punkt") + nltk.data.find("averaged_perceptron_tagger") + except LookupError: + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + @contextmanager def get_connection(self): """Context manager for database connections""" @@ -178,30 +196,71 @@ class EnhancedContextPlugin(sm.BasePlugin): return None def is_identity_question(self, text: str) -> bool: - """Check if the message is asking about identity""" - text = text.lower().strip() - identity_questions = { - "who am i", - "what's my name", - "what is my name", - "do you know who i am", - "do you know my name", - } - return text in identity_questions + """Use NLTK to detect identity questions""" + # Tokenize and tag parts of speech + tokens = word_tokenize(text.lower()) + tagged = pos_tag(tokens) + + # Extract key words and patterns + words = set(tokens) + has_question_word = any(word in ["who", "what"] for word in words) + has_identity_term = any(word in ["i", "me", "my", "name"] for word in words) + has_conversation_term = any( + word in ["talking", "speaking", "chatting"] for word in words + ) + + # Check for question structure + is_question = ( + text.endswith("?") + or has_question_word + or any( + tag in ["WP", "WRB"] for word, tag in tagged + ) # WP = wh-pronoun, WRB = wh-adverb + ) + + # Combine conditions for identity questions + is_identity_question = is_question and ( + (has_identity_term) or (has_question_word and has_conversation_term) + ) + + if is_identity_question: + self.logger.info(f"Detected identity question: {text}") + + return is_identity_question def store_identity(self, identity: str) -> None: - """Store personal identity in database""" + """Store personal identity in database and add to recent entities""" + if not identity: + return + try: with self.get_connection() as conn: + now = datetime.now() + + # Store in identity table conn.execute( """ INSERT OR REPLACE INTO identity (id, name, last_updated) VALUES (1, ?, ?) """, - (identity, datetime.now()), + (identity, now), ) - conn.commit() # Add explicit commit + + # Store in entities table with explicit timestamp + conn.execute( + """ + INSERT INTO entities (name, type, timestamp) + VALUES (?, 'identity', ?) + """, + (identity, now), + ) + conn.commit() self.logger.info(f"Stored identity in database: {identity}") + + # Verify storage + cur = conn.cursor() + cur.execute("SELECT name FROM identity WHERE id = 1") + self.logger.info(f"Verified identity storage: {cur.fetchone()}") except sqlite3.Error as e: self.logger.error(f"Database error while storing identity: {e}") @@ -268,7 +327,7 @@ class EnhancedContextPlugin(sm.BasePlugin): recent_entities = self.retrieve_recent_entities() context_message = self.format_context_message(recent_entities) if context_message: # Only add if there's actual context to share - conversation.add_message(role="system", text=context_message) + conversation.add_message(role="user", text=context_message) self.logger.info(f"Added context message: {context_message}") @@ -286,37 +345,46 @@ def main(): conversation.add_message(role="system", text=context_message) plugin.logger.info(f"Added initial context message: {context_message}") - print("Chat interface ready! Type 'quit' to exit.") + console = Console() + console.print( + Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.") + ) print("-" * 50) try: while True: - # Get user input - user_input = input("\nYou: ").strip() + # Get user input with colored prompt + console.print("\n[bold blue]You:[/bold blue] ", end="") + user_input = input().strip() # Check for quit command if user_input.lower() in ["quit", "exit", "q"]: - print("\nGoodbye!") + console.print("\n[bold green]Goodbye![/bold green]") break # Add user message and get response conversation.add_message(role="user", text=user_input) - # Store the result of pre_send_hook should_continue = plugin.pre_send_hook(conversation) # Only send to LLM if pre_send_hook returns True or None if should_continue is not False: response = conversation.send() - print("\nAssistant:", response) + console.print( + "\n[bold purple]Assistant:[/bold purple]", + Text(str(response.text), style="italic"), + ) else: # Get the last assistant message that was added by the plugin response = conversation.get_last_message(role="assistant") if response: - print("\nAssistant:", response.text) + console.print( + "\n[bold purple]Assistant:[/bold purple]", + Text(response.text, style="bold green"), + ) - print("-" * 50) + console.print(Text("-" * 50, style="dim")) except KeyboardInterrupt: - print("\n\nGoodbye!") + console.print("\n\n[bold green]Goodbye![/bold green]") return