mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Refactor LLM provider and model in enhanced_context.py
This commit is contained in:
+157
-67
@@ -19,10 +19,19 @@ from rich import print
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.markdown import Markdown
|
||||
from rich.status import Status
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import random
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
LLM_PROVIDER = "xai"
|
||||
LLM_MODEL = "grok-beta"
|
||||
|
||||
LLM_MODEL = "gpt-4o-mini"
|
||||
LLM_PROVIDER = "openai"
|
||||
|
||||
# LLM_PROVIDER = "xai"
|
||||
# LLM_MODEL = "grok-beta"
|
||||
|
||||
|
||||
class EnhancedContextPlugin(sm.BasePlugin):
|
||||
@@ -69,6 +78,15 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
except LookupError as e:
|
||||
self.logger.error(f"Error downloading NLTK data: {e}")
|
||||
|
||||
# Add LLM personality traits for easter egg
|
||||
self.llm_personalities = [
|
||||
"You are a wise philosopher who speaks in riddles",
|
||||
"You are an excited scientist who loves discovering patterns",
|
||||
"You are a detective who analyzes every detail",
|
||||
"You are a poet who sees beauty in connections",
|
||||
"You are a historian who relates everything to the past",
|
||||
]
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Context manager for database connections"""
|
||||
@@ -81,13 +99,15 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
def init_db(self):
|
||||
"""Initialize the database with proper schema"""
|
||||
with self.get_connection() as conn:
|
||||
# Create memory table for entities
|
||||
# Modify memory table to include source
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory (
|
||||
entity TEXT PRIMARY KEY,
|
||||
entity TEXT,
|
||||
source TEXT, -- 'user' or 'llm'
|
||||
last_mentioned TIMESTAMP,
|
||||
mention_count INTEGER DEFAULT 1
|
||||
mention_count INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (entity, source)
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -115,45 +135,59 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
"""
|
||||
)
|
||||
|
||||
def store_entity(self, entity: str) -> None:
|
||||
"""Store or update entity mention with error handling"""
|
||||
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||
"""Store or update entity mention with source tracking"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
# Modified to store datetime in SQLite format
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory (entity, last_mentioned, mention_count)
|
||||
VALUES (?, ?, 1)
|
||||
ON CONFLICT(entity) DO UPDATE SET
|
||||
INSERT INTO memory (entity, source, last_mentioned, mention_count)
|
||||
VALUES (?, ?, ?, 1)
|
||||
ON CONFLICT(entity, source) DO UPDATE SET
|
||||
last_mentioned = ?,
|
||||
mention_count = mention_count + 1
|
||||
""",
|
||||
(entity, now, now),
|
||||
(entity, source, now, now),
|
||||
)
|
||||
conn.commit() # Added explicit commit
|
||||
self.logger.info(f"Stored entity: {entity}")
|
||||
conn.commit()
|
||||
self.logger.info(f"Stored {source} entity: {entity}")
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error while storing entity {entity}: {e}")
|
||||
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[str]:
|
||||
"""Retrieve recently mentioned entities with frequency"""
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||
"""Retrieve recently mentioned entities with frequency and source"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.cursor()
|
||||
# Modified query to handle datetime strings properly
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT entity, mention_count
|
||||
SELECT
|
||||
entity,
|
||||
SUM(mention_count) as total_mentions,
|
||||
GROUP_CONCAT(source || ':' || mention_count) as source_counts
|
||||
FROM memory
|
||||
WHERE last_mentioned >= datetime('now', ?, 'localtime')
|
||||
ORDER BY mention_count DESC, last_mentioned DESC
|
||||
GROUP BY entity
|
||||
ORDER BY total_mentions DESC, MAX(last_mentioned) DESC
|
||||
LIMIT 5
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
|
||||
entities = [(row[0], row[1]) for row in cur.fetchall()]
|
||||
entities = []
|
||||
for row in cur.fetchall():
|
||||
entity, total_count, source_counts = row
|
||||
source_dict = dict(sc.split(":") for sc in source_counts.split(","))
|
||||
entities.append(
|
||||
(
|
||||
entity,
|
||||
total_count,
|
||||
int(source_dict.get("user", 0)),
|
||||
int(source_dict.get("llm", 0)),
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.info(f"Retrieved recent entities: {entities}")
|
||||
return entities
|
||||
except sqlite3.Error as e:
|
||||
@@ -189,25 +223,26 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
def format_context_message(
|
||||
self, entities: List[tuple], include_identity: bool = True
|
||||
) -> str:
|
||||
"""Format context message more naturally"""
|
||||
"""Format context message with source awareness"""
|
||||
context_parts = []
|
||||
|
||||
# Add identity context if available and requested
|
||||
# Add identity context
|
||||
if include_identity and self.personal_identity:
|
||||
context_parts.append(
|
||||
f"The user's name is {self.personal_identity}. Remember to use their name naturally in conversation when appropriate."
|
||||
)
|
||||
context_parts.append(f"The user's name is {self.personal_identity}.")
|
||||
|
||||
# Add entity context if available
|
||||
# Add entity context with source awareness
|
||||
if entities:
|
||||
# Format entities with their mention counts
|
||||
entity_strings = [
|
||||
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
|
||||
for entity, count in entities
|
||||
]
|
||||
entity_strings = []
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
source_info = []
|
||||
if user_count > 0:
|
||||
source_info.append(f"{user_count} by user")
|
||||
if llm_count > 0:
|
||||
source_info.append(f"{llm_count} by assistant")
|
||||
entity_strings.append(f"{entity} (mentioned {', '.join(source_info)})")
|
||||
|
||||
context_parts.append(
|
||||
"Recent conversation history includes: "
|
||||
"Recent conversation topics: "
|
||||
+ (
|
||||
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||
if len(entity_strings) > 1
|
||||
@@ -215,13 +250,16 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
)
|
||||
)
|
||||
|
||||
# Add instructions for memory queries
|
||||
context_parts.append(
|
||||
"If the user asks about memories or what has been discussed, "
|
||||
"naturally incorporate the above context into your response."
|
||||
)
|
||||
# Add guidance for heavily mentioned entities
|
||||
heavy_mentions = [(e, t) for e, t, _, _ in entities if t > 3]
|
||||
if heavy_mentions:
|
||||
context_parts.append(
|
||||
"Note: Be mindful not to overuse "
|
||||
+ ", ".join(f"{e}" for e, _ in heavy_mentions)
|
||||
+ " as they have been frequently discussed."
|
||||
)
|
||||
|
||||
return " ".join(context_parts)
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def extract_identity(self, text: str) -> str | None:
|
||||
"""Extract identity statements like 'I am X'"""
|
||||
@@ -448,11 +486,11 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
for marker_type, markers in markers_by_type.items():
|
||||
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
|
||||
|
||||
# Add entity context
|
||||
# Add entity context - Fixed tuple unpacking
|
||||
if entities:
|
||||
entity_strings = [
|
||||
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
|
||||
for entity, count in entities
|
||||
f"{entity} (mentioned {total} {'times' if total > 1 else 'time'})"
|
||||
for entity, total, user_count, llm_count in entities
|
||||
]
|
||||
context_parts.append(
|
||||
"Recent conversation topics: "
|
||||
@@ -472,10 +510,10 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
self.logger.info(f"Processing user message: {last_message.text}")
|
||||
|
||||
# Extract and store entities
|
||||
# Extract and store entities with user source
|
||||
entities = self.extract_entities(last_message.text)
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
self.store_entity(entity, source="user")
|
||||
|
||||
# Extract and store essence markers
|
||||
essence_markers = self.extract_essence_markers(last_message.text)
|
||||
@@ -487,7 +525,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
recent_entities = self.retrieve_recent_entities(days=30)
|
||||
context_message = self.format_context_message(recent_entities)
|
||||
if context_message:
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
self.logger.info(f"Added context message: {context_message}")
|
||||
|
||||
return True
|
||||
@@ -500,13 +538,46 @@ class EnhancedContextPlugin(sm.BasePlugin):
|
||||
|
||||
self.logger.info(f"Processing assistant response: {last_response.text}")
|
||||
|
||||
# Extract and store entities from the LLM's response
|
||||
# Extract and store entities from the LLM's response with llm source
|
||||
entities = self.extract_entities(last_response.text)
|
||||
for entity in entities:
|
||||
self.store_entity(entity)
|
||||
self.store_entity(entity, source="llm")
|
||||
|
||||
return True
|
||||
|
||||
def simulate_llm_conversation(self, context: str, num_turns: int = 3) -> str:
|
||||
"""Simulate a conversation between multiple LLM personalities about the context"""
|
||||
conversation_log = []
|
||||
|
||||
def get_response(personality: str, previous_messages: str) -> str:
|
||||
prompt = (
|
||||
f"{personality}. You are participating in a brief group discussion "
|
||||
f"about the following context:\n{context}\n\n"
|
||||
f"Previous messages:\n{previous_messages}\n\n"
|
||||
"Provide a short, focused response (1-2 sentences) that builds on "
|
||||
"the discussion. Be creative but stay on topic."
|
||||
)
|
||||
|
||||
temp_conv = sm.create_conversation(
|
||||
llm_model=LLM_MODEL, llm_provider=LLM_PROVIDER
|
||||
)
|
||||
temp_conv.add_message(role="user", text=prompt)
|
||||
response = temp_conv.send()
|
||||
return response.text.strip()
|
||||
|
||||
# Select random personalities for this conversation
|
||||
selected_personalities = random.sample(
|
||||
self.llm_personalities, min(num_turns, len(self.llm_personalities))
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for i, personality in enumerate(selected_personalities, 1):
|
||||
previous = "\n".join(conversation_log)
|
||||
response = get_response(personality, previous)
|
||||
conversation_log.append(f"Speaker {i}: {response}")
|
||||
|
||||
return "\n\n".join(conversation_log)
|
||||
|
||||
|
||||
# Replace the example usage code at the bottom with this chat interface:
|
||||
def main():
|
||||
@@ -521,51 +592,70 @@ def main():
|
||||
recent_entities = plugin.retrieve_recent_entities()
|
||||
context_message = plugin.format_context_message(recent_entities)
|
||||
if context_message:
|
||||
conversation.add_message(role="system", text=context_message)
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
plugin.logger.info(f"Added initial context message: {context_message}")
|
||||
|
||||
console = Console()
|
||||
console.print(
|
||||
Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.")
|
||||
)
|
||||
print("-" * 50)
|
||||
md = """# Enhanced Context Chat Interface
|
||||
Type 'quit' to exit. Type 'go go go' for a special surprise!
|
||||
|
||||
---"""
|
||||
console.print(Markdown(md))
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input with colored prompt
|
||||
console.print("\n[bold blue]You:[/bold blue] ", end="")
|
||||
console.print(Markdown("**You:**"), end=" ")
|
||||
user_input = input().strip()
|
||||
|
||||
# Check for quit command
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
console.print("\n[bold green]Goodbye![/bold green]")
|
||||
console.print(Markdown("**Goodbye!**"))
|
||||
break
|
||||
|
||||
# Add user message and get response
|
||||
# Easter egg handling
|
||||
if user_input.lower() == "go go go":
|
||||
console.print(Markdown("## 🎉 Multi-LLM Discussion Initiated!"))
|
||||
recent_entities = plugin.retrieve_recent_entities()
|
||||
context = plugin.format_context_message(recent_entities)
|
||||
conversation_result = plugin.simulate_llm_conversation(context)
|
||||
|
||||
console.print(
|
||||
Markdown(
|
||||
f"""### Discussion Results
|
||||
*{conversation_result}*
|
||||
|
||||
---"""
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Regular conversation handling
|
||||
conversation.add_message(role="user", text=user_input)
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
# Only send to LLM if pre_send_hook returns True or None
|
||||
if should_continue is not False:
|
||||
response = conversation.send()
|
||||
# Add post-response processing
|
||||
plugin.post_response_hook(conversation)
|
||||
with Status("[bold blue]Thinking...", spinner="dots") as status:
|
||||
response = conversation.send()
|
||||
plugin.post_response_hook(conversation)
|
||||
console.print(
|
||||
"\n[bold purple]Assistant:[/bold purple]",
|
||||
Text(str(response.text), style="italic"),
|
||||
Markdown(
|
||||
f"""**Assistant:** *{response.text}*
|
||||
|
||||
---"""
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Get the last assistant message that was added by the plugin
|
||||
response = conversation.get_last_message(role="assistant")
|
||||
if response:
|
||||
console.print(
|
||||
"\n[bold purple]Assistant:[/bold purple]",
|
||||
Text(response.text, style="bold green"),
|
||||
Markdown(
|
||||
f"""**Assistant:** *{response.text}*
|
||||
|
||||
---"""
|
||||
)
|
||||
)
|
||||
|
||||
console.print(Text("-" * 50, style="dim"))
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[bold green]Goodbye![/bold green]")
|
||||
console.print(Markdown("\n**Goodbye!**"))
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user