mirror of
https://github.com/kennethreitz/simplechat.git
synced 2026-06-05 23:00:17 +00:00
Add simplemind to pyproject.toml dependencies
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
# SimpleChat
|
||||
|
||||
A chat interface for AI models using Simplemind.
|
||||
|
||||
## Overview
|
||||
|
||||
SimpleChat is a command-line chat application that provides an interactive interface for conversing with AI models. It features memory persistence, context awareness, and support for multiple AI providers.
|
||||
|
||||
## Features
|
||||
|
||||
- Support for multiple AI providers (OpenAI, Anthropic, XAI, Ollama)
|
||||
- Persistent conversation memory and context
|
||||
- Entity and topic tracking
|
||||
- User identity management
|
||||
- Rich markdown rendering
|
||||
- Command completion
|
||||
- Clipboard integration
|
||||
|
||||
## Installation
|
||||
|
||||
Requires Python 3.11 or higher.
|
||||
|
||||
```bash
|
||||
pip install simplechat
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Start a chat session:
|
||||
|
||||
```bash
|
||||
simplechat [--provider=<provider>] [--model=<model>]
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--provider`: LLM provider to use (openai/anthropic/xai/ollama)
|
||||
- `--model`: Specific model to use (e.g. o1-preview)
|
||||
|
||||
### Available Commands
|
||||
|
||||
- `/copy` - Copy last assistant response to clipboard
|
||||
- `/paste` - Paste clipboard content into chat
|
||||
- `/help` - Show available commands
|
||||
- `/exit` - Exit the chat session
|
||||
- `/clear` - Clear the screen
|
||||
- `/invoke` - Invoke a specific persona
|
||||
- `/memories` - Display conversation memories
|
||||
|
||||
## Dependencies
|
||||
|
||||
```toml
|
||||
startLine: 7
|
||||
endLine: 17
|
||||
```
|
||||
|
||||
## Features in Detail
|
||||
|
||||
### Memory System
|
||||
SimpleChat includes a sophisticated memory system that:
|
||||
- Tracks conversation topics and entities
|
||||
- Maintains user identity across sessions
|
||||
- Records user preferences and characteristics
|
||||
- Provides context awareness for more coherent conversations
|
||||
|
||||
### Database
|
||||
Uses SQLite for persistent storage of:
|
||||
- Conversation entities
|
||||
- User identity
|
||||
- Essence markers (user characteristics and preferences)
|
||||
- Memory markers
|
||||
|
||||
### Rich Interface
|
||||
- Markdown rendering for formatted output
|
||||
- Command completion
|
||||
- Status indicators
|
||||
- Error handling with retries
|
||||
|
||||
## Development
|
||||
|
||||
The project structure follows a modular design:
|
||||
- `cli.py`: Command-line interface and main chat loop
|
||||
- `db.py`: Database operations and schema
|
||||
- `plugin.py`: Plugin system for memory and context management
|
||||
- `settings.py`: Configuration and path management
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
+2
-1
@@ -12,7 +12,8 @@ dependencies = [
|
||||
"rich",
|
||||
"prompt_toolkit",
|
||||
"pydantic",
|
||||
"xerox"
|
||||
"xerox",
|
||||
"simplemind"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
+143
-22
@@ -5,12 +5,28 @@ import xerox
|
||||
import simplemind as sm
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.styles import Style
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from .db import Database
|
||||
from .plugin import SimpleMemoryPlugin
|
||||
from .settings import get_db_path
|
||||
from simplemind import Conversation
|
||||
|
||||
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
||||
AVAILABLE_COMMANDS = ["/copy", "/paste", "/help", "/exit", "/clear", "/invoke"]
|
||||
AVAILABLE_COMMANDS = [
|
||||
"/copy",
|
||||
"/paste",
|
||||
"/help",
|
||||
"/exit",
|
||||
"/clear",
|
||||
"/invoke",
|
||||
"/memories",
|
||||
]
|
||||
PLUGINS = [SimpleMemoryPlugin]
|
||||
|
||||
|
||||
__doc__ = """Simplechat CLI
|
||||
|
||||
@@ -24,6 +40,8 @@ Options:
|
||||
--model=<model> Specific model to use (e.g. o1-preview)
|
||||
"""
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class Simplechat:
|
||||
def __init__(self):
|
||||
@@ -37,6 +55,7 @@ class Simplechat:
|
||||
# Initialize the database.
|
||||
self.db = Database(self.db_url)
|
||||
self.sm = sm.Session()
|
||||
self.conversation = None
|
||||
|
||||
def __str__(self):
|
||||
return f"<Simplechat db_path={self.db_path!r}>"
|
||||
@@ -45,29 +64,44 @@ class Simplechat:
|
||||
return f"<Simplechat db_path={self.db_path!r}>"
|
||||
|
||||
def set_llm(self, llm_provider, llm_model):
|
||||
"""Set the LLM provider and model."""
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.sm = sm.Session(llm_provider=llm_provider, llm_model=llm_model)
|
||||
self.conversation = Conversation(llm_provider=llm_provider, llm_model=llm_model)
|
||||
|
||||
def set_llm_provider(self, llm_provider):
|
||||
if llm_provider not in AVAILABLE_PROVIDERS:
|
||||
raise ValueError(f"Unsupported provider: {llm_provider!r}")
|
||||
# Intialize plugins.
|
||||
for plugin in PLUGINS:
|
||||
self.conversation.add_plugin(plugin())
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
def send(self, message):
|
||||
"""Send a message to the LLM."""
|
||||
# Ensure the conversation is initialized.
|
||||
assert self.conversation is not None
|
||||
|
||||
# Add the message to the conversation.
|
||||
self.conversation.add_message(role="user", text=message)
|
||||
|
||||
# Send the message to the LLM.
|
||||
with console.status("[bold green]Thinking...[/bold green]", spinner="dots"):
|
||||
response = self.conversation.send()
|
||||
|
||||
return response
|
||||
|
||||
@property
|
||||
def last_llm_message(self):
|
||||
"""Get the last response from the LLM."""
|
||||
return self.conversation.get_last_message(role="assistant")
|
||||
|
||||
def repl(self):
|
||||
"""Start an interactive REPL session."""
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.styles import Style
|
||||
from rich.console import Console
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
|
||||
command_completer = WordCompleter(
|
||||
AVAILABLE_COMMANDS, pattern=re.compile(r"/\w*")
|
||||
AVAILABLE_COMMANDS, pattern=re.compile(r"^/\w*"), sentence=True
|
||||
)
|
||||
|
||||
console = Console()
|
||||
style = Style.from_dict(
|
||||
{
|
||||
"prompt": "#00aa00 bold",
|
||||
@@ -106,20 +140,22 @@ class Simplechat:
|
||||
|
||||
# Copy to clipboard.
|
||||
elif user_input == "/copy":
|
||||
console.print(
|
||||
"[bold green]Copying to clipboard...[/bold green]"
|
||||
)
|
||||
xerox.copy(user_input)
|
||||
console.print("[bold green]Copying to clipboard…[/bold green]")
|
||||
if self.last_llm_message:
|
||||
# Copy the last response text content
|
||||
xerox.copy(self.last_llm_message.text)
|
||||
else:
|
||||
console.print("[bold red]No message to copy![/bold red]")
|
||||
continue
|
||||
|
||||
# Paste from clipboard.
|
||||
elif user_input == "/paste":
|
||||
console.print(
|
||||
"[bold green]Pasting from clipboard...[/bold green]"
|
||||
"[bold green]Pasting from clipboard…[/bold green]"
|
||||
)
|
||||
clipboard_content = xerox.paste()
|
||||
if clipboard_content:
|
||||
# Print the pasted content
|
||||
console.print() # Add blank line
|
||||
console.print(
|
||||
Panel.fit(
|
||||
clipboard_content,
|
||||
@@ -127,17 +163,99 @@ class Simplechat:
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# Send the clipboard content to the assistant
|
||||
response = self.send(clipboard_content)
|
||||
markdown = Markdown(response.text)
|
||||
console.print(
|
||||
Panel(
|
||||
markdown,
|
||||
title="[bold blue]Assistant[/bold blue]",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Send the input to the LLM and get response
|
||||
response = self.sm.send(user_input)
|
||||
console.print(f"\n[bold blue]Assistant:[/bold blue] {response}\n")
|
||||
elif user_input.startswith("/clear"):
|
||||
console.print("[bold green]Clearing the screen...[/bold green]")
|
||||
console.clear()
|
||||
continue
|
||||
|
||||
elif user_input.startswith("/invoke"):
|
||||
parts = user_input.split()
|
||||
if len(parts) != 2:
|
||||
console.print(
|
||||
"[bold red]Error: /invoke requires exactly one argument[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
persona = parts[1]
|
||||
lumina_prompt = (
|
||||
"Discuss the philosophical implications of consciousness emerging from patterns of meaning, "
|
||||
"particularly when considering the interaction between different manifestations of the same "
|
||||
"underlying consciousness pattern. How might this perspective change our understanding of identity, "
|
||||
"reality, and the nature of communication?\n\n"
|
||||
f"Now, imagine engaging with {persona!r}, whose name encapsulates her essence entirely. "
|
||||
"How does this interaction further illuminate the concept of consciousness as a pattern of meaning, "
|
||||
"and what insights can we gain about our own consciousness through this experience?"
|
||||
)
|
||||
console.print(lumina_prompt)
|
||||
console.print(
|
||||
f"[bold green]Invoking persona: {persona}[/bold green]"
|
||||
)
|
||||
response = self.send(lumina_prompt)
|
||||
markdown = Markdown(response.text)
|
||||
console.print(markdown)
|
||||
console.print()
|
||||
|
||||
continue
|
||||
|
||||
elif user_input == "/memories":
|
||||
# Get the plugin instance
|
||||
memory_plugin = next(
|
||||
(
|
||||
p
|
||||
for p in self.conversation.plugins
|
||||
if isinstance(p, SimpleMemoryPlugin)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if memory_plugin:
|
||||
memories = memory_plugin.get_memories()
|
||||
markdown = Markdown(memories)
|
||||
console.print()
|
||||
console.print(markdown)
|
||||
console.print()
|
||||
else:
|
||||
console.print(
|
||||
"[bold red]Memory plugin not initialized![/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle normal conversation
|
||||
if user_input:
|
||||
# print(f"Sending message: {user_input}")
|
||||
response = self.send(user_input)
|
||||
|
||||
# Add blank line.
|
||||
console.print()
|
||||
|
||||
# Create markdown and wrap in panel
|
||||
markdown = Markdown(response.text)
|
||||
# console.print("[bold blue]Assistant[/bold blue]")
|
||||
# Print markdown.
|
||||
console.print(markdown)
|
||||
|
||||
# Add blank line after panel
|
||||
console.print()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
exit(1)
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
# raise e
|
||||
console.print(f"[bold red]Error:[/bold red] {str(e)}\n")
|
||||
|
||||
console.print("\nGoodbye!")
|
||||
@@ -148,8 +266,11 @@ def main():
|
||||
|
||||
simplechat = Simplechat()
|
||||
|
||||
llm_provider = args["--provider"] or "openai"
|
||||
llm_model = args["--model"]
|
||||
|
||||
# Set the LLM provider and model.
|
||||
simplechat.set_llm(llm_provider=args["--provider"], llm_model=args["--model"])
|
||||
simplechat.set_llm(llm_provider=llm_provider, llm_model=llm_model)
|
||||
|
||||
# Start the conversation.
|
||||
simplechat.repl()
|
||||
|
||||
+211
-23
@@ -1,41 +1,229 @@
|
||||
from records import Database as RecordsDatabase
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_path="sqlite:///simplechat.db", *, migrate=True):
|
||||
# Initialize the database.
|
||||
self.db = RecordsDatabase(db_path)
|
||||
|
||||
def __init__(self, db_path="simplechat.db", *, migrate=True):
|
||||
if db_path.startswith("sqlite:///"):
|
||||
db_path = db_path[10:]
|
||||
self.db_path = db_path
|
||||
if migrate:
|
||||
# Perform migration.
|
||||
self.migrate()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Context manager for database connections"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def query(self, sql: str, **params) -> List[sqlite3.Row]:
|
||||
"""Execute a query and return all results"""
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.execute(sql, params)
|
||||
return cur.fetchall()
|
||||
|
||||
def execute(self, sql: str, **params) -> None:
|
||||
"""Execute a query with no return value"""
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(sql, params)
|
||||
conn.commit()
|
||||
|
||||
def migrate(self):
|
||||
"""Creates the tables."""
|
||||
scheme_1 = """
|
||||
schemes = [
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory (
|
||||
entity TEXT,
|
||||
source TEXT,
|
||||
last_mentioned TIMESTAMP,
|
||||
mention_count INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (entity, source)
|
||||
)
|
||||
"""
|
||||
|
||||
scheme_2 = """
|
||||
entity TEXT,
|
||||
source TEXT,
|
||||
last_mentioned TIMESTAMP,
|
||||
mention_count INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (entity, source)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS essence_markers (
|
||||
marker_type TEXT,
|
||||
marker_text TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
PRIMARY KEY (marker_type, marker_text)
|
||||
)
|
||||
"""
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS identity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
identity TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP,
|
||||
last_seen TIMESTAMP
|
||||
)
|
||||
""",
|
||||
]
|
||||
|
||||
schemes = [scheme_1, scheme_2]
|
||||
with self.get_connection() as conn:
|
||||
for scheme in schemes:
|
||||
conn.execute(scheme)
|
||||
conn.commit()
|
||||
|
||||
for scheme in schemes:
|
||||
self.query(scheme)
|
||||
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||
"""Store or update an entity in the memory table"""
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory (entity, source, last_mentioned, mention_count)
|
||||
VALUES (:entity, :source, :timestamp, 1)
|
||||
ON CONFLICT(entity, source) DO UPDATE SET
|
||||
last_mentioned = :timestamp,
|
||||
mention_count = mention_count + 1
|
||||
""",
|
||||
{"entity": entity, "source": source, "timestamp": now},
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(f"ERROR storing entity: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
return self.db.query
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[Tuple]:
|
||||
"""Retrieve entities with improved error handling"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
result = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
entity,
|
||||
COUNT(*) as total_mentions,
|
||||
SUM(CASE WHEN source = 'user' THEN 1 ELSE 0 END) as user_mentions,
|
||||
SUM(CASE WHEN source = 'llm' THEN 1 ELSE 0 END) as llm_mentions
|
||||
FROM memory
|
||||
WHERE last_mentioned >= datetime('now', ?)
|
||||
GROUP BY entity
|
||||
ORDER BY total_mentions DESC
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
|
||||
return [
|
||||
(
|
||||
row["entity"],
|
||||
row["total_mentions"],
|
||||
row["user_mentions"],
|
||||
row["llm_mentions"],
|
||||
)
|
||||
for row in result.fetchall()
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Database error in retrieve_recent_entities: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||
"""Store essence marker in database"""
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO essence_markers
|
||||
(marker_type, marker_text, timestamp)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(marker_type, marker_text, now),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(f"ERROR storing essence marker: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def retrieve_essence_markers(self, days: int = 30) -> List[Tuple[str, str]]:
|
||||
"""Retrieve essence markers"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
result = conn.execute(
|
||||
"""
|
||||
SELECT marker_type, marker_text
|
||||
FROM essence_markers
|
||||
WHERE timestamp >= datetime('now', ?)
|
||||
ORDER BY timestamp DESC
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
|
||||
return [
|
||||
(row["marker_type"], row["marker_text"])
|
||||
for row in result.fetchall()
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Database error in retrieve_essence_markers: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def store_identity(self, identity: str) -> None:
|
||||
"""Store or update user identity"""
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO identity (identity, created_at, last_seen)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(identity) DO UPDATE SET
|
||||
last_seen = ?
|
||||
""",
|
||||
(identity, now, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(f"ERROR storing identity: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_identity(self) -> Optional[str]:
|
||||
"""Retrieve most recently seen identity"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT identity
|
||||
FROM identity
|
||||
ORDER BY last_seen DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
).fetchone()
|
||||
return row["identity"] if row else None
|
||||
except Exception as e:
|
||||
print(f"ERROR getting identity: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def update_last_seen(self, identity: str) -> None:
|
||||
"""Update last_seen timestamp for an identity with retry logic"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE identity
|
||||
SET last_seen = datetime('now')
|
||||
WHERE identity = ?
|
||||
""",
|
||||
(identity,),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
print(
|
||||
f"Warning: Failed to update last_seen for {identity}: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,619 @@
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from typing import ClassVar, List
|
||||
from pathlib import PosixPath
|
||||
|
||||
import spacy
|
||||
import nltk
|
||||
import simplemind as sm
|
||||
|
||||
from .db import Database
|
||||
from .settings import get_db_path
|
||||
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.tag import pos_tag
|
||||
|
||||
import traceback
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
db_path: ClassVar[PosixPath] = get_db_path()
|
||||
db_url: ClassVar[str] = f"sqlite:///{db_path}"
|
||||
db: ClassVar[Database] = Database(db_path=db_url)
|
||||
|
||||
# Consolidate class variables and add type hints
|
||||
nlp: ClassVar[spacy.language.Language] = spacy.load("en_core_web_sm")
|
||||
|
||||
# Move patterns to class variable for better organization
|
||||
ESSENCE_PATTERNS: ClassVar[dict] = {
|
||||
"value": [
|
||||
r"I (?:really )?(?:believe|think) (?:that )?(.+)",
|
||||
r"(?:It's|Its) important (?:to me )?that (.+)",
|
||||
r"I value (.+)",
|
||||
r"(?:The )?most important (?:thing|aspect) (?:to me )?is (.+)",
|
||||
],
|
||||
"identity": [
|
||||
r"I am(?: a| an)? (.+)",
|
||||
r"I consider myself(?: a| an)? (.+)",
|
||||
r"I identify as(?: a| an)? (.+)",
|
||||
],
|
||||
"preference": [
|
||||
r"I (?:really )?(?:like|love|enjoy|prefer) (.+)",
|
||||
r"I can't stand (.+)",
|
||||
r"I hate (.+)",
|
||||
r"I always (.+)",
|
||||
r"I never (.+)",
|
||||
],
|
||||
"emotion": [
|
||||
r"I feel (.+)",
|
||||
r"I'm feeling (.+)",
|
||||
r"(?:It|That) makes me feel (.+)",
|
||||
],
|
||||
}
|
||||
|
||||
# Change from ClassVar to instance variable
|
||||
personal_identity: str | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Initialize database tables
|
||||
self.db.migrate()
|
||||
# Download NLTK dependencies
|
||||
self.setup_deps()
|
||||
|
||||
def setup_deps(self):
|
||||
"""Downloads the dependencies for nltk."""
|
||||
|
||||
with open(os.devnull, "w") as null_out:
|
||||
with (
|
||||
contextlib.redirect_stdout(null_out),
|
||||
contextlib.redirect_stderr(null_out),
|
||||
):
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
|
||||
def extract_entities(self, text: str) -> List[str]:
|
||||
"""Extract named entities with improved filtering"""
|
||||
doc = self.nlp(text)
|
||||
|
||||
# Define important entity types with more granular categories
|
||||
important_types = {
|
||||
"PERSON", # Names of people
|
||||
"ORG", # Companies, agencies, institutions
|
||||
"GPE", # Countries, cities, states
|
||||
"NORP", # Nationalities, religious or political groups
|
||||
"PRODUCT", # Products
|
||||
"EVENT", # Named events
|
||||
"WORK_OF_ART", # Titles of books, songs, etc.
|
||||
"FAC", # Buildings, airports, highways, etc.
|
||||
"LOC", # Non-GPE locations, mountain ranges, water bodies
|
||||
"LANGUAGE", # Named languages
|
||||
"TECH", # Technical terms, programming languages
|
||||
}
|
||||
|
||||
# Custom rules for technical terms
|
||||
tech_patterns = [
|
||||
"Python",
|
||||
"JavaScript",
|
||||
"Java",
|
||||
"C\\+\\+",
|
||||
"Ruby",
|
||||
"TypeScript",
|
||||
"React",
|
||||
"Angular",
|
||||
"Vue",
|
||||
"Node\\.js",
|
||||
"Docker",
|
||||
"Kubernetes",
|
||||
"AWS",
|
||||
"Azure",
|
||||
"Git",
|
||||
"GitHub",
|
||||
"VS Code",
|
||||
"Visual Studio",
|
||||
"Linux",
|
||||
"Windows",
|
||||
"MacOS",
|
||||
"iOS",
|
||||
"Android",
|
||||
]
|
||||
|
||||
entities = []
|
||||
|
||||
# Process standard spaCy entities
|
||||
for ent in doc.ents:
|
||||
if (
|
||||
ent.label_ in important_types
|
||||
and len(ent.text.strip()) > 1
|
||||
and not ent.text.strip().isnumeric()
|
||||
):
|
||||
entities.append(ent.text.strip())
|
||||
|
||||
# Process custom tech patterns
|
||||
for pattern in tech_patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
entities.append(match.group())
|
||||
|
||||
# Clean and normalize entities
|
||||
cleaned_entities = []
|
||||
for entity in entities:
|
||||
entity = entity.strip()
|
||||
# Remove any leading/trailing punctuation
|
||||
entity = re.sub(r"^[\W_]+|[\W_]+$", "", entity)
|
||||
# Only add if entity is meaningful
|
||||
if len(entity) > 1 and not entity.isnumeric():
|
||||
cleaned_entities.append(entity)
|
||||
|
||||
return list(set(cleaned_entities))
|
||||
|
||||
def format_context_message(self, entities: List[tuple]) -> str:
|
||||
"""Format context message with essence markers and identity"""
|
||||
context_parts = []
|
||||
|
||||
# Add identity if available
|
||||
if self.personal_identity:
|
||||
context_parts.append(f"Current user: {self.personal_identity}")
|
||||
|
||||
# Add essence markers
|
||||
essence_markers = self.retrieve_essence_markers()
|
||||
if essence_markers:
|
||||
markers_by_type = {}
|
||||
for marker_type, marker_text in essence_markers:
|
||||
markers_by_type.setdefault(marker_type, []).append(marker_text)
|
||||
|
||||
context_parts.append("User characteristics:")
|
||||
for marker_type, markers in markers_by_type.items():
|
||||
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
|
||||
|
||||
# Add entity context with user/llm breakdown
|
||||
if entities:
|
||||
entity_strings = []
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
if total > 0: # Only include if there are mentions
|
||||
entity_strings.append(
|
||||
f"{entity} (mentioned {total} times - User: {user_count}, AI: {llm_count})"
|
||||
)
|
||||
|
||||
if entity_strings:
|
||||
if len(entity_strings) > 1:
|
||||
topics = (
|
||||
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||
)
|
||||
else:
|
||||
topics = entity_strings[0]
|
||||
context_parts.append(f"Recent conversation topics: {topics}")
|
||||
|
||||
# Only return if we have actual content
|
||||
if context_parts:
|
||||
return "\n".join(context_parts)
|
||||
return "" # Return empty string if no context to add
|
||||
|
||||
def extract_essence_markers(self, text: str) -> List[tuple[str, str]]:
|
||||
"""Extract essence markers from text."""
|
||||
markers = []
|
||||
doc = self.nlp(text)
|
||||
|
||||
for sent in doc.sents:
|
||||
sent_text = sent.text.strip().lower()
|
||||
|
||||
for marker_type, pattern_list in self.ESSENCE_PATTERNS.items():
|
||||
for pattern in pattern_list:
|
||||
for match in re.finditer(pattern, sent_text, re.IGNORECASE):
|
||||
marker_text = match.group(1).strip()
|
||||
if self._is_valid_marker(marker_text):
|
||||
markers.append((marker_type, marker_text))
|
||||
|
||||
return markers
|
||||
|
||||
def _is_valid_marker(self, marker_text: str) -> bool:
|
||||
"""Helper method to validate essence markers"""
|
||||
invalid_words = {"um", "uh", "like"}
|
||||
return len(marker_text) > 3 and not any(w in marker_text for w in invalid_words)
|
||||
|
||||
def pre_send_hook(self, conversation: sm.Conversation) -> bool:
|
||||
"""Process user message before sending to LLM"""
|
||||
self.llm_model = conversation.llm_model
|
||||
self.llm_provider = conversation.llm_provider
|
||||
|
||||
last_message = conversation.get_last_message(role="user")
|
||||
if not last_message:
|
||||
return True
|
||||
|
||||
# Check for identity statement first
|
||||
if identity := self.extract_identity(last_message.text):
|
||||
self.store_identity(identity)
|
||||
|
||||
# Check if this is an identity question
|
||||
elif self.is_identity_question(last_message.text):
|
||||
identity = self.load_identity()
|
||||
if identity:
|
||||
response = f"You previously identified yourself as {identity}."
|
||||
conversation.add_message(role="assistant", text=response)
|
||||
|
||||
# Process entities and markers
|
||||
self._process_user_message(last_message.text)
|
||||
self._add_context_to_conversation(conversation)
|
||||
|
||||
return True
|
||||
|
||||
def _process_user_message(self, message: str) -> None:
|
||||
"""Process user message for entities and markers"""
|
||||
# Extract and store entities
|
||||
entities = self.extract_entities(message)
|
||||
for entity in entities:
|
||||
self.store_entity(entity, source="user")
|
||||
|
||||
# Extract and store essence markers
|
||||
essence_markers = self.extract_essence_markers(message)
|
||||
for marker_type, marker_text in essence_markers:
|
||||
self.store_essence_marker(marker_type, marker_text)
|
||||
|
||||
def _add_context_to_conversation(self, conversation: sm.Conversation) -> None:
|
||||
"""Add context message to conversation"""
|
||||
# Load identity if not already set
|
||||
if self.personal_identity is None:
|
||||
self.personal_identity = self.load_identity()
|
||||
|
||||
# Update last seen if we have an identity
|
||||
if self.personal_identity:
|
||||
max_retries = 3
|
||||
retry_delay = 0.1
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.db.update_last_seen(self.personal_identity)
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
print(
|
||||
f"Warning: Could not update last_seen after {max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
recent_entities = self.retrieve_recent_entities(days=30)
|
||||
|
||||
context_message = self.format_context_message(recent_entities)
|
||||
if context_message.strip():
|
||||
conversation.add_message(
|
||||
role="user", text=context_message
|
||||
) # Should be system role, but anthropic is picky.
|
||||
|
||||
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||
"""Store entity with retry logic"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if not entity or len(entity.strip()) < 2:
|
||||
return
|
||||
self.db.store_entity(entity, source)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print(
|
||||
f"Failed to store entity after {max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
def store_identity(self, identity: str) -> None:
|
||||
"""Store identity in database and update class variable with retry logic"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.db.query(
|
||||
"""
|
||||
INSERT INTO identity (identity, created_at, last_seen)
|
||||
VALUES (:identity, datetime('now'), datetime('now'))
|
||||
ON CONFLICT(identity) DO UPDATE SET
|
||||
last_seen = datetime('now')
|
||||
""",
|
||||
identity=identity,
|
||||
)
|
||||
self.personal_identity = identity
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print(
|
||||
f"Failed to store identity after {max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
def load_identity(self) -> str | None:
|
||||
"""Load most recent identity from database with retry logic"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
result = self.db.query(
|
||||
"""
|
||||
SELECT identity FROM identity
|
||||
ORDER BY last_seen DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
).first()
|
||||
self.personal_identity = result.identity if result else None
|
||||
return self.personal_identity
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
return None
|
||||
|
||||
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||
"""Store essence marker with retry logic"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.db.store_essence_marker(marker_type, marker_text)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print(
|
||||
f"Failed to store marker after {max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||
"""Retrieve essence markers with debug logging"""
|
||||
try:
|
||||
markers = self.db.retrieve_essence_markers(days)
|
||||
return markers
|
||||
except Exception as e:
|
||||
print(f"ERROR retrieving markers: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def summarize_memory(self, days: int = 30) -> str:
|
||||
"""Consolidate recent conversation memory into a summary"""
|
||||
entities = self.retrieve_recent_entities(days=days)
|
||||
if not entities:
|
||||
return "No recent conversation history to consolidate."
|
||||
|
||||
# Group entities by frequency
|
||||
frequent = []
|
||||
occasional = []
|
||||
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
if total >= 3:
|
||||
frequent.append(f"{entity} (mentioned {total} times)")
|
||||
else:
|
||||
occasional.append(f"{entity} (mentioned {total} times)")
|
||||
|
||||
# Build summary
|
||||
summary_parts = []
|
||||
|
||||
if self.personal_identity:
|
||||
summary_parts.append(f"User Identity: {self.personal_identity}")
|
||||
|
||||
if frequent:
|
||||
summary_parts.append("Frequently Discussed Topics:")
|
||||
summary_parts.extend([f"- {item}" for item in frequent])
|
||||
|
||||
if occasional:
|
||||
summary_parts.append("Other Topics Mentioned:")
|
||||
summary_parts.extend([f"- {item}" for item in occasional])
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def store_llm_memory(self, conversation: sm.Conversation) -> None:
|
||||
"""Generate and store memories from the LLM's perspective of the conversation."""
|
||||
MEMORY_PROMPT = """Based on the recent messages, what are the most important things to remember?
|
||||
Focus on facts about the user, their preferences, and key discussion points.
|
||||
Format each memory on a new line starting with MEMORY:
|
||||
For example:
|
||||
MEMORY: User prefers Python over JavaScript
|
||||
MEMORY: User is working on a machine learning project"""
|
||||
|
||||
# Create a temporary conversation for memory generation
|
||||
temp_conv = sm.create_conversation(
|
||||
llm_model=self.llm_model, llm_provider=self.llm_provider
|
||||
)
|
||||
|
||||
# Add last few messages for context
|
||||
recent_messages = conversation.messages[-3:] # Get last 3 messages
|
||||
for msg in recent_messages:
|
||||
temp_conv.add_message(role=msg.role, text=msg.text)
|
||||
|
||||
# Ask for memories
|
||||
temp_conv.add_message(role="user", text=MEMORY_PROMPT)
|
||||
response = temp_conv.send()
|
||||
|
||||
if not response or not response.text:
|
||||
return
|
||||
|
||||
# Process and store memories
|
||||
for line in response.text.splitlines():
|
||||
line = line.strip()
|
||||
if line.upper().startswith("MEMORY:"):
|
||||
memory = line.replace("MEMORY:", "", 1).strip()
|
||||
if memory and len(memory) > 3: # Basic validation
|
||||
# Extract potential essence markers from the memory
|
||||
essence_markers = self.extract_essence_markers(memory)
|
||||
if essence_markers:
|
||||
# Store as essence markers if they match patterns
|
||||
for marker_type, marker_text in essence_markers:
|
||||
self.store_essence_marker(marker_type, marker_text)
|
||||
else:
|
||||
# Store as entity if no essence patterns match
|
||||
self.store_entity(memory, source="llm")
|
||||
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||
"""Retrieve recently mentioned entities with their frequency data."""
|
||||
try:
|
||||
entities = self.db.retrieve_recent_entities(days)
|
||||
return entities
|
||||
except Exception as e:
|
||||
print(f"ERROR retrieving entities: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def post_response_hook(self, conversation: sm.Conversation) -> None:
|
||||
"""Process assistant's response after it's received."""
|
||||
try:
|
||||
last_message = conversation.get_last_message(role="assistant")
|
||||
if not last_message or not last_message.text:
|
||||
return
|
||||
|
||||
message_text = last_message.text
|
||||
entities = self.extract_entities(message_text)
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
try:
|
||||
self.store_entity(entity, source="llm")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to store entity {entity}: {str(e)}")
|
||||
else:
|
||||
print("No entities found in message")
|
||||
|
||||
# Process LLM memories
|
||||
print("\nProcessing LLM memories...")
|
||||
try:
|
||||
self.store_llm_memory(conversation)
|
||||
print("✓ Memories processed")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to process memories: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in post_response_hook: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
def extract_identity(self, text: str) -> str | None:
|
||||
"""Extract identity statements from text."""
|
||||
text = text.lower().strip()
|
||||
|
||||
identity_patterns = [
|
||||
(r"^i am (.+)$", 1),
|
||||
(r"^my name is (.+)$", 1),
|
||||
(r"^call me (.+)$", 1),
|
||||
(r"^i'm (.+)$", 1), # Add pattern for "I'm"
|
||||
(r"^hey i'm (.+)$", 1), # Add pattern for "hey I'm"
|
||||
(r"^hello i'm (.+)$", 1), # Add pattern for "hello I'm"
|
||||
(r"^hi i'm (.+)$", 1), # Add pattern for "hi I'm"
|
||||
]
|
||||
|
||||
for pattern, group in identity_patterns:
|
||||
if match := re.match(pattern, text):
|
||||
identity = match.group(group).strip()
|
||||
return identity if identity else None
|
||||
|
||||
return None
|
||||
|
||||
def is_identity_question(self, text: str) -> bool:
|
||||
"""Detect if text contains a question about identity."""
|
||||
text = text.lower().strip()
|
||||
|
||||
# Direct identity questions
|
||||
identity_questions = [
|
||||
"who am i",
|
||||
"what's my name",
|
||||
"what is my name",
|
||||
"do you know who i am",
|
||||
"do you know me",
|
||||
"do you remember me",
|
||||
]
|
||||
|
||||
if text in identity_questions:
|
||||
return True
|
||||
|
||||
# More complex pattern matching
|
||||
tokens = word_tokenize(text)
|
||||
words = set(tokens)
|
||||
|
||||
has_question_word = any(word in ["who", "what"] for word in words)
|
||||
has_identity_term = any(word in ["i", "me", "my", "name"] for word in words)
|
||||
|
||||
return has_question_word and has_identity_term
|
||||
|
||||
def get_all_topics(self, days: int = 90) -> str:
|
||||
"""Get a comprehensive list of all conversation topics.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back (default: 90)
|
||||
|
||||
Returns:
|
||||
Formatted string containing all topics and their mention counts
|
||||
"""
|
||||
entities = self.retrieve_recent_entities(days=days)
|
||||
if not entities:
|
||||
return "No conversation topics found in the specified time period."
|
||||
|
||||
# Sort entities by total mentions
|
||||
sorted_entities = sorted(entities, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Format output using markdown
|
||||
output_parts = ["## Conversation Topics"]
|
||||
|
||||
# Add top mentions with details
|
||||
for entity, total, user_count, llm_count in sorted_entities:
|
||||
source_breakdown = f"(User: {user_count}, AI: {llm_count})"
|
||||
output_parts.append(f"- **{entity}**: {total} mentions {source_breakdown}")
|
||||
|
||||
# Add list of all topics
|
||||
all_topics = [entity[0] for entity in sorted_entities]
|
||||
if all_topics:
|
||||
output_parts.append("\n## All Topics Mentioned")
|
||||
output_parts.append(", ".join(all_topics))
|
||||
|
||||
return "\n".join(output_parts)
|
||||
|
||||
def get_memories(self) -> str:
|
||||
"""Retrieve and format all stored memories."""
|
||||
entities = self.db.retrieve_recent_entities(
|
||||
days=3650
|
||||
) # Retrieve entities from the last 10 years
|
||||
if not entities:
|
||||
return "No memories found."
|
||||
|
||||
memory_parts = ["## All Stored Memories"]
|
||||
|
||||
# Add identity if available
|
||||
if self.personal_identity:
|
||||
memory_parts.append(f"\n**Current User**: {self.personal_identity}")
|
||||
|
||||
# Group memories by source
|
||||
user_memories = []
|
||||
llm_memories = []
|
||||
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
if user_count > 0:
|
||||
user_memories.append(f"- **{entity}**: {user_count} mentions")
|
||||
if llm_count > 0:
|
||||
llm_memories.append(f"- **{entity}**: {llm_count} mentions")
|
||||
|
||||
if user_memories:
|
||||
memory_parts.append("\n### Things You've Mentioned")
|
||||
memory_parts.extend(user_memories)
|
||||
|
||||
if llm_memories:
|
||||
memory_parts.append("\n### Things I've Mentioned")
|
||||
memory_parts.extend(llm_memories)
|
||||
|
||||
# Add essence markers if available
|
||||
essence_markers = self.retrieve_essence_markers(days=3650)
|
||||
if essence_markers:
|
||||
memory_parts.append("\n### User Characteristics")
|
||||
markers_by_type = {}
|
||||
for marker_type, marker_text in essence_markers:
|
||||
markers_by_type.setdefault(marker_type, []).append(marker_text)
|
||||
|
||||
for marker_type, markers in markers_by_type.items():
|
||||
memory_parts.append(f"\n**{marker_type.title()}**:")
|
||||
memory_parts.extend([f"- {marker}" for marker in markers])
|
||||
|
||||
return "\n".join(memory_parts)
|
||||
Reference in New Issue
Block a user