From 30d8412bbfb4a9ffd2ca14cf9a5c249ddf2276a9 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Wed, 6 Nov 2024 09:57:27 -0500 Subject: [PATCH] Refactor LLM provider and model in enhanced_context.py --- examples/enhanced_context.py | 224 ++++++++++++++++++++++++----------- 1 file changed, 157 insertions(+), 67 deletions(-) diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index 6b21717..9446935 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -19,10 +19,19 @@ from rich import print from rich.console import Console from rich.panel import Panel from rich.text import Text +from rich.markdown import Markdown +from rich.status import Status + +from concurrent.futures import ThreadPoolExecutor +import random DB_PATH = "enhanced_context.db" -LLM_PROVIDER = "xai" -LLM_MODEL = "grok-beta" + +LLM_MODEL = "gpt-4o-mini" +LLM_PROVIDER = "openai" + +# LLM_PROVIDER = "xai" +# LLM_MODEL = "grok-beta" class EnhancedContextPlugin(sm.BasePlugin): @@ -69,6 +78,15 @@ class EnhancedContextPlugin(sm.BasePlugin): except LookupError as e: self.logger.error(f"Error downloading NLTK data: {e}") + # Add LLM personality traits for easter egg + self.llm_personalities = [ + "You are a wise philosopher who speaks in riddles", + "You are an excited scientist who loves discovering patterns", + "You are a detective who analyzes every detail", + "You are a poet who sees beauty in connections", + "You are a historian who relates everything to the past", + ] + @contextmanager def get_connection(self): """Context manager for database connections""" @@ -81,13 +99,15 @@ class EnhancedContextPlugin(sm.BasePlugin): def init_db(self): """Initialize the database with proper schema""" with self.get_connection() as conn: - # Create memory table for entities + # Modify memory table to include source conn.execute( """ CREATE TABLE IF NOT EXISTS memory ( - entity TEXT PRIMARY KEY, + entity TEXT, + source TEXT, -- 'user' or 'llm' last_mentioned TIMESTAMP, - mention_count INTEGER DEFAULT 1 + mention_count INTEGER DEFAULT 1, + PRIMARY KEY (entity, source) ) """ ) @@ -115,45 +135,59 @@ class EnhancedContextPlugin(sm.BasePlugin): """ ) - def store_entity(self, entity: str) -> None: - """Store or update entity mention with error handling""" + def store_entity(self, entity: str, source: str = "user") -> None: + """Store or update entity mention with source tracking""" try: with self.get_connection() as conn: - # Modified to store datetime in SQLite format now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conn.execute( """ - INSERT INTO memory (entity, last_mentioned, mention_count) - VALUES (?, ?, 1) - ON CONFLICT(entity) DO UPDATE SET + INSERT INTO memory (entity, source, last_mentioned, mention_count) + VALUES (?, ?, ?, 1) + ON CONFLICT(entity, source) DO UPDATE SET last_mentioned = ?, mention_count = mention_count + 1 """, - (entity, now, now), + (entity, source, now, now), ) - conn.commit() # Added explicit commit - self.logger.info(f"Stored entity: {entity}") + conn.commit() + self.logger.info(f"Stored {source} entity: {entity}") except sqlite3.Error as e: self.logger.error(f"Database error while storing entity {entity}: {e}") - def retrieve_recent_entities(self, days: int = 7) -> List[str]: - """Retrieve recently mentioned entities with frequency""" + def retrieve_recent_entities(self, days: int = 7) -> List[tuple]: + """Retrieve recently mentioned entities with frequency and source""" try: with self.get_connection() as conn: cur = conn.cursor() - # Modified query to handle datetime strings properly cur.execute( """ - SELECT entity, mention_count + SELECT + entity, + SUM(mention_count) as total_mentions, + GROUP_CONCAT(source || ':' || mention_count) as source_counts FROM memory WHERE last_mentioned >= datetime('now', ?, 'localtime') - ORDER BY mention_count DESC, last_mentioned DESC + GROUP BY entity + ORDER BY total_mentions DESC, MAX(last_mentioned) DESC LIMIT 5 """, (f"-{days} days",), ) - entities = [(row[0], row[1]) for row in cur.fetchall()] + entities = [] + for row in cur.fetchall(): + entity, total_count, source_counts = row + source_dict = dict(sc.split(":") for sc in source_counts.split(",")) + entities.append( + ( + entity, + total_count, + int(source_dict.get("user", 0)), + int(source_dict.get("llm", 0)), + ) + ) + self.logger.info(f"Retrieved recent entities: {entities}") return entities except sqlite3.Error as e: @@ -189,25 +223,26 @@ class EnhancedContextPlugin(sm.BasePlugin): def format_context_message( self, entities: List[tuple], include_identity: bool = True ) -> str: - """Format context message more naturally""" + """Format context message with source awareness""" context_parts = [] - # Add identity context if available and requested + # Add identity context if include_identity and self.personal_identity: - context_parts.append( - f"The user's name is {self.personal_identity}. Remember to use their name naturally in conversation when appropriate." - ) + context_parts.append(f"The user's name is {self.personal_identity}.") - # Add entity context if available + # Add entity context with source awareness if entities: - # Format entities with their mention counts - entity_strings = [ - f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})" - for entity, count in entities - ] + entity_strings = [] + for entity, total, user_count, llm_count in entities: + source_info = [] + if user_count > 0: + source_info.append(f"{user_count} by user") + if llm_count > 0: + source_info.append(f"{llm_count} by assistant") + entity_strings.append(f"{entity} (mentioned {', '.join(source_info)})") context_parts.append( - "Recent conversation history includes: " + "Recent conversation topics: " + ( ", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}" if len(entity_strings) > 1 @@ -215,13 +250,16 @@ class EnhancedContextPlugin(sm.BasePlugin): ) ) - # Add instructions for memory queries - context_parts.append( - "If the user asks about memories or what has been discussed, " - "naturally incorporate the above context into your response." - ) + # Add guidance for heavily mentioned entities + heavy_mentions = [(e, t) for e, t, _, _ in entities if t > 3] + if heavy_mentions: + context_parts.append( + "Note: Be mindful not to overuse " + + ", ".join(f"{e}" for e, _ in heavy_mentions) + + " as they have been frequently discussed." + ) - return " ".join(context_parts) + return "\n".join(context_parts) def extract_identity(self, text: str) -> str | None: """Extract identity statements like 'I am X'""" @@ -448,11 +486,11 @@ class EnhancedContextPlugin(sm.BasePlugin): for marker_type, markers in markers_by_type.items(): context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}") - # Add entity context + # Add entity context - Fixed tuple unpacking if entities: entity_strings = [ - f"{entity} (mentioned {count} {'times' if count > 1 else 'time'})" - for entity, count in entities + f"{entity} (mentioned {total} {'times' if total > 1 else 'time'})" + for entity, total, user_count, llm_count in entities ] context_parts.append( "Recent conversation topics: " @@ -472,10 +510,10 @@ class EnhancedContextPlugin(sm.BasePlugin): self.logger.info(f"Processing user message: {last_message.text}") - # Extract and store entities + # Extract and store entities with user source entities = self.extract_entities(last_message.text) for entity in entities: - self.store_entity(entity) + self.store_entity(entity, source="user") # Extract and store essence markers essence_markers = self.extract_essence_markers(last_message.text) @@ -487,7 +525,7 @@ class EnhancedContextPlugin(sm.BasePlugin): recent_entities = self.retrieve_recent_entities(days=30) context_message = self.format_context_message(recent_entities) if context_message: - conversation.add_message(role="system", text=context_message) + conversation.add_message(role="user", text=context_message) self.logger.info(f"Added context message: {context_message}") return True @@ -500,13 +538,46 @@ class EnhancedContextPlugin(sm.BasePlugin): self.logger.info(f"Processing assistant response: {last_response.text}") - # Extract and store entities from the LLM's response + # Extract and store entities from the LLM's response with llm source entities = self.extract_entities(last_response.text) for entity in entities: - self.store_entity(entity) + self.store_entity(entity, source="llm") return True + def simulate_llm_conversation(self, context: str, num_turns: int = 3) -> str: + """Simulate a conversation between multiple LLM personalities about the context""" + conversation_log = [] + + def get_response(personality: str, previous_messages: str) -> str: + prompt = ( + f"{personality}. You are participating in a brief group discussion " + f"about the following context:\n{context}\n\n" + f"Previous messages:\n{previous_messages}\n\n" + "Provide a short, focused response (1-2 sentences) that builds on " + "the discussion. Be creative but stay on topic." + ) + + temp_conv = sm.create_conversation( + llm_model=LLM_MODEL, llm_provider=LLM_PROVIDER + ) + temp_conv.add_message(role="user", text=prompt) + response = temp_conv.send() + return response.text.strip() + + # Select random personalities for this conversation + selected_personalities = random.sample( + self.llm_personalities, min(num_turns, len(self.llm_personalities)) + ) + + with ThreadPoolExecutor() as executor: + for i, personality in enumerate(selected_personalities, 1): + previous = "\n".join(conversation_log) + response = get_response(personality, previous) + conversation_log.append(f"Speaker {i}: {response}") + + return "\n\n".join(conversation_log) + # Replace the example usage code at the bottom with this chat interface: def main(): @@ -521,51 +592,70 @@ def main(): recent_entities = plugin.retrieve_recent_entities() context_message = plugin.format_context_message(recent_entities) if context_message: - conversation.add_message(role="system", text=context_message) + conversation.add_message(role="user", text=context_message) plugin.logger.info(f"Added initial context message: {context_message}") console = Console() - console.print( - Panel("[bold green]Chat interface ready![/bold green] Type 'quit' to exit.") - ) - print("-" * 50) + md = """# Enhanced Context Chat Interface +Type 'quit' to exit. Type 'go go go' for a special surprise! + +---""" + console.print(Markdown(md)) try: while True: - # Get user input with colored prompt - console.print("\n[bold blue]You:[/bold blue] ", end="") + console.print(Markdown("**You:**"), end=" ") user_input = input().strip() - # Check for quit command if user_input.lower() in ["quit", "exit", "q"]: - console.print("\n[bold green]Goodbye![/bold green]") + console.print(Markdown("**Goodbye!**")) break - # Add user message and get response + # Easter egg handling + if user_input.lower() == "go go go": + console.print(Markdown("## 🎉 Multi-LLM Discussion Initiated!")) + recent_entities = plugin.retrieve_recent_entities() + context = plugin.format_context_message(recent_entities) + conversation_result = plugin.simulate_llm_conversation(context) + + console.print( + Markdown( + f"""### Discussion Results +*{conversation_result}* + +---""" + ) + ) + continue + + # Regular conversation handling conversation.add_message(role="user", text=user_input) should_continue = plugin.pre_send_hook(conversation) - # Only send to LLM if pre_send_hook returns True or None if should_continue is not False: - response = conversation.send() - # Add post-response processing - plugin.post_response_hook(conversation) + with Status("[bold blue]Thinking...", spinner="dots") as status: + response = conversation.send() + plugin.post_response_hook(conversation) console.print( - "\n[bold purple]Assistant:[/bold purple]", - Text(str(response.text), style="italic"), + Markdown( + f"""**Assistant:** *{response.text}* + +---""" + ) ) else: - # Get the last assistant message that was added by the plugin response = conversation.get_last_message(role="assistant") if response: console.print( - "\n[bold purple]Assistant:[/bold purple]", - Text(response.text, style="bold green"), + Markdown( + f"""**Assistant:** *{response.text}* + +---""" + ) ) - console.print(Text("-" * 50, style="dim")) except KeyboardInterrupt: - console.print("\n\n[bold green]Goodbye![/bold green]") + console.print(Markdown("\n**Goodbye!**")) return