mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Refactor enhanced_context.py and update requirements.txt
This commit is contained in:
@@ -168,3 +168,4 @@ cython_debug/
|
||||
src/**
|
||||
requirements.txt
|
||||
Pipfile
|
||||
enhanced_context.db
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
import spacy
|
||||
import sqlite3
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
|
||||
|
||||
class EnhancedContextPlugin(sm.BasePlugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Initialize NLP model and memory database
|
||||
object.__setattr__(self, "nlp", spacy.load("en_core_web_sm"))
|
||||
object.__setattr__(self, "conn", sqlite3.connect(":memory:"))
|
||||
# Add model configuration to allow arbitrary attributes
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__() # Don't forget to call parent's __init__
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize NLP model and memory database
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
self.conn = sqlite3.connect(DB_PATH)
|
||||
self.init_db()
|
||||
self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}")
|
||||
|
||||
def init_db(self):
|
||||
# Create a table to store entities and their last mention time
|
||||
@@ -29,7 +41,6 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
def store_entity(self, entity):
|
||||
# Store or update entity mention time
|
||||
print(f"Storing entity in memory: {entity}")
|
||||
with self.conn:
|
||||
self.conn.execute(
|
||||
"""
|
||||
@@ -38,6 +49,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
""",
|
||||
(entity, datetime.now()),
|
||||
)
|
||||
self.logger.info(f"Stored entity: {entity}")
|
||||
|
||||
def retrieve_recent_entities(self):
|
||||
# Retrieve entities mentioned in the last 7 days
|
||||
@@ -48,7 +60,9 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
WHERE last_mentioned >= datetime('now', '-7 days')
|
||||
"""
|
||||
)
|
||||
return [row[0] for row in cur.fetchall()]
|
||||
entities = [row[0] for row in cur.fetchall()]
|
||||
self.logger.info(f"Retrieved recent entities: {entities}")
|
||||
return entities
|
||||
|
||||
def extract_entities(self, text):
|
||||
# Extract entities (people, places, organizations) from text
|
||||
@@ -63,40 +77,53 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
# Process the latest user message
|
||||
last_message = conversation.get_last_message(role="user")
|
||||
if last_message:
|
||||
self.logger.info(f"Processing message: {last_message.text}")
|
||||
|
||||
# Extract entities and store in memory
|
||||
entities = self.extract_entities(last_message.text)
|
||||
|
||||
print(f"Extracted entities: {entities}")
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
if entities:
|
||||
self.logger.info(f"Extracted entities: {entities}")
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
else:
|
||||
self.logger.info("No entities found in message")
|
||||
|
||||
# Retrieve recent entities for context
|
||||
recent_entities = self.retrieve_recent_entities()
|
||||
|
||||
if recent_entities:
|
||||
print(f"Recent entities found: {recent_entities}")
|
||||
|
||||
context_message = f"Here are some topics recently discussed: {', '.join(recent_entities)}. Feel free to bring them up if relevant."
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
self.logger.info(
|
||||
f"Added context message with entities: {recent_entities}"
|
||||
)
|
||||
|
||||
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||
conversation.add_plugin(EnhancedContextPlugin())
|
||||
# Replace the example usage code at the bottom with this chat interface:
|
||||
def main():
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model="gpt-4", llm_provider="openai")
|
||||
conversation.add_plugin(EnhancedContextPlugin())
|
||||
|
||||
print("Chat interface ready! Type 'quit' to exit.")
|
||||
print("-" * 50)
|
||||
|
||||
# Replace the single message test with an interactive chat loop
|
||||
def chat():
|
||||
print("Welcome to the enhanced context chat! Type 'quit' to exit.")
|
||||
while True:
|
||||
# Get user input
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
|
||||
# Check for quit command
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
# Add user message and get response
|
||||
conversation.add_message(role="user", text=user_input)
|
||||
response = conversation.send()
|
||||
print(f"\nAssistant: {response.text!r}")
|
||||
|
||||
# Print assistant's response
|
||||
print("\nAssistant:", response)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user