mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Refactor enhanced_context.py and update requirements.txt
This commit is contained in:
@@ -2,12 +2,22 @@ from datetime import datetime, timedelta
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import List
|
||||
import re
|
||||
|
||||
import spacy
|
||||
from contextlib import contextmanager
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
import nltk
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.tag import pos_tag
|
||||
|
||||
from rich import print
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
|
||||
|
||||
@@ -39,6 +49,14 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
self.personal_identity = None
|
||||
self.load_identity()
|
||||
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("averaged_perceptron_tagger")
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
nltk.download("averaged_perceptron_tagger")
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Context manager for database connections"""
|
||||
@@ -178,30 +196,71 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
return None
|
||||
|
||||
def is_identity_question(self, text: str) -> bool:
|
||||
"""Check if the message is asking about identity"""
|
||||
text = text.lower().strip()
|
||||
identity_questions = {
|
||||
"who am i",
|
||||
"what's my name",
|
||||
"what is my name",
|
||||
"do you know who i am",
|
||||
"do you know my name",
|
||||
}
|
||||
return text in identity_questions
|
||||
"""Use NLTK to detect identity questions"""
|
||||
# Tokenize and tag parts of speech
|
||||
tokens = word_tokenize(text.lower())
|
||||
tagged = pos_tag(tokens)
|
||||
|
||||
# Extract key words and patterns
|
||||
words = set(tokens)
|
||||
has_question_word = any(word in ["who", "what"] for word in words)
|
||||
has_identity_term = any(word in ["i", "me", "my", "name"] for word in words)
|
||||
has_conversation_term = any(
|
||||
word in ["talking", "speaking", "chatting"] for word in words
|
||||
)
|
||||
|
||||
# Check for question structure
|
||||
is_question = (
|
||||
text.endswith("?")
|
||||
or has_question_word
|
||||
or any(
|
||||
tag in ["WP", "WRB"] for word, tag in tagged
|
||||
) # WP = wh-pronoun, WRB = wh-adverb
|
||||
)
|
||||
|
||||
# Combine conditions for identity questions
|
||||
is_identity_question = is_question and (
|
||||
(has_identity_term) or (has_question_word and has_conversation_term)
|
||||
)
|
||||
|
||||
if is_identity_question:
|
||||
self.logger.info(f"Detected identity question: {text}")
|
||||
|
||||
return is_identity_question
|
||||
|
||||
def store_identity(self, identity: str) -> None:
|
||||
"""Store personal identity in database"""
|
||||
"""Store personal identity in database and add to recent entities"""
|
||||
if not identity:
|
||||
return
|
||||
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
now = datetime.now()
|
||||
|
||||
# Store in identity table
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO identity (id, name, last_updated)
|
||||
VALUES (1, ?, ?)
|
||||
""",
|
||||
(identity, datetime.now()),
|
||||
(identity, now),
|
||||
)
|
||||
conn.commit() # Add explicit commit
|
||||
|
||||
# Store in entities table with explicit timestamp
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO entities (name, type, timestamp)
|
||||
VALUES (?, 'identity', ?)
|
||||
""",
|
||||
(identity, now),
|
||||
)
|
||||
conn.commit()
|
||||
self.logger.info(f"Stored identity in database: {identity}")
|
||||
|
||||
# Verify storage
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM identity WHERE id = 1")
|
||||
self.logger.info(f"Verified identity storage: {cur.fetchone()}")
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error while storing identity: {e}")
|
||||
|
||||
@@ -268,7 +327,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
recent_entities = self.retrieve_recent_entities()
|
||||
context_message = self.format_context_message(recent_entities)
|
||||
if context_message: # Only add if there's actual context to share
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
self.logger.info(f"Added context message: {context_message}")
|
||||
|
||||
|
||||
@@ -286,37 +345,46 @@ def main():
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
plugin.logger.info(f"Added initial context message: {context_message}")
|
||||
|
||||
print("Chat interface ready! Type 'quit' to exit.")
|
||||
console = Console()
|
||||
console.print(
|
||||
Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.")
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
user_input = input("\nYou: ").strip()
|
||||
# Get user input with colored prompt
|
||||
console.print("\n[bold blue]You:[/bold blue] ", end="")
|
||||
user_input = input().strip()
|
||||
|
||||
# Check for quit command
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("\nGoodbye!")
|
||||
console.print("\n[bold green]Goodbye![/bold green]")
|
||||
break
|
||||
|
||||
# Add user message and get response
|
||||
conversation.add_message(role="user", text=user_input)
|
||||
# Store the result of pre_send_hook
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
# Only send to LLM if pre_send_hook returns True or None
|
||||
if should_continue is not False:
|
||||
response = conversation.send()
|
||||
print("\nAssistant:", response)
|
||||
console.print(
|
||||
"\n[bold purple]Assistant:[/bold purple]",
|
||||
Text(str(response.text), style="italic"),
|
||||
)
|
||||
else:
|
||||
# Get the last assistant message that was added by the plugin
|
||||
response = conversation.get_last_message(role="assistant")
|
||||
if response:
|
||||
print("\nAssistant:", response.text)
|
||||
console.print(
|
||||
"\n[bold purple]Assistant:[/bold purple]",
|
||||
Text(response.text, style="bold green"),
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
console.print(Text("-" * 50, style="dim"))
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nGoodbye!")
|
||||
console.print("\n\n[bold green]Goodbye![/bold green]")
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user