mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor EnhancedContextPlugin to store datetime in SQLite format, handle datetime strings properly, and extract/store entities for context
This commit is contained in:
+199
-40
@@ -24,12 +24,16 @@ DB_PATH = "enhanced_context.db"
|
||||
class EnhancedContextPlugin(sm.BasePlugin):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, verbose: bool = False):
|
||||
super().__init__()
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize NLP model
|
||||
@@ -91,6 +95,18 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
"""
|
||||
)
|
||||
|
||||
# Create essence markers table
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS essence_markers (
|
||||
marker_type TEXT,
|
||||
marker_text TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
PRIMARY KEY (marker_type, marker_text)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def store_entity(self, entity: str) -> None:
|
||||
"""Store or update entity mention with error handling"""
|
||||
try:
|
||||
@@ -170,18 +186,20 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
# Add identity context if available and requested
|
||||
if include_identity and self.personal_identity:
|
||||
context_parts.append(f"You are speaking with {self.personal_identity}")
|
||||
context_parts.append(
|
||||
f"The user's name is {self.personal_identity}. Remember to use their name naturally in conversation when appropriate."
|
||||
)
|
||||
|
||||
# Add entity context if available
|
||||
if entities:
|
||||
# Format entities with their mention counts
|
||||
entity_strings = [
|
||||
f"{entity} ({'mentioned multiple times' if count > 1 else 'mentioned recently'})"
|
||||
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
|
||||
for entity, count in entities
|
||||
]
|
||||
|
||||
context_parts.append(
|
||||
f"Previously discussed topic{'s' if len(entity_strings) > 1 else ''}: "
|
||||
"Recent conversation history includes: "
|
||||
+ (
|
||||
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||
if len(entity_strings) > 1
|
||||
@@ -189,7 +207,13 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
)
|
||||
)
|
||||
|
||||
return ". ".join(context_parts) + ("." if context_parts else "")
|
||||
# Add instructions for memory queries
|
||||
context_parts.append(
|
||||
"If the user asks about memories or what has been discussed, "
|
||||
"naturally incorporate the above context into your response."
|
||||
)
|
||||
|
||||
return " ".join(context_parts)
|
||||
|
||||
def extract_identity(self, text: str) -> str | None:
|
||||
"""Extract identity statements like 'I am X'"""
|
||||
@@ -280,6 +304,159 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
self.logger.error(f"Database error while loading identity: {e}")
|
||||
return None
|
||||
|
||||
def is_memory_question(self, text: str) -> bool:
|
||||
"""Detect questions about memory and recall"""
|
||||
text = text.lower().strip()
|
||||
|
||||
# Keywords related to memory and recall
|
||||
memory_words = {
|
||||
"remember",
|
||||
"recall",
|
||||
"memory",
|
||||
"memories",
|
||||
"mentioned",
|
||||
"talked about",
|
||||
"discussed",
|
||||
"tell me about",
|
||||
"what do you know",
|
||||
}
|
||||
|
||||
return any(word in text for word in memory_words)
|
||||
|
||||
def extract_essence_markers(self, text: str) -> List[tuple[str, str]]:
|
||||
"""Extract essence markers from text.
|
||||
Returns list of tuples (marker_type, marker_text)"""
|
||||
|
||||
# Common patterns for essence markers
|
||||
patterns = {
|
||||
"value": [
|
||||
r"I (?:really )?(?:believe|think) (?:that )?(.+)",
|
||||
r"(?:It's|Its) important (?:to me )?that (.+)",
|
||||
r"I value (.+)",
|
||||
r"(?:The )?most important (?:thing|aspect) (?:to me )?is (.+)",
|
||||
],
|
||||
"identity": [
|
||||
r"I am(?: a| an)? (.+)",
|
||||
r"I consider myself(?: a| an)? (.+)",
|
||||
r"I identify as(?: a| an)? (.+)",
|
||||
],
|
||||
"preference": [
|
||||
r"I (?:really )?(?:like|love|enjoy|prefer) (.+)",
|
||||
r"I can't stand (.+)",
|
||||
r"I hate (.+)",
|
||||
r"I always (.+)",
|
||||
r"I never (.+)",
|
||||
],
|
||||
"emotion": [
|
||||
r"I feel (.+)",
|
||||
r"I'm feeling (.+)",
|
||||
r"(?:It|That) makes me feel (.+)",
|
||||
],
|
||||
}
|
||||
|
||||
markers = []
|
||||
|
||||
# Process with spaCy for better sentence splitting
|
||||
doc = self.nlp(text)
|
||||
|
||||
for sent in doc.sents:
|
||||
sent_text = sent.text.strip().lower()
|
||||
|
||||
# Check each pattern type
|
||||
for marker_type, pattern_list in patterns.items():
|
||||
for pattern in pattern_list:
|
||||
matches = re.finditer(pattern, sent_text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
marker_text = match.group(1).strip()
|
||||
# Filter out very short or common phrases
|
||||
if len(marker_text) > 3 and not any(
|
||||
w in marker_text for w in ["um", "uh", "like"]
|
||||
):
|
||||
markers.append((marker_type, marker_text))
|
||||
|
||||
return markers
|
||||
|
||||
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||
"""Store essence marker in database"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO essence_markers (marker_type, marker_text, timestamp)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(marker_type, marker_text, now),
|
||||
)
|
||||
conn.commit()
|
||||
self.logger.info(
|
||||
f"Stored essence marker: {marker_type} - {marker_text}"
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error storing essence marker: {e}")
|
||||
|
||||
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||
"""Retrieve recent essence markers"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT marker_type, marker_text
|
||||
FROM essence_markers
|
||||
WHERE timestamp >= datetime('now', ?, 'localtime')
|
||||
ORDER BY timestamp DESC
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
|
||||
markers = cur.fetchall()
|
||||
self.logger.info(f"Retrieved essence markers: {markers}")
|
||||
return markers
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error retrieving essence markers: {e}")
|
||||
return []
|
||||
|
||||
def format_context_message(
|
||||
self, entities: List[tuple], include_identity: bool = True
|
||||
) -> str:
|
||||
"""Format context message with essence markers"""
|
||||
context_parts = []
|
||||
|
||||
# Add identity context
|
||||
if include_identity and self.personal_identity:
|
||||
context_parts.append(f"The user's name is {self.personal_identity}.")
|
||||
|
||||
# Add essence markers
|
||||
essence_markers = self.retrieve_essence_markers()
|
||||
if essence_markers:
|
||||
markers_by_type = {}
|
||||
for marker_type, marker_text in essence_markers:
|
||||
if marker_type not in markers_by_type:
|
||||
markers_by_type[marker_type] = []
|
||||
markers_by_type[marker_type].append(marker_text)
|
||||
|
||||
context_parts.append("User characteristics:")
|
||||
for marker_type, markers in markers_by_type.items():
|
||||
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
|
||||
|
||||
# Add entity context
|
||||
if entities:
|
||||
entity_strings = [
|
||||
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
|
||||
for entity, count in entities
|
||||
]
|
||||
context_parts.append(
|
||||
"Recent conversation topics: "
|
||||
+ (
|
||||
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||
if len(entity_strings) > 1
|
||||
else entity_strings[0]
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
last_message = conversation.get_last_message(role="user")
|
||||
if not last_message:
|
||||
@@ -287,52 +464,34 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
self.logger.info(f"Processing message: {last_message.text}")
|
||||
|
||||
# Check for identity statements FIRST
|
||||
identity = self.extract_identity(last_message.text)
|
||||
if identity:
|
||||
self.logger.info(f"Extracted identity: {identity}")
|
||||
self.personal_identity = identity
|
||||
self.store_identity(identity)
|
||||
conversation.add_message(
|
||||
role="assistant", text=f"I'll remember that your name is {identity}."
|
||||
)
|
||||
return False
|
||||
|
||||
# Handle identity questions
|
||||
if self.is_identity_question(last_message.text):
|
||||
self.load_identity() # Reload identity from database
|
||||
conversation.add_message(
|
||||
role="assistant",
|
||||
text=(
|
||||
f"You are {self.personal_identity}."
|
||||
if self.personal_identity
|
||||
else "I don't know your name yet. You can tell me by saying 'I am [your name]' or 'My name is [your name]'."
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
# Extract and store entities
|
||||
entities = self.extract_entities(last_message.text)
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
self.logger.info(f"Stored entity: {entity}")
|
||||
|
||||
if not entities:
|
||||
self.logger.info("No entities found in message")
|
||||
# Extract and store essence markers
|
||||
essence_markers = self.extract_essence_markers(last_message.text)
|
||||
for marker_type, marker_text in essence_markers:
|
||||
self.store_essence_marker(marker_type, marker_text)
|
||||
self.logger.info(f"Found essence marker: {marker_type} - {marker_text}")
|
||||
|
||||
# Add context message
|
||||
recent_entities = self.retrieve_recent_entities()
|
||||
recent_entities = self.retrieve_recent_entities(days=30)
|
||||
context_message = self.format_context_message(recent_entities)
|
||||
if context_message: # Only add if there's actual context to share
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
if context_message:
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
self.logger.info(f"Added context message: {context_message}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Replace the example usage code at the bottom with this chat interface:
|
||||
def main():
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model="gpt-4", llm_provider="openai")
|
||||
plugin = EnhancedContextPlugin()
|
||||
conversation = sm.create_conversation(
|
||||
llm_model="gpt-4o-mini", llm_provider="openai"
|
||||
)
|
||||
plugin = EnhancedContextPlugin(verbose=False) # Set verbose here
|
||||
conversation.add_plugin(plugin)
|
||||
|
||||
# Add initial context if available
|
||||
|
||||
Reference in New Issue
Block a user