Refactor LLM provider and model in enhanced_context.py

This commit is contained in:
2024-11-06 09:57:27 -05:00
parent 4a852e6220
commit 30d8412bbf
+157 -67
View File
@@ -19,10 +19,19 @@ from rich import print
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.status import Status
from concurrent.futures import ThreadPoolExecutor
import random
DB_PATH = "enhanced_context.db"
LLM_PROVIDER = "xai"
LLM_MODEL = "grok-beta"
LLM_MODEL = "gpt-4o-mini"
LLM_PROVIDER = "openai"
# LLM_PROVIDER = "xai"
# LLM_MODEL = "grok-beta"
class EnhancedContextPlugin(sm.BasePlugin):
@@ -69,6 +78,15 @@ class EnhancedContextPlugin(sm.BasePlugin):
except LookupError as e:
self.logger.error(f"Error downloading NLTK data: {e}")
# Add LLM personality traits for easter egg
self.llm_personalities = [
"You are a wise philosopher who speaks in riddles",
"You are an excited scientist who loves discovering patterns",
"You are a detective who analyzes every detail",
"You are a poet who sees beauty in connections",
"You are a historian who relates everything to the past",
]
@contextmanager
def get_connection(self):
"""Context manager for database connections"""
@@ -81,13 +99,15 @@ class EnhancedContextPlugin(sm.BasePlugin):
def init_db(self):
"""Initialize the database with proper schema"""
with self.get_connection() as conn:
# Create memory table for entities
# Modify memory table to include source
conn.execute(
"""
CREATE TABLE IF NOT EXISTS memory (
entity TEXT PRIMARY KEY,
entity TEXT,
source TEXT, -- 'user' or 'llm'
last_mentioned TIMESTAMP,
mention_count INTEGER DEFAULT 1
mention_count INTEGER DEFAULT 1,
PRIMARY KEY (entity, source)
)
"""
)
@@ -115,45 +135,59 @@ class EnhancedContextPlugin(sm.BasePlugin):
"""
)
def store_entity(self, entity: str) -> None:
"""Store or update entity mention with error handling"""
def store_entity(self, entity: str, source: str = "user") -> None:
"""Store or update entity mention with source tracking"""
try:
with self.get_connection() as conn:
# Modified to store datetime in SQLite format
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conn.execute(
"""
INSERT INTO memory (entity, last_mentioned, mention_count)
VALUES (?, ?, 1)
ON CONFLICT(entity) DO UPDATE SET
INSERT INTO memory (entity, source, last_mentioned, mention_count)
VALUES (?, ?, ?, 1)
ON CONFLICT(entity, source) DO UPDATE SET
last_mentioned = ?,
mention_count = mention_count + 1
""",
(entity, now, now),
(entity, source, now, now),
)
conn.commit() # Added explicit commit
self.logger.info(f"Stored entity: {entity}")
conn.commit()
self.logger.info(f"Stored {source} entity: {entity}")
except sqlite3.Error as e:
self.logger.error(f"Database error while storing entity {entity}: {e}")
def retrieve_recent_entities(self, days: int = 7) -> List[str]:
"""Retrieve recently mentioned entities with frequency"""
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
"""Retrieve recently mentioned entities with frequency and source"""
try:
with self.get_connection() as conn:
cur = conn.cursor()
# Modified query to handle datetime strings properly
cur.execute(
"""
SELECT entity, mention_count
SELECT
entity,
SUM(mention_count) as total_mentions,
GROUP_CONCAT(source || ':' || mention_count) as source_counts
FROM memory
WHERE last_mentioned >= datetime('now', ?, 'localtime')
ORDER BY mention_count DESC, last_mentioned DESC
GROUP BY entity
ORDER BY total_mentions DESC, MAX(last_mentioned) DESC
LIMIT 5
""",
(f"-{days} days",),
)
entities = [(row[0], row[1]) for row in cur.fetchall()]
entities = []
for row in cur.fetchall():
entity, total_count, source_counts = row
source_dict = dict(sc.split(":") for sc in source_counts.split(","))
entities.append(
(
entity,
total_count,
int(source_dict.get("user", 0)),
int(source_dict.get("llm", 0)),
)
)
self.logger.info(f"Retrieved recent entities: {entities}")
return entities
except sqlite3.Error as e:
@@ -189,25 +223,26 @@ class EnhancedContextPlugin(sm.BasePlugin):
def format_context_message(
self, entities: List[tuple], include_identity: bool = True
) -> str:
"""Format context message more naturally"""
"""Format context message with source awareness"""
context_parts = []
# Add identity context if available and requested
# Add identity context
if include_identity and self.personal_identity:
context_parts.append(
f"The user's name is {self.personal_identity}. Remember to use their name naturally in conversation when appropriate."
)
context_parts.append(f"The user's name is {self.personal_identity}.")
# Add entity context if available
# Add entity context with source awareness
if entities:
# Format entities with their mention counts
entity_strings = [
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
for entity, count in entities
]
entity_strings = []
for entity, total, user_count, llm_count in entities:
source_info = []
if user_count > 0:
source_info.append(f"{user_count} by user")
if llm_count > 0:
source_info.append(f"{llm_count} by assistant")
entity_strings.append(f"{entity} (mentioned {', '.join(source_info)})")
context_parts.append(
"Recent conversation history includes: "
"Recent conversation topics: "
+ (
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
if len(entity_strings) > 1
@@ -215,13 +250,16 @@ class EnhancedContextPlugin(sm.BasePlugin):
)
)
# Add instructions for memory queries
context_parts.append(
"If the user asks about memories or what has been discussed, "
"naturally incorporate the above context into your response."
)
# Add guidance for heavily mentioned entities
heavy_mentions = [(e, t) for e, t, _, _ in entities if t > 3]
if heavy_mentions:
context_parts.append(
"Note: Be mindful not to overuse "
+ ", ".join(f"{e}" for e, _ in heavy_mentions)
+ " as they have been frequently discussed."
)
return " ".join(context_parts)
return "\n".join(context_parts)
def extract_identity(self, text: str) -> str | None:
"""Extract identity statements like 'I am X'"""
@@ -448,11 +486,11 @@ class EnhancedContextPlugin(sm.BasePlugin):
for marker_type, markers in markers_by_type.items():
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
# Add entity context
# Add entity context - Fixed tuple unpacking
if entities:
entity_strings = [
f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})"
for entity, count in entities
f"{entity} (mentioned {total} {'times' if total > 1 else 'time'})"
for entity, total, user_count, llm_count in entities
]
context_parts.append(
"Recent conversation topics: "
@@ -472,10 +510,10 @@ class EnhancedContextPlugin(sm.BasePlugin):
self.logger.info(f"Processing user message: {last_message.text}")
# Extract and store entities
# Extract and store entities with user source
entities = self.extract_entities(last_message.text)
for entity in entities:
self.store_entity(entity)
self.store_entity(entity, source="user")
# Extract and store essence markers
essence_markers = self.extract_essence_markers(last_message.text)
@@ -487,7 +525,7 @@ class EnhancedContextPlugin(sm.BasePlugin):
recent_entities = self.retrieve_recent_entities(days=30)
context_message = self.format_context_message(recent_entities)
if context_message:
conversation.add_message(role="system", text=context_message)
conversation.add_message(role="user", text=context_message)
self.logger.info(f"Added context message: {context_message}")
return True
@@ -500,13 +538,46 @@ class EnhancedContextPlugin(sm.BasePlugin):
self.logger.info(f"Processing assistant response: {last_response.text}")
# Extract and store entities from the LLM's response
# Extract and store entities from the LLM's response with llm source
entities = self.extract_entities(last_response.text)
for entity in entities:
self.store_entity(entity)
self.store_entity(entity, source="llm")
return True
def simulate_llm_conversation(self, context: str, num_turns: int = 3) -> str:
"""Simulate a conversation between multiple LLM personalities about the context"""
conversation_log = []
def get_response(personality: str, previous_messages: str) -> str:
prompt = (
f"{personality}. You are participating in a brief group discussion "
f"about the following context:\n{context}\n\n"
f"Previous messages:\n{previous_messages}\n\n"
"Provide a short, focused response (1-2 sentences) that builds on "
"the discussion. Be creative but stay on topic."
)
temp_conv = sm.create_conversation(
llm_model=LLM_MODEL, llm_provider=LLM_PROVIDER
)
temp_conv.add_message(role="user", text=prompt)
response = temp_conv.send()
return response.text.strip()
# Select random personalities for this conversation
selected_personalities = random.sample(
self.llm_personalities, min(num_turns, len(self.llm_personalities))
)
with ThreadPoolExecutor() as executor:
for i, personality in enumerate(selected_personalities, 1):
previous = "\n".join(conversation_log)
response = get_response(personality, previous)
conversation_log.append(f"Speaker {i}: {response}")
return "\n\n".join(conversation_log)
# Replace the example usage code at the bottom with this chat interface:
def main():
@@ -521,51 +592,70 @@ def main():
recent_entities = plugin.retrieve_recent_entities()
context_message = plugin.format_context_message(recent_entities)
if context_message:
conversation.add_message(role="system", text=context_message)
conversation.add_message(role="user", text=context_message)
plugin.logger.info(f"Added initial context message: {context_message}")
console = Console()
console.print(
Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.")
)
print("-" * 50)
md = """# Enhanced Context Chat Interface
Type 'quit' to exit. Type 'go go go' for a special surprise!
---"""
console.print(Markdown(md))
try:
while True:
# Get user input with colored prompt
console.print("\n[bold blue]You:[/bold blue] ", end="")
console.print(Markdown("**You:**"), end=" ")
user_input = input().strip()
# Check for quit command
if user_input.lower() in ["quit", "exit", "q"]:
console.print("\n[bold green]Goodbye![/bold green]")
console.print(Markdown("**Goodbye!**"))
break
# Add user message and get response
# Easter egg handling
if user_input.lower() == "go go go":
console.print(Markdown("## 🎉 Multi-LLM Discussion Initiated!"))
recent_entities = plugin.retrieve_recent_entities()
context = plugin.format_context_message(recent_entities)
conversation_result = plugin.simulate_llm_conversation(context)
console.print(
Markdown(
f"""### Discussion Results
*{conversation_result}*
---"""
)
)
continue
# Regular conversation handling
conversation.add_message(role="user", text=user_input)
should_continue = plugin.pre_send_hook(conversation)
# Only send to LLM if pre_send_hook returns True or None
if should_continue is not False:
response = conversation.send()
# Add post-response processing
plugin.post_response_hook(conversation)
with Status("[bold blue]Thinking...", spinner="dots") as status:
response = conversation.send()
plugin.post_response_hook(conversation)
console.print(
"\n[bold purple]Assistant:[/bold purple]",
Text(str(response.text), style="italic"),
Markdown(
f"""**Assistant:** *{response.text}*
---"""
)
)
else:
# Get the last assistant message that was added by the plugin
response = conversation.get_last_message(role="assistant")
if response:
console.print(
"\n[bold purple]Assistant:[/bold purple]",
Text(response.text, style="bold green"),
Markdown(
f"""**Assistant:** *{response.text}*
---"""
)
)
console.print(Text("-" * 50, style="dim"))
except KeyboardInterrupt:
console.print("\n\n[bold green]Goodbye![/bold green]")
console.print(Markdown("\n**Goodbye!**"))
return