Refactor enhanced_context.py and update requirements.txt

This commit is contained in:
2024-11-06 08:42:43 -05:00
parent 07715ed8df
commit 0087a7e8f2
+91 -23
View File
@@ -2,12 +2,22 @@ from datetime import datetime, timedelta
import logging
import sqlite3
from typing import List
import re
import spacy
from contextlib import contextmanager
from _context import simplemind as sm
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from rich import print
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
DB_PATH = "enhanced_context.db"
@@ -39,6 +49,14 @@ class EnhancedContextPlugin(sm.BasePlugin):
self.personal_identity = None
self.load_identity()
# Download required NLTK data
try:
nltk.data.find("tokenizers/punkt")
nltk.data.find("averaged_perceptron_tagger")
except LookupError:
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")
@contextmanager
def get_connection(self):
"""Context manager for database connections"""
@@ -178,30 +196,71 @@ class EnhancedContextPlugin(sm.BasePlugin):
return None
def is_identity_question(self, text: str) -> bool:
"""Check if the message is asking about identity"""
text = text.lower().strip()
identity_questions = {
"who am i",
"what's my name",
"what is my name",
"do you know who i am",
"do you know my name",
}
return text in identity_questions
"""Use NLTK to detect identity questions"""
# Tokenize and tag parts of speech
tokens = word_tokenize(text.lower())
tagged = pos_tag(tokens)
# Extract key words and patterns
words = set(tokens)
has_question_word = any(word in ["who", "what"] for word in words)
has_identity_term = any(word in ["i", "me", "my", "name"] for word in words)
has_conversation_term = any(
word in ["talking", "speaking", "chatting"] for word in words
)
# Check for question structure
is_question = (
text.endswith("?")
or has_question_word
or any(
tag in ["WP", "WRB"] for word, tag in tagged
) # WP = wh-pronoun, WRB = wh-adverb
)
# Combine conditions for identity questions
is_identity_question = is_question and (
(has_identity_term) or (has_question_word and has_conversation_term)
)
if is_identity_question:
self.logger.info(f"Detected identity question: {text}")
return is_identity_question
def store_identity(self, identity: str) -> None:
"""Store personal identity in database"""
"""Store personal identity in database and add to recent entities"""
if not identity:
return
try:
with self.get_connection() as conn:
now = datetime.now()
# Store in identity table
conn.execute(
"""
INSERT OR REPLACE INTO identity (id, name, last_updated)
VALUES (1, ?, ?)
""",
(identity, datetime.now()),
(identity, now),
)
conn.commit() # Add explicit commit
# Store in entities table with explicit timestamp
conn.execute(
"""
INSERT INTO entities (name, type, timestamp)
VALUES (?, 'identity', ?)
""",
(identity, now),
)
conn.commit()
self.logger.info(f"Stored identity in database: {identity}")
# Verify storage
cur = conn.cursor()
cur.execute("SELECT name FROM identity WHERE id = 1")
self.logger.info(f"Verified identity storage: {cur.fetchone()}")
except sqlite3.Error as e:
self.logger.error(f"Database error while storing identity: {e}")
@@ -268,7 +327,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
recent_entities = self.retrieve_recent_entities()
context_message = self.format_context_message(recent_entities)
if context_message: # Only add if there's actual context to share
conversation.add_message(role="system", text=context_message)
conversation.add_message(role="user", text=context_message)
self.logger.info(f"Added context message: {context_message}")
@@ -286,37 +345,46 @@ def main():
conversation.add_message(role="system", text=context_message)
plugin.logger.info(f"Added initial context message: {context_message}")
print("Chat interface ready! Type 'quit' to exit.")
console = Console()
console.print(
Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.")
)
print("-" * 50)
try:
while True:
# Get user input
user_input = input("\nYou: ").strip()
# Get user input with colored prompt
console.print("\n[bold blue]You:[/bold blue] ", end="")
user_input = input().strip()
# Check for quit command
if user_input.lower() in ["quit", "exit", "q"]:
print("\nGoodbye!")
console.print("\n[bold green]Goodbye![/bold green]")
break
# Add user message and get response
conversation.add_message(role="user", text=user_input)
# Store the result of pre_send_hook
should_continue = plugin.pre_send_hook(conversation)
# Only send to LLM if pre_send_hook returns True or None
if should_continue is not False:
response = conversation.send()
print("\nAssistant:", response)
console.print(
"\n[bold purple]Assistant:[/bold purple]",
Text(str(response.text), style="italic"),
)
else:
# Get the last assistant message that was added by the plugin
response = conversation.get_last_message(role="assistant")
if response:
print("\nAssistant:", response.text)
console.print(
"\n[bold purple]Assistant:[/bold purple]",
Text(response.text, style="bold green"),
)
print("-" * 50)
console.print(Text("-" * 50, style="dim"))
except KeyboardInterrupt:
print("\n\nGoodbye!")
console.print("\n\n[bold green]Goodbye![/bold green]")
return