mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor EnhancedContextPlugin to store datetime in SQLite format, handle datetime strings properly, and extract/store entities for context
This commit is contained in:
@@ -3,6 +3,8 @@ import logging
|
||||
import sqlite3
|
||||
from typing import List
|
||||
import re
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
import spacy
|
||||
from contextlib import contextmanager
|
||||
@@ -19,6 +21,8 @@ from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
LLM_PROVIDER = "openai"
|
||||
LLM_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
class EnhancedContextPlugin(sm.BasePlugin):
|
||||
@@ -53,13 +57,17 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
self.personal_identity = None
|
||||
self.load_identity()
|
||||
|
||||
# Download required NLTK data
|
||||
# Download required NLTK data silently
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("averaged_perceptron_tagger")
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
nltk.download("averaged_perceptron_tagger")
|
||||
with open(os.devnull, "w") as null_out:
|
||||
with (
|
||||
contextlib.redirect_stdout(null_out),
|
||||
contextlib.redirect_stderr(null_out),
|
||||
):
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
except LookupError as e:
|
||||
self.logger.error(f"Error downloading NLTK data: {e}")
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
@@ -462,7 +470,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
if not last_message:
|
||||
return
|
||||
|
||||
self.logger.info(f"Processing message: {last_message.text}")
|
||||
self.logger.info(f"Processing user message: {last_message.text}")
|
||||
|
||||
# Extract and store entities
|
||||
entities = self.extract_entities(last_message.text)
|
||||
@@ -484,12 +492,27 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
return True
|
||||
|
||||
def post_response_hook(self, conversation: sm.Conversation):
|
||||
"""Process the LLM's response to extract and store entities"""
|
||||
last_response = conversation.get_last_message(role="assistant")
|
||||
if not last_response:
|
||||
return
|
||||
|
||||
self.logger.info(f"Processing assistant response: {last_response.text}")
|
||||
|
||||
# Extract and store entities from the LLM's response
|
||||
entities = self.extract_entities(last_response.text)
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Replace the example usage code at the bottom with this chat interface:
|
||||
def main():
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(
|
||||
llm_model="gpt-4o-mini", llm_provider="openai"
|
||||
llm_model=LLM_MODEL, llm_provider=LLM_PROVIDER
|
||||
)
|
||||
plugin = EnhancedContextPlugin(verbose=False) # Set verbose here
|
||||
conversation.add_plugin(plugin)
|
||||
@@ -525,6 +548,8 @@ def main():
|
||||
# Only send to LLM if pre_send_hook returns True or None
|
||||
if should_continue is not False:
|
||||
response = conversation.send()
|
||||
# Add post-response processing
|
||||
plugin.post_response_hook(conversation)
|
||||
console.print(
|
||||
"\n[bold purple]Assistant:[/bold purple]",
|
||||
Text(str(response.text), style="italic"),
|
||||
|
||||
Reference in New Issue
Block a user