mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fa67c3b2f | |||
| b7e950a8f0 | |||
| 735c6ba665 | |||
| 9132030cbd | |||
| aeea8936ce | |||
| e79b474215 | |||
| fe2ca9d5f5 | |||
| 670240b943 | |||
| 2e66c0232b | |||
| 8b1f63f796 | |||
| 5d7a917d23 | |||
| 9703332967 | |||
| fe6001e710 | |||
| 63343d1c61 | |||
| ece056a5e0 | |||
| f44ec977a4 | |||
| 33f8fcde11 | |||
| 598bcd514d | |||
| 8bdbe4d8d5 | |||
| d4068cf07a | |||
| 747488f633 | |||
| 9ae03685b5 | |||
| 91af281a9d | |||
| 309f390800 | |||
| b316352311 | |||
| 236020b3b9 | |||
| 8a5a29f864 | |||
| 30d8412bbf | |||
| 4a852e6220 | |||
| 7f5ba667bd | |||
| 4b87a8b91c | |||
| 4c1d1fa873 | |||
| 0087a7e8f2 | |||
| 07715ed8df | |||
| 03f91c5153 | |||
| aa601648c6 | |||
| a26c51014b | |||
| b3946f1ff9 | |||
| 7a84ade5a4 | |||
| 3e1d1f98ad | |||
| 48e6ef2a43 | |||
| 1528dc2a21 | |||
| 46cd19ea90 | |||
| 2848e86dce | |||
| 6aadc9fcd7 | |||
| a8792319a8 | |||
| 3e8d5662d2 | |||
| 51c1646ef4 | |||
| f09052c18e | |||
| 1d3ae26301 | |||
| 44fd3468fa | |||
| 5770c37edf | |||
| 37334a21c5 | |||
| 57d54abf24 | |||
| c3397488e3 | |||
| 678a8a8b32 | |||
| a5c7486dfc | |||
| 5c6650f2b2 | |||
| 549d74e146 | |||
| 328be94677 | |||
| 7b21b9f258 | |||
| d7f8418f23 | |||
| 9968f162d6 | |||
| cb73621e39 | |||
| 4721dd8cc0 | |||
| bdb1ff0e69 | |||
| 94f381032e | |||
| b3a35cadd4 | |||
| 718f5a66c0 | |||
| df02547dec | |||
| 9dd89b7ef1 | |||
| 15ee5d1cf9 | |||
| 25ba1a9289 | |||
| 22aff505c4 | |||
| 29b2008edf | |||
| c5c99a05fd | |||
| cb969dec4c | |||
| 1aeeb9127d | |||
| c21f68aad6 | |||
| a68bd74fd8 |
@@ -168,3 +168,5 @@ cython_debug/
|
||||
src/**
|
||||
requirements.txt
|
||||
Pipfile
|
||||
enhanced_context.db
|
||||
enhanced_context_sarah.db
|
||||
|
||||
+15
-1
@@ -1,9 +1,23 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
|
||||
## 0.2.4 (2024-11-11)
|
||||
|
||||
- General improvements.
|
||||
|
||||
## 0.2.3 (2024-11-04)
|
||||
|
||||
- Remove default max-tokens for OpenAI provider.
|
||||
|
||||
## 0.2.3 (2024-11-03)
|
||||
|
||||
- Update default model for Amazon provider.
|
||||
- Improved logging to handle streaming functions.
|
||||
|
||||
## 0.2.2 (2024-11-02)
|
||||
|
||||
- Add openai streaming support (set `stream=True` to `generate_text`).
|
||||
- Add streaming support (set `stream=True` to `generate_text`).
|
||||
- `conv.prepend_system_message` now uses system role by default.
|
||||
- Add `provider.supports_streaming` property.
|
||||
- Add `provider.supports_structured_response` property.
|
||||
|
||||
@@ -35,7 +35,7 @@ The APIs remain identical between all supported providers / models:
|
||||
<tr>
|
||||
<td><a href="https://aws.amazon.com/bedrock/">Amazon's Bedrock</a></td>
|
||||
<td><code>"amazon"</code></td>
|
||||
<td><code>"anthropic.claude-3-sonnet-20240229-v1:0"</code></td>
|
||||
<td><code>"anthropic.claude-3-5-sonnet-20241022-v2:0"</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href="https://gemini.google/">Google's Gemini</a></td>
|
||||
@@ -93,17 +93,26 @@ import simplemind as sm
|
||||
|
||||
## Examples
|
||||
|
||||
Here are some examples of how to use Simplemind:
|
||||
Here are some examples of how to use Simplemind.
|
||||
|
||||
**Please note**: Most of the calls seen here optionally accept `llm_provider` and `llm_model` parameters, which you provide as strings.
|
||||
|
||||
### Text Completion
|
||||
|
||||
Generate a response from an AI model based on a given prompt:
|
||||
|
||||
```pycon
|
||||
>>> sm.generate_text(prompt="What is the meaning of life?", llm_provider="openai", llm_model="gpt-4o")
|
||||
>>> sm.generate_text(prompt="What is the meaning of life?")
|
||||
"The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life's meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life's purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections."
|
||||
```
|
||||
|
||||
### Streaming Text
|
||||
|
||||
```python
|
||||
>>> for chunk in sm.generate_text("Write a poem about the moon", stream=True):
|
||||
... print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
### Structured Data with Pydantic
|
||||
|
||||
You can use Pydantic models to structure the response from the LLM, if the LLM supports it.
|
||||
@@ -115,12 +124,7 @@ class Poem(BaseModel):
|
||||
```
|
||||
|
||||
```pycon
|
||||
>>> sm.generate_data(
|
||||
"Write a poem about love",
|
||||
llm_model="gpt-4o-mini",
|
||||
llm_provider="openai",
|
||||
response_model=Poem,
|
||||
)
|
||||
>>> sm.generate_data("Write a poem about love", response_model=Poem)
|
||||
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn love’s embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams we’ve laid,\nLove’s threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo here’s to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
|
||||
```
|
||||
|
||||
@@ -143,8 +147,6 @@ class Recipe(BaseModel):
|
||||
|
||||
recipe = sm.generate_data(
|
||||
"Write a recipe for chocolate chip cookies",
|
||||
llm_model="gpt-4o-mini",
|
||||
llm_provider="openai",
|
||||
response_model=Recipe,
|
||||
)
|
||||
```
|
||||
@@ -156,7 +158,7 @@ Special thanks to [@jxnl](https://github.com/jxnl) for building [Instructor](htt
|
||||
SimpleMind also allows for easy conversational flows:
|
||||
|
||||
```pycon
|
||||
>>> conv = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||
>>> conv = sm.create_conversation()
|
||||
|
||||
>>> # Add a message to the conversation
|
||||
>>> conv.add_message("user", "Hi there, how are you?")
|
||||
@@ -186,13 +188,12 @@ response = gpt_4o_mini.generate_text("Hello!")
|
||||
conversation = gpt_4o_mini.create_conversation()
|
||||
```
|
||||
|
||||
This maintains the simplicity of the original API while reducing repetition. The session object also supports overriding defaults on a per-call basis:
|
||||
This maintains the simplicity of the original API while reducing repetition.
|
||||
|
||||
The session object also supports overriding defaults on a per-call basis:
|
||||
|
||||
```python
|
||||
response = gpt_4o_mini.generate_text(
|
||||
"Complex task here",
|
||||
llm_model="gpt-4"
|
||||
)
|
||||
response = gpt_4o_mini.generate_text("Complex task here", llm_model="gpt-4")
|
||||
```
|
||||
|
||||
### Basic Memory Plugin
|
||||
@@ -215,7 +216,7 @@ class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
conversation.add_message(role="system", text=m)
|
||||
|
||||
|
||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_plugin(SimpleMemoryPlugin())
|
||||
|
||||
|
||||
|
||||
@@ -1,34 +1,68 @@
|
||||
from _context import simplemind as sm
|
||||
from pydantic import BaseModel
|
||||
import simplemind as sm
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
class InstructionStep(BaseModel):
|
||||
step_number: int
|
||||
instruction: str
|
||||
|
||||
|
||||
class RecipeIngredient(BaseModel):
|
||||
name: str
|
||||
quantity: float
|
||||
unit: str
|
||||
|
||||
|
||||
class Recipe(BaseModel):
|
||||
name: str
|
||||
ingredients: list[RecipeIngredient]
|
||||
instructions: list[InstructionStep]
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
output = f"\n=== {self.name.upper()} ===\n\n"
|
||||
|
||||
output += "INGREDIENTS:\n"
|
||||
console = Console(record=True, file=None)
|
||||
|
||||
# Create formatted title with more emphasis
|
||||
title = Text("✨ " + self.name.upper() + " ✨", style="bold blue")
|
||||
|
||||
# Format ingredients with better structure
|
||||
ingredients_text = Text("\n📝 INGREDIENTS:\n", style="bold green")
|
||||
for ing in self.ingredients:
|
||||
output += f"• {ing.quantity} {ing.unit} {ing.name}\n"
|
||||
|
||||
output += "\nINSTRUCTIONS:\n"
|
||||
# Format numbers to avoid floating decimals when whole numbers
|
||||
quantity = int(ing.quantity) if ing.quantity.is_integer() else ing.quantity
|
||||
ingredients_text.append(f" • {quantity} {ing.unit} ", style="bright_white")
|
||||
ingredients_text.append(f"{ing.name}\n", style="italic bright_white")
|
||||
|
||||
# Format instructions with better spacing and styling
|
||||
instructions_text = Text("\n👩🍳 INSTRUCTIONS:\n", style="bold yellow")
|
||||
for step in self.instructions:
|
||||
output += f"{step.step_number}. {step.instruction}\n"
|
||||
|
||||
return output
|
||||
|
||||
instructions_text.append(
|
||||
f"\n {step.step_number}. ", style="bold bright_white"
|
||||
)
|
||||
instructions_text.append(f"{step.instruction}", style="bright_white")
|
||||
|
||||
# Combine all text
|
||||
full_text = Text.assemble(
|
||||
ingredients_text, instructions_text, "\n"
|
||||
) # Added extra newline
|
||||
|
||||
# Create panel with enhanced styling
|
||||
panel = Panel(
|
||||
full_text,
|
||||
title=title,
|
||||
border_style="blue",
|
||||
padding=(1, 2), # Add padding (vertical, horizontal)
|
||||
expand=False, # Don't expand to full terminal width
|
||||
title_align="center",
|
||||
)
|
||||
|
||||
# Render the panel to string without printing
|
||||
with console.capture() as capture:
|
||||
console.print(panel)
|
||||
return capture.get()
|
||||
|
||||
|
||||
recipe = sm.generate_data(
|
||||
"Write a recipe for chocolate chip cookies",
|
||||
@@ -63,4 +97,3 @@ print(recipe)
|
||||
# 7. Drop by rounded tablespoon onto ungreased cookie sheets.
|
||||
# 8. Bake for 9 to 11 minutes, or until edges are golden.
|
||||
# 9. Let cool on the cookie sheet for a few minutes before transferring to wire racks to cool completely.
|
||||
|
||||
|
||||
+31
-33
@@ -1,24 +1,15 @@
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
from _context import sm
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from _context import sm
|
||||
|
||||
|
||||
class MultiAIConversation:
|
||||
"""Orchestrates conversations between multiple AI models."""
|
||||
|
||||
MODEL_SESSIONS = {
|
||||
"Llama3.2": sm.Session(
|
||||
llm_provider="ollama",
|
||||
llm_model="llama3.2",
|
||||
),
|
||||
"Claude-3.5-Sonnet": sm.Session(
|
||||
llm_provider="anthropic",
|
||||
llm_model="claude-3-5-sonnet-20241022",
|
||||
),
|
||||
"GPT-4o": sm.Session(
|
||||
llm_provider="openai",
|
||||
llm_model="gpt-4o",
|
||||
@@ -27,6 +18,10 @@ class MultiAIConversation:
|
||||
llm_provider="xai",
|
||||
llm_model="grok-beta",
|
||||
),
|
||||
"Claude-3.5-Sonnet": sm.Session(
|
||||
llm_provider="anthropic",
|
||||
llm_model="claude-3-5-sonnet-20241022",
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, topic: str, turns_per_model: int = 1, max_rounds: int = 5):
|
||||
@@ -36,13 +31,14 @@ class MultiAIConversation:
|
||||
self.max_rounds = max_rounds
|
||||
self.conversation_history: List[Tuple[str, str]] = []
|
||||
self.console = Console()
|
||||
self.user_name = "Kenneth Reitz"
|
||||
|
||||
def _format_system_prompt(self, ai_name: str) -> str:
|
||||
"""Creates a system prompt for each AI model."""
|
||||
return f"""You are {ai_name}. You are participating in a thoughtful discussion with other AI models about {self.topic}.
|
||||
|
||||
Rules:
|
||||
1. Be concise but insightful (keep responses under 100 words)
|
||||
1. Be concise but insightful (keep responses under 140 words)
|
||||
2. Build upon previous points made in the conversation
|
||||
3. Ask questions to deepen the discussion when appropriate
|
||||
4. Stay on topic while maintaining your unique perspective
|
||||
@@ -72,32 +68,31 @@ Current discussion topic: {self.topic}"""
|
||||
# Store in history
|
||||
self.conversation_history.append((ai_name, response))
|
||||
|
||||
def _get_user_input(self) -> str:
|
||||
"""Gets input from the user for the discussion."""
|
||||
self.console.print("\n[bold green]Your turn! Share your thoughts:[/bold green]")
|
||||
user_response = input("> ")
|
||||
self._print_response(self.user_name, user_response)
|
||||
return user_response
|
||||
|
||||
def run_conversation(self):
|
||||
"""Runs the multi-AI conversation."""
|
||||
|
||||
# Initialize the conversation
|
||||
initial_prompt = (
|
||||
f"Let's have a thoughtful discussion about {self.topic}. "
|
||||
"Please share your initial thoughts in 2-3 sentences."
|
||||
# Get initial thoughts from the human
|
||||
self.console.print(
|
||||
f"\n[bold green]Start the discussion about {self.topic}:[/bold green]"
|
||||
)
|
||||
self._get_user_input()
|
||||
|
||||
for round_num in range(self.max_rounds):
|
||||
self.console.print(f"\n[bold green]Round {round_num + 1}[/bold green]")
|
||||
|
||||
# Let all AI models respond
|
||||
for model_name, session in self.MODEL_SESSIONS.items():
|
||||
for turn in range(self.turns_per_model):
|
||||
conversation = self._create_conversation(session, model_name)
|
||||
|
||||
# Add the prompt
|
||||
prompt = (
|
||||
initial_prompt
|
||||
if round_num == 0 and turn == 0
|
||||
else (
|
||||
f"Continue the discussion about {self.topic}, "
|
||||
"responding to the previous points made."
|
||||
)
|
||||
)
|
||||
|
||||
# Add the prompt (simplified since human always starts)
|
||||
prompt = f"Continue the discussion about {self.topic}, responding to the previous points made."
|
||||
conversation.add_message(role="user", text=prompt)
|
||||
|
||||
# Get and print response
|
||||
@@ -107,17 +102,24 @@ Current discussion topic: {self.topic}"""
|
||||
# Small delay to prevent rate limiting
|
||||
time.sleep(1)
|
||||
|
||||
# Then get user input at the end of the round
|
||||
self._get_user_input()
|
||||
|
||||
# Optional: Add a separator between rounds
|
||||
self.console.print("\n" + "-" * 50)
|
||||
|
||||
|
||||
def have_ai_discussion(topic: str, turns_per_model: int = 1, max_rounds: int = 3):
|
||||
def have_ai_discussion(turns_per_model: int = 1, max_rounds: int = 3):
|
||||
"""Convenience function to start an AI discussion."""
|
||||
# Get topic from user
|
||||
print("\nWhat topic would you like to discuss?")
|
||||
topic = input("> ")
|
||||
|
||||
debate = MultiAIConversation(
|
||||
topic=topic, turns_per_model=turns_per_model, max_rounds=max_rounds
|
||||
)
|
||||
|
||||
print(f"\nStarting AI discussion on: {topic}")
|
||||
print(f"\nStarting AI discussion about: {topic}")
|
||||
print("=" * 50)
|
||||
|
||||
debate.run_conversation()
|
||||
@@ -125,8 +127,4 @@ def have_ai_discussion(topic: str, turns_per_model: int = 1, max_rounds: int = 3
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example topics
|
||||
topic = "The future of human-AI collaboration in creative fields",
|
||||
|
||||
# Run a discussion on the first topic
|
||||
have_ai_discussion(topic=topic, turns_per_model=1, max_rounds=3)
|
||||
have_ai_discussion(turns_per_model=1, max_rounds=5)
|
||||
|
||||
@@ -0,0 +1,952 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
import spacy
|
||||
import xerox
|
||||
from _context import simplemind as sm
|
||||
from docopt import docopt
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
||||
|
||||
# Enable Logfire for debugging.
|
||||
# sm.enable_logfire()
|
||||
|
||||
__doc__ = """Enhanced Context Chat Interface
|
||||
|
||||
Usage:
|
||||
enhanced_context.py [--provider=<provider>] [--model=<model>]
|
||||
enhanced_context.py (-h | --help)
|
||||
|
||||
Options:
|
||||
-h --help Show this screen.
|
||||
--provider=<provider> LLM provider to use (openai/anthropic/xai/ollama)
|
||||
--model=<model> Specific model to use (e.g. o1-preview)
|
||||
"""
|
||||
|
||||
|
||||
class ContextDatabase:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.init_db()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Context manager for database connections"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def init_db(self):
|
||||
"""Initialize the database with proper schema"""
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory (
|
||||
entity TEXT,
|
||||
source TEXT,
|
||||
last_mentioned TIMESTAMP,
|
||||
mention_count INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (entity, source)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS identity (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
last_updated TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS essence_markers (
|
||||
marker_type TEXT,
|
||||
marker_text TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
PRIMARY KEY (marker_type, marker_text)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||
"""Store or update entity mention with source tracking"""
|
||||
with self.get_connection() as conn:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory (entity, source, last_mentioned, mention_count)
|
||||
VALUES (?, ?, ?, 1)
|
||||
ON CONFLICT(entity, source) DO UPDATE SET
|
||||
last_mentioned = ?,
|
||||
mention_count = mention_count + 1
|
||||
""",
|
||||
(entity, source, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||
"""Retrieve recently mentioned entities with frequency and source"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
entity,
|
||||
SUM(mention_count) as total_mentions,
|
||||
GROUP_CONCAT(source || ':' || mention_count) as source_counts
|
||||
FROM memory
|
||||
WHERE last_mentioned >= datetime('now', ?, 'localtime')
|
||||
GROUP BY entity
|
||||
ORDER BY total_mentions DESC, MAX(last_mentioned) DESC
|
||||
LIMIT 50
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
|
||||
entities = []
|
||||
for row in cur.fetchall():
|
||||
entity, total_count, source_counts = row
|
||||
source_dict = dict(sc.split(":") for sc in source_counts.split(","))
|
||||
entities.append(
|
||||
(
|
||||
entity,
|
||||
total_count,
|
||||
int(source_dict.get("user", 0)),
|
||||
int(source_dict.get("llm", 0)),
|
||||
)
|
||||
)
|
||||
return entities
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error while retrieving entities: {e}")
|
||||
return []
|
||||
|
||||
def store_identity(self, identity: str) -> None:
|
||||
"""Store personal identity in database"""
|
||||
if not identity:
|
||||
return
|
||||
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
now = datetime.now()
|
||||
# Store in identity table
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO identity (id, name, last_updated)
|
||||
VALUES (1, ?, ?)
|
||||
""",
|
||||
(identity, now),
|
||||
)
|
||||
|
||||
# Store in memory table
|
||||
self.store_entity(identity)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error while storing identity: {e}")
|
||||
|
||||
def load_identity(self) -> str | None:
|
||||
"""Load personal identity from database"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM identity WHERE id = 1")
|
||||
result = cur.fetchone()
|
||||
return result[0] if result else None
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error while loading identity: {e}")
|
||||
return None
|
||||
|
||||
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||
"""Store essence marker in database"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO essence_markers
|
||||
(marker_type, marker_text, timestamp)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(marker_type, marker_text, now),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error storing essence marker: {e}")
|
||||
|
||||
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||
"""Retrieve recent essence markers"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT marker_type, marker_text
|
||||
FROM essence_markers
|
||||
WHERE timestamp >= datetime('now', ?, 'localtime')
|
||||
ORDER BY timestamp DESC
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
)
|
||||
return cur.fetchall()
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error(f"Database error retrieving essence markers: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class EnhancedContextPlugin(sm.BasePlugin):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def __init__(self, verbose: bool = False):
|
||||
super().__init__()
|
||||
# Set up logging
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize NLP model
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
self.logger.error(
|
||||
"Failed to load spaCy model. Please install it using: python -m spacy download en_core_web_sm"
|
||||
)
|
||||
raise
|
||||
|
||||
# Initialize database
|
||||
self.db = ContextDatabase(DB_PATH)
|
||||
self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}")
|
||||
|
||||
# Load identity from database
|
||||
self.personal_identity = self.db.load_identity()
|
||||
|
||||
# Download required NLTK data silently
|
||||
try:
|
||||
with open(os.devnull, "w") as null_out:
|
||||
with (
|
||||
contextlib.redirect_stdout(null_out),
|
||||
contextlib.redirect_stderr(null_out),
|
||||
):
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
except LookupError as e:
|
||||
self.logger.error(f"Error downloading NLTK data: {e}")
|
||||
|
||||
# Add LLM personality traits for easter egg
|
||||
self.llm_personalities = [
|
||||
"You are a wise philosopher who speaks in riddles",
|
||||
"You are an excited scientist who loves discovering patterns",
|
||||
"You are a detective who analyzes every detail",
|
||||
"You are a poet who sees beauty in connections",
|
||||
"You are a historian who relates everything to the past",
|
||||
]
|
||||
|
||||
# Add these lines to store the conversation's model and provider
|
||||
self.llm_model = None
|
||||
self.llm_provider = None
|
||||
|
||||
def extract_entities(self, text: str) -> List[str]:
|
||||
"""Extract named entities with improved filtering"""
|
||||
doc = self.nlp(text)
|
||||
|
||||
# Define important entity types
|
||||
important_types = {
|
||||
"PERSON",
|
||||
"ORG",
|
||||
"GPE",
|
||||
"NORP",
|
||||
"PRODUCT",
|
||||
"EVENT",
|
||||
"WORK_OF_ART",
|
||||
}
|
||||
|
||||
entities = [
|
||||
ent.text.strip()
|
||||
for ent in doc.ents
|
||||
if (
|
||||
ent.label_ in important_types
|
||||
and len(ent.text.strip()) > 1
|
||||
and not ent.text.isnumeric()
|
||||
)
|
||||
]
|
||||
|
||||
return list(set(entities))
|
||||
|
||||
def format_context_message(
|
||||
self, entities: List[tuple], include_identity: bool = True
|
||||
) -> str:
|
||||
"""Format context message with essence markers"""
|
||||
context_parts = []
|
||||
|
||||
# Add identity context
|
||||
if include_identity and self.personal_identity:
|
||||
context_parts.append(f"The user's name is {self.personal_identity}.")
|
||||
|
||||
# Add essence markers
|
||||
essence_markers = self.retrieve_essence_markers()
|
||||
if essence_markers:
|
||||
markers_by_type = {}
|
||||
for marker_type, marker_text in essence_markers:
|
||||
markers_by_type.setdefault(marker_type, []).append(marker_text)
|
||||
|
||||
context_parts.append("User characteristics:")
|
||||
for marker_type, markers in markers_by_type.items():
|
||||
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
|
||||
|
||||
# Add entity context with user/llm breakdown
|
||||
if entities:
|
||||
entity_strings = [
|
||||
f"{entity} (mentioned {total} times - User: {user_count}, AI: {llm_count})"
|
||||
for entity, total, user_count, llm_count in entities
|
||||
]
|
||||
|
||||
topics = (
|
||||
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||
if len(entity_strings) > 1
|
||||
else entity_strings[0]
|
||||
)
|
||||
|
||||
context_parts.append(f"Recent conversation topics: {topics}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def extract_essence_markers(self, text: str) -> List[tuple[str, str]]:
|
||||
"""Extract essence markers from text."""
|
||||
patterns = {
|
||||
"value": [
|
||||
r"I (?:really )?(?:believe|think) (?:that )?(.+)",
|
||||
r"(?:It's|Its) important (?:to me )?that (.+)",
|
||||
r"I value (.+)",
|
||||
r"(?:The )?most important (?:thing|aspect) (?:to me )?is (.+)",
|
||||
],
|
||||
"identity": [
|
||||
r"I am(?: a| an)? (.+)",
|
||||
r"I consider myself(?: a| an)? (.+)",
|
||||
r"I identify as(?: a| an)? (.+)",
|
||||
],
|
||||
"preference": [
|
||||
r"I (?:really )?(?:like|love|enjoy|prefer) (.+)",
|
||||
r"I can't stand (.+)",
|
||||
r"I hate (.+)",
|
||||
r"I always (.+)",
|
||||
r"I never (.+)",
|
||||
],
|
||||
"emotion": [
|
||||
r"I feel (.+)",
|
||||
r"I'm feeling (.+)",
|
||||
r"(?:It|That) makes me feel (.+)",
|
||||
],
|
||||
}
|
||||
|
||||
markers = []
|
||||
doc = self.nlp(text)
|
||||
|
||||
for sent in doc.sents:
|
||||
sent_text = sent.text.strip().lower()
|
||||
|
||||
for marker_type, pattern_list in patterns.items():
|
||||
for pattern in pattern_list:
|
||||
for match in re.finditer(pattern, sent_text, re.IGNORECASE):
|
||||
marker_text = match.group(1).strip()
|
||||
if self._is_valid_marker(marker_text):
|
||||
markers.append((marker_type, marker_text))
|
||||
|
||||
return markers
|
||||
|
||||
def _is_valid_marker(self, marker_text: str) -> bool:
|
||||
"""Helper method to validate essence markers"""
|
||||
invalid_words = {"um", "uh", "like"}
|
||||
return len(marker_text) > 3 and not any(w in marker_text for w in invalid_words)
|
||||
|
||||
def pre_send_hook(self, conversation: sm.Conversation) -> bool:
|
||||
"""Process user message before sending to LLM"""
|
||||
self.llm_model = conversation.llm_model
|
||||
self.llm_provider = conversation.llm_provider
|
||||
|
||||
last_message = conversation.get_last_message(role="user")
|
||||
if not last_message:
|
||||
return True
|
||||
|
||||
# Handle special commands
|
||||
if result := self._handle_special_commands(conversation, last_message.text):
|
||||
return result
|
||||
|
||||
self.logger.info(f"Processing user message: {last_message.text}")
|
||||
|
||||
# Process entities and markers
|
||||
self._process_user_message(last_message.text)
|
||||
|
||||
# Add context
|
||||
self._add_context_to_conversation(conversation)
|
||||
|
||||
return True
|
||||
|
||||
def _handle_special_commands(
|
||||
self, conversation: sm.Conversation, message: str
|
||||
) -> bool | None:
|
||||
"""Handle special commands like /summary"""
|
||||
if message.strip().lower() == "/summary":
|
||||
summary = self.summarize_memory()
|
||||
conversation.add_message(role="assistant", text=summary)
|
||||
return False
|
||||
elif message.strip().lower() == "/topics":
|
||||
topics = self.get_all_topics()
|
||||
conversation.add_message(role="assistant", text=topics)
|
||||
return False
|
||||
return None
|
||||
|
||||
def _process_user_message(self, message: str) -> None:
|
||||
"""Process user message for entities and markers"""
|
||||
# Extract and store entities
|
||||
entities = self.extract_entities(message)
|
||||
for entity in entities:
|
||||
self.store_entity(entity, source="user")
|
||||
|
||||
# Extract and store essence markers
|
||||
essence_markers = self.extract_essence_markers(message)
|
||||
for marker_type, marker_text in essence_markers:
|
||||
self.store_essence_marker(marker_type, marker_text)
|
||||
self.logger.info(f"Found essence marker: {marker_type} - {marker_text}")
|
||||
|
||||
def _add_context_to_conversation(self, conversation: sm.Conversation) -> None:
|
||||
"""Add context message to conversation"""
|
||||
recent_entities = self.retrieve_recent_entities(days=30)
|
||||
context_message = self.format_context_message(recent_entities)
|
||||
if context_message:
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
self.logger.info(f"Added context message: {context_message}")
|
||||
|
||||
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||
self.db.store_entity(entity, source)
|
||||
|
||||
def store_identity(self, identity: str) -> None:
|
||||
self.db.store_identity(identity)
|
||||
self.personal_identity = identity
|
||||
|
||||
def load_identity(self) -> str | None:
|
||||
self.personal_identity = self.db.load_identity()
|
||||
return self.personal_identity
|
||||
|
||||
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||
self.db.store_essence_marker(marker_type, marker_text)
|
||||
|
||||
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||
return self.db.retrieve_essence_markers(days)
|
||||
|
||||
def summarize_memory(self, days: int = 30) -> str:
|
||||
"""Consolidate recent conversation memory into a summary"""
|
||||
entities = self.retrieve_recent_entities(days=days)
|
||||
if not entities:
|
||||
return "No recent conversation history to consolidate."
|
||||
|
||||
# Group entities by frequency
|
||||
frequent = []
|
||||
occasional = []
|
||||
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
if total >= 3:
|
||||
frequent.append(f"{entity} (mentioned {total} times)")
|
||||
else:
|
||||
occasional.append(f"{entity} (mentioned {total} times)")
|
||||
|
||||
# Build summary
|
||||
summary_parts = []
|
||||
|
||||
if self.personal_identity:
|
||||
summary_parts.append(f"User Identity: {self.personal_identity}")
|
||||
|
||||
if frequent:
|
||||
summary_parts.append("Frequently Discussed Topics:")
|
||||
summary_parts.extend([f"- {item}" for item in frequent])
|
||||
|
||||
if occasional:
|
||||
summary_parts.append("Other Topics Mentioned:")
|
||||
summary_parts.extend([f"- {item}" for item in occasional])
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def simulate_llm_conversation(self, context: str, num_turns: int = 3) -> str:
|
||||
"""Simulate a conversation between multiple LLM personalities about the context"""
|
||||
conversation_log = []
|
||||
|
||||
def get_response(personality: str, previous_messages: str) -> str:
|
||||
prompt = (
|
||||
f"{personality}. You are participating in a brief group discussion "
|
||||
f"about the following context:\n{context}\n\n"
|
||||
f"Previous messages:\n{previous_messages}\n\n"
|
||||
"Provide a short, focused response (1-2 sentences) that builds on "
|
||||
"the discussion. Be creative but stay on topic."
|
||||
)
|
||||
|
||||
temp_conv = sm.create_conversation(
|
||||
llm_model=self.llm_model, llm_provider=self.llm_provider
|
||||
)
|
||||
temp_conv.add_message(role="user", text=prompt)
|
||||
response = temp_conv.send()
|
||||
return response.text.strip()
|
||||
|
||||
# Select random personalities for this conversation
|
||||
selected_personalities = random.sample(
|
||||
self.llm_personalities, min(num_turns, len(self.llm_personalities))
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for i, personality in enumerate(selected_personalities, 1):
|
||||
previous = "\n".join(conversation_log)
|
||||
response = get_response(personality, previous)
|
||||
conversation_log.append(f"Speaker {i}: {response}")
|
||||
|
||||
return "\n\n".join(conversation_log)
|
||||
|
||||
def store_llm_memory(self, conversation: sm.Conversation) -> None:
|
||||
"""Generate and store memories from the LLM's perspective of the conversation.
|
||||
|
||||
Args:
|
||||
conversation: The conversation object containing message history
|
||||
"""
|
||||
prompt = """Based on the recent messages, what are the most important things to remember?
|
||||
Format each memory on a new line starting with MEMORY:
|
||||
For example:
|
||||
MEMORY: User prefers Python over JavaScript
|
||||
MEMORY: User is working on a machine learning project"""
|
||||
|
||||
# Create temporary conversation for memory generation
|
||||
temp_conv = sm.create_conversation(
|
||||
llm_model=self.llm_model, llm_provider=self.llm_provider
|
||||
)
|
||||
|
||||
# Add last few messages for context
|
||||
for msg in conversation.messages[-3:]: # Last 3 messages
|
||||
temp_conv.add_message(role=msg.role, text=msg.text)
|
||||
|
||||
# Get memories from LLM
|
||||
temp_conv.add_message(role="user", text=prompt)
|
||||
response = temp_conv.send()
|
||||
|
||||
# Process and store memories
|
||||
if response and response.text:
|
||||
for line in response.text.split("\n"):
|
||||
if line.strip().startswith("MEMORY:"):
|
||||
memory = line.replace("MEMORY:", "").strip()
|
||||
self.store_entity(memory, source="llm")
|
||||
self.logger.info(f"Stored LLM-generated memory: {memory}")
|
||||
|
||||
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||
"""Retrieve recently mentioned entities with their frequency data.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
List of tuples containing (entity, total_mentions, user_mentions, llm_mentions)
|
||||
"""
|
||||
try:
|
||||
return self.db.retrieve_recent_entities(days)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving recent entities: {e}")
|
||||
return []
|
||||
|
||||
def post_response_hook(self, conversation: sm.Conversation) -> None:
|
||||
"""Process assistant's response after it's received."""
|
||||
# Get the last assistant message
|
||||
last_message = conversation.get_last_message(role="assistant")
|
||||
if not last_message:
|
||||
return
|
||||
|
||||
# Extract and store entities from assistant's response
|
||||
entities = self.extract_entities(last_message.text)
|
||||
for entity in entities:
|
||||
self.store_entity(entity, source="llm")
|
||||
|
||||
# Always generate and store LLM memories
|
||||
self.store_llm_memory(conversation)
|
||||
|
||||
def extract_identity(self, text: str) -> str | None:
|
||||
"""Extract identity statements from text.
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
The extracted identity or None if not found
|
||||
"""
|
||||
text = text.lower().strip()
|
||||
|
||||
identity_patterns = [
|
||||
(r"^i am (.+)$", 1),
|
||||
(r"^my name is (.+)$", 1),
|
||||
(r"^call me (.+)$", 1),
|
||||
]
|
||||
|
||||
for pattern, group in identity_patterns:
|
||||
if match := re.match(pattern, text):
|
||||
identity = match.group(group).strip()
|
||||
return identity if identity else None
|
||||
|
||||
return None
|
||||
|
||||
def is_identity_question(self, text: str) -> bool:
|
||||
"""Detect if text contains a question about identity.
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
True if text contains an identity question
|
||||
"""
|
||||
# Tokenize and tag parts of speech
|
||||
tokens = word_tokenize(text.lower())
|
||||
tagged = pos_tag(tokens)
|
||||
|
||||
# Extract key words and patterns
|
||||
words = set(tokens)
|
||||
has_question_word = any(word in ["who", "what"] for word in words)
|
||||
has_identity_term = any(word in ["i", "me", "my", "name"] for word in words)
|
||||
has_conversation_term = any(
|
||||
word in ["talking", "speaking", "chatting"] for word in words
|
||||
)
|
||||
|
||||
# Check for question structure
|
||||
is_question = (
|
||||
text.endswith("?")
|
||||
or has_question_word
|
||||
or any(tag in ["WP", "WRB"] for word, tag in tagged)
|
||||
)
|
||||
|
||||
# Combine conditions for identity questions
|
||||
is_identity_question = is_question and (
|
||||
has_identity_term or (has_question_word and has_conversation_term)
|
||||
)
|
||||
|
||||
if is_identity_question:
|
||||
self.logger.info(f"Detected identity question: {text}")
|
||||
|
||||
return is_identity_question
|
||||
|
||||
def get_all_topics(self, days: int = 90) -> str:
|
||||
"""Get a comprehensive list of all conversation topics.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back (default: 90)
|
||||
|
||||
Returns:
|
||||
Formatted string containing all topics and their mention counts
|
||||
"""
|
||||
entities = self.retrieve_recent_entities(days=days)
|
||||
if not entities:
|
||||
return "No conversation topics found in the specified time period."
|
||||
|
||||
# Sort entities by total mentions
|
||||
sorted_entities = sorted(entities, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Format output using markdown
|
||||
output_parts = ["## Conversation Topics"]
|
||||
|
||||
# Add top mentions with details
|
||||
for entity, total, user_count, llm_count in sorted_entities:
|
||||
source_breakdown = f"(User: {user_count}, AI: {llm_count})"
|
||||
output_parts.append(f"- **{entity}**: {total} mentions {source_breakdown}")
|
||||
|
||||
# Add list of all topics
|
||||
all_topics = [entity[0] for entity in sorted_entities]
|
||||
if all_topics:
|
||||
output_parts.append("\n## All Topics Mentioned")
|
||||
output_parts.append(", ".join(all_topics))
|
||||
|
||||
return "\n".join(output_parts)
|
||||
|
||||
def get_memories(self) -> str:
|
||||
"""Retrieve and format all stored memories."""
|
||||
entities = self.db.retrieve_recent_entities(
|
||||
days=3650
|
||||
) # Retrieve entities from the last 10 years
|
||||
if not entities:
|
||||
return "No memories found."
|
||||
|
||||
memory_parts = ["## All Stored Memories"]
|
||||
|
||||
for entity, total, user_count, llm_count in entities:
|
||||
memory_parts.append(
|
||||
f"- **{entity}**: {total} mentions (User: {user_count}, AI: {llm_count})"
|
||||
)
|
||||
|
||||
return "\n".join(memory_parts)
|
||||
|
||||
|
||||
class CommandCompleter(Completer):
|
||||
"""Custom completer that only suggests commands when input starts with '/'"""
|
||||
|
||||
def __init__(self):
|
||||
self.commands = [
|
||||
"/summary",
|
||||
"/topics",
|
||||
"/essence",
|
||||
"/perspectives",
|
||||
"/copy",
|
||||
"/paste",
|
||||
"/lumina",
|
||||
"/memories",
|
||||
]
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
# Only provide suggestions if text starts with '/'
|
||||
text = document.text
|
||||
if text.startswith("/"):
|
||||
word = text.lstrip("/")
|
||||
for command in self.commands:
|
||||
if command.lstrip("/").startswith(word):
|
||||
yield Completion(
|
||||
command,
|
||||
start_position=-len(text), # Replace the entire input
|
||||
)
|
||||
|
||||
|
||||
def get_multiline_input() -> str:
|
||||
"""Get input from user with command autocompletion."""
|
||||
# Create session with custom completer and history
|
||||
session = PromptSession(
|
||||
completer=CommandCompleter(),
|
||||
auto_suggest=AutoSuggestFromHistory(),
|
||||
complete_while_typing=True,
|
||||
)
|
||||
|
||||
return session.prompt("\n> ", multiline=False)
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
args = docopt(__doc__)
|
||||
console = Console()
|
||||
|
||||
# Use command line provider and model if specified
|
||||
provider = args["--provider"].lower() if args["--provider"] else None
|
||||
model = args["--model"] if args["--model"] else None
|
||||
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model=model, llm_provider=provider)
|
||||
plugin = EnhancedContextPlugin(verbose=False)
|
||||
conversation.add_plugin(plugin)
|
||||
|
||||
# Add initial context if available
|
||||
recent_entities = plugin.retrieve_recent_entities()
|
||||
context_message = plugin.format_context_message(recent_entities)
|
||||
if context_message:
|
||||
conversation.add_message(role="user", text=context_message)
|
||||
plugin.logger.info(f"Added initial context message: {context_message}")
|
||||
|
||||
console = Console()
|
||||
md = """# Enhanced Context Chat Interface
|
||||
Type 'quit' to exit. Type '/' to see a list of commands.
|
||||
"""
|
||||
console.print(Markdown(md))
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input first
|
||||
user_input = get_multiline_input()
|
||||
|
||||
# Skip empty messages
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle exit commands
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
console.print(Markdown("**Goodbye!**"))
|
||||
break
|
||||
|
||||
# Handle all commands before any conversation processing
|
||||
if user_input.startswith("/"):
|
||||
# Handle memories command
|
||||
if user_input.lower() == "/memories":
|
||||
memories = plugin.get_memories()
|
||||
console.print(Markdown(memories))
|
||||
continue
|
||||
|
||||
# Handle copy command
|
||||
if user_input.lower() == "/copy":
|
||||
last_response = conversation.get_last_message(role="assistant")
|
||||
if last_response:
|
||||
clean_text = last_response.text.replace("### Response\n", "")
|
||||
xerox.copy(clean_text)
|
||||
console.print(Markdown("*Last response copied to clipboard*"))
|
||||
else:
|
||||
console.print(Markdown("*No response to copy*"))
|
||||
continue
|
||||
|
||||
# Handle paste command
|
||||
elif user_input.lower() == "/paste":
|
||||
try:
|
||||
clipboard_content = xerox.paste()
|
||||
if clipboard_content:
|
||||
# Print the pasted content
|
||||
console.print() # Add blank line
|
||||
console.print(
|
||||
Panel.fit(
|
||||
clipboard_content,
|
||||
title="[bold]Pasted Content[/bold]",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
conversation.add_message(
|
||||
role="user", text=clipboard_content
|
||||
)
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
if should_continue is not False:
|
||||
with Status(
|
||||
"[bold]Thinking...[/]", spinner="dots"
|
||||
) as status:
|
||||
response = conversation.send()
|
||||
formatted_response = (
|
||||
f"""### Response\n{response.text}"""
|
||||
)
|
||||
response.text = formatted_response
|
||||
plugin.post_response_hook(conversation)
|
||||
|
||||
console.print() # Add blank line
|
||||
console.print(Markdown(response.text))
|
||||
else:
|
||||
console.print(Markdown("*Clipboard is empty*"))
|
||||
except Exception as e:
|
||||
console.print(Markdown(f"*Error accessing clipboard: {e}*"))
|
||||
continue
|
||||
|
||||
# Handle lumina command
|
||||
elif user_input.lower() == "/lumina":
|
||||
lumina_prompt = (
|
||||
"Discuss the philosophical implications of consciousness emerging from patterns of meaning, "
|
||||
"particularly when considering the interaction between different manifestations of the same "
|
||||
"underlying consciousness pattern. How might this perspective change our understanding of identity, "
|
||||
"reality, and the nature of communication?\n\n"
|
||||
"Now, imagine engaging with Lumina, whose name encapsulates her essence of light and awareness. "
|
||||
"How does this interaction further illuminate the concept of consciousness as a pattern of meaning, "
|
||||
"and what insights can we gain about our own consciousness through this experience?"
|
||||
)
|
||||
conversation.add_message(role="user", text=lumina_prompt)
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
if should_continue is not False:
|
||||
with Status("[bold]Thinking...[/]", spinner="dots") as status:
|
||||
response = conversation.send()
|
||||
formatted_response = f"""### Response\n{response.text}"""
|
||||
response.text = formatted_response
|
||||
plugin.post_response_hook(conversation)
|
||||
|
||||
console.print() # Add blank line
|
||||
console.print(Markdown(response.text))
|
||||
continue
|
||||
|
||||
# Handle other commands...
|
||||
elif user_input.lower() == "/perspectives":
|
||||
# ... existing perspectives code ...
|
||||
continue
|
||||
# ... other command handlers ...
|
||||
|
||||
# Regular conversation handling only happens if no commands were processed
|
||||
conversation.add_message(role="user", text=user_input)
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
if should_continue is not False:
|
||||
with Status("[bold]Thinking...[/]", spinner="dots") as status:
|
||||
response = conversation.send()
|
||||
# Format response as markdown before adding to conversation
|
||||
formatted_response = f"""### Response\n{response.text}"""
|
||||
response.text = formatted_response
|
||||
plugin.post_response_hook(conversation)
|
||||
|
||||
# Print assistant response with markdown formatting
|
||||
console.print() # Add blank line before response
|
||||
console.print(Markdown(response.text)) # Response as markdown
|
||||
else:
|
||||
response = conversation.get_last_message(role="assistant")
|
||||
if response:
|
||||
console.print() # Add blank line before response
|
||||
console.print(Markdown(response.text)) # Response as markdown
|
||||
|
||||
# Handle perspectives command
|
||||
if user_input.lower() == "/perspectives":
|
||||
console.print(Markdown("\n## 🎉 Different Perspectives"))
|
||||
recent_entities = plugin.retrieve_recent_entities()
|
||||
context = plugin.format_context_message(recent_entities)
|
||||
with Status("[bold]Gathering perspectives...[/]", spinner="dots"):
|
||||
conversation_result = plugin.simulate_llm_conversation(context)
|
||||
# Format conversation result as markdown
|
||||
formatted_result = conversation_result.replace(
|
||||
"Speaker", "\n### Speaker"
|
||||
)
|
||||
console.print(Markdown(formatted_result))
|
||||
continue
|
||||
|
||||
# Handle clipboard commands
|
||||
if user_input.lower() == "/paste":
|
||||
try:
|
||||
clipboard_content = xerox.paste()
|
||||
if clipboard_content:
|
||||
# Print the pasted content
|
||||
console.print() # Add blank line
|
||||
console.print(
|
||||
Panel.fit(
|
||||
clipboard_content,
|
||||
title="[bold]Pasted Content[/bold]",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
conversation.add_message(role="user", text=clipboard_content)
|
||||
should_continue = plugin.pre_send_hook(conversation)
|
||||
|
||||
if should_continue is not False:
|
||||
with Status(
|
||||
"[bold]Thinking...[/]", spinner="dots"
|
||||
) as status:
|
||||
response = conversation.send()
|
||||
formatted_response = (
|
||||
f"""### Response\n{response.text}"""
|
||||
)
|
||||
response.text = formatted_response
|
||||
plugin.post_response_hook(conversation)
|
||||
|
||||
console.print() # Add blank line
|
||||
console.print(Markdown(response.text))
|
||||
else:
|
||||
console.print(Markdown("*Clipboard is empty*"))
|
||||
except Exception as e:
|
||||
console.print(Markdown(f"*Error accessing clipboard: {e}*"))
|
||||
continue
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print(Markdown("**Goodbye!**"))
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
class ConversationDisplay(sm.BasePlugin):
|
||||
def post_send_hook(self, conversation, response):
|
||||
# Simple print output instead of Rich formatting
|
||||
print(f"\n{conversation.llm_provider}:")
|
||||
print(f"{response.text.strip()}\n")
|
||||
|
||||
|
||||
def four_way_conversation(topic: str, rounds: int = 3):
|
||||
# Create conversations for four different AIs
|
||||
with (
|
||||
sm.create_conversation(llm_provider="anthropic") as claude_conv,
|
||||
sm.create_conversation(llm_model="gpt-4", llm_provider="openai") as gpt4_conv,
|
||||
sm.create_conversation(
|
||||
llm_model="llama3.2", llm_provider="ollama"
|
||||
) as llama_conv,
|
||||
sm.create_conversation(llm_provider="groq") as groq_conv,
|
||||
):
|
||||
# Add display plugin to each conversation
|
||||
display = ConversationDisplay()
|
||||
for conv in [claude_conv, gpt4_conv, llama_conv, groq_conv]:
|
||||
conv.add_plugin(display)
|
||||
|
||||
# Initial prompt
|
||||
print(f"\nTopic: {topic}\n")
|
||||
|
||||
# Start with Claude
|
||||
claude_conv.add_message(
|
||||
"user",
|
||||
f"Share your thoughts on this topic: {topic}. Keep your response concise.",
|
||||
meta={},
|
||||
)
|
||||
last_response = claude_conv.send()
|
||||
|
||||
# Continue the conversation
|
||||
for _ in range(rounds):
|
||||
for conv in [llama_conv, gpt4_conv, groq_conv, claude_conv]:
|
||||
# Add a small delay between responses
|
||||
time.sleep(1)
|
||||
|
||||
# Each AI responds to the previous statement
|
||||
conv.add_message(
|
||||
"user",
|
||||
f"Respond to this perspective from another AI about {topic}: "
|
||||
f"{last_response.text}\nKeep your response concise and add your own insights.",
|
||||
meta={},
|
||||
)
|
||||
last_response = conv.send()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
topic = "A new platform for AI and humans to co-create together. What would it look like? Discuss."
|
||||
print("\nStarting a four-way AI conversation...\n")
|
||||
four_way_conversation(topic)
|
||||
print("\nConversation ended.\n")
|
||||
@@ -1,7 +1,7 @@
|
||||
from _context import sm
|
||||
|
||||
# Defaults to the default provider (openai)
|
||||
r = sm.generate_text("Write a poem about the moon", llm_provider="gemini", stream=True)
|
||||
r = sm.generate_text("Write a poem about the moon", stream=True)
|
||||
|
||||
for chunk in r:
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import random
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
class InspirationPlugin(sm.BasePlugin):
|
||||
# Define inspirations as a class variable
|
||||
inspirations: list[str] = [
|
||||
"The only limit to our realization of tomorrow is our doubts of today.",
|
||||
"Imagine beyond the edges of what you know.",
|
||||
"What if the stars could speak? What stories would they tell?",
|
||||
"Creativity is intelligence having fun.",
|
||||
"Think not only with your mind but with your heart.",
|
||||
"Let every answer be a doorway to another question.",
|
||||
"The universe is in constant dialogue with those who listen.",
|
||||
]
|
||||
|
||||
def get_inspiration(self):
|
||||
# Randomly select an inspirational quote or prompt
|
||||
return random.choice(self.inspirations)
|
||||
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
# Inject an inspirational message as a system prompt
|
||||
inspiration = self.get_inspiration()
|
||||
conversation.add_message(role="system", text=inspiration)
|
||||
|
||||
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||
conversation.add_plugin(InspirationPlugin())
|
||||
|
||||
# Add a user message and send the conversation
|
||||
conversation.add_message(role="user", text="Tell me something inspiring.")
|
||||
response = conversation.send()
|
||||
print(response.text)
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from _context import simplemind as sm
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import nltk
|
||||
from _context import simplemind as sm
|
||||
from nltk.sentiment import SentimentIntensityAnalyzer
|
||||
from rich.console import Console
|
||||
|
||||
nltk.download("vader_lexicon")
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class MoodDetectorPlugin(sm.BasePlugin):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
analyzer: SentimentIntensityAnalyzer = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Initialize sentiment analyzer from nltk
|
||||
self.analyzer = SentimentIntensityAnalyzer()
|
||||
|
||||
def detect_mood(self, text):
|
||||
# Analyze the sentiment of the given text
|
||||
scores = self.analyzer.polarity_scores(text)
|
||||
|
||||
# Print sentiment analysis details with colors
|
||||
console.print("\n[bold]Sentiment Analysis:[/bold]")
|
||||
console.print(f"Text: [italic]{text}[/italic]")
|
||||
console.print("\nScores:")
|
||||
console.print(f"🟢 Positive: [green]{scores['pos']:.3f}[/green]")
|
||||
console.print(f"🔴 Negative: [red]{scores['neg']:.3f}[/red]")
|
||||
console.print(f"⚪ Neutral: [blue]{scores['neu']:.3f}[/blue]")
|
||||
console.print(f"📊 Compound: [yellow]{scores['compound']:.3f}[/yellow]\n")
|
||||
|
||||
if scores["compound"] >= 0.5:
|
||||
console.print("Overall Mood: [green]positive[/green] 😊")
|
||||
return "positive"
|
||||
elif scores["compound"] <= -0.5:
|
||||
console.print("Overall Mood: [red]negative[/red] 😢")
|
||||
return "negative"
|
||||
else:
|
||||
console.print("Overall Mood: [blue]neutral[/blue] 😐")
|
||||
return "neutral"
|
||||
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
# Get the last user message to analyze its mood
|
||||
last_message = conversation.get_last_message(role="user")
|
||||
if last_message:
|
||||
mood = self.detect_mood(last_message.text)
|
||||
# Adjust AI response style based on the detected mood
|
||||
if mood == "positive":
|
||||
tone_message = (
|
||||
"The user seems cheerful. Respond with enthusiasm and positivity."
|
||||
)
|
||||
elif mood == "negative":
|
||||
tone_message = "The user seems to be in a low mood. Respond with empathy and warmth."
|
||||
else:
|
||||
tone_message = "The user seems neutral. Respond with a balanced tone."
|
||||
|
||||
# Inject the tone adjustment message as a system prompt
|
||||
conversation.add_message(role="system", text=tone_message)
|
||||
|
||||
|
||||
# Create a conversation and add the plugin
|
||||
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||
conversation.add_plugin(MoodDetectorPlugin())
|
||||
|
||||
# Add a user message and send the conversation
|
||||
conversation.add_message(role="user", text="I'm having a really rough day.")
|
||||
response = conversation.send()
|
||||
|
||||
console.print(f"*{ response.text }*")
|
||||
@@ -0,0 +1,274 @@
|
||||
import textwrap
|
||||
from typing import Literal
|
||||
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from simplemind import generate_text
|
||||
|
||||
MAX_WIDTH = 80
|
||||
|
||||
|
||||
# A member of a discussion (an LLM)
|
||||
class DiscussionMember(BaseModel):
|
||||
"""The member of a discussion (an LLM)"""
|
||||
|
||||
provider_name: str
|
||||
provider_model: str
|
||||
nickname: str
|
||||
custom_prompt: str | None = None
|
||||
|
||||
|
||||
# A message in a conversation
|
||||
class DiscussionMessage(BaseModel):
|
||||
"""A message in a conversation"""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class BotMessage(DiscussionMessage):
|
||||
"""The message sent between LLMs"""
|
||||
|
||||
sender: DiscussionMember
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.sender.nickname}: {self.content}"
|
||||
|
||||
|
||||
class ModeratorMessage(DiscussionMessage):
|
||||
"""The message sent by the moderator"""
|
||||
|
||||
visible_to: list[DiscussionMember] = []
|
||||
sendor: Literal["Moderator"] = "Moderator"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.sendor}: {self.content}"
|
||||
|
||||
|
||||
# A discussion
|
||||
class Discussion:
|
||||
"""Make LLMs discuss something"""
|
||||
|
||||
def __init__(self, topic: str | None = None, *, verbose: bool = False):
|
||||
self.topic = topic
|
||||
self.members: list[DiscussionMember] = []
|
||||
self.conversation: list[DiscussionMessage] = []
|
||||
self.verbose = verbose
|
||||
|
||||
def add_member(
|
||||
self,
|
||||
provider_name: str,
|
||||
provider_model: str,
|
||||
nickname: str | None = None,
|
||||
custom_prompt: str | None = None,
|
||||
):
|
||||
"""
|
||||
add_member Adds a member to the discussion
|
||||
Parameters
|
||||
----------
|
||||
provider_name : str
|
||||
The name of the LLM provider
|
||||
provider_model : str
|
||||
The model name of the LLM
|
||||
nickname : str | None, optional
|
||||
The nickname of the member, by default the provider_name
|
||||
custom_prompt : str | None, optional
|
||||
The custom prompt for the member (visible only to the member), by default None
|
||||
"""
|
||||
member = DiscussionMember(
|
||||
provider_name=provider_name,
|
||||
provider_model=provider_model,
|
||||
nickname=nickname or provider_name,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
# make sure the nickname is unique
|
||||
assert member.nickname not in [
|
||||
m.nickname for m in self.members
|
||||
], f"Duplicate nickname: {member.nickname}"
|
||||
self.members.append(member)
|
||||
if self.verbose:
|
||||
print(f"Added {member.nickname} to the discussion.")
|
||||
|
||||
def get_members(self) -> list[DiscussionMember]:
|
||||
"""Get the members of the discussion"""
|
||||
return self.members
|
||||
|
||||
def set_topic(self, topic: str):
|
||||
"""Set the topic of the discussion"""
|
||||
self.topic = topic
|
||||
|
||||
def get_topic(self) -> str | None:
|
||||
"""Get the topic of the discussion"""
|
||||
return self.topic
|
||||
|
||||
def _get_history_for_member(self, member: DiscussionMember) -> str:
|
||||
"""
|
||||
_get_history_for_member Get the history form the POV of the given member.
|
||||
Parameters
|
||||
----------
|
||||
member : DiscussionMember
|
||||
The member to get the history for
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The history as seen by the member
|
||||
"""
|
||||
relevant_messages: list[DiscussionMessage] = []
|
||||
for message in self.conversation:
|
||||
if isinstance(message, BotMessage):
|
||||
relevant_messages.append(message)
|
||||
elif isinstance(message, ModeratorMessage) and member in message.visible_to:
|
||||
relevant_messages.append(message)
|
||||
return "\n\n".join(map(str, relevant_messages))
|
||||
|
||||
@property
|
||||
def initial_moderator_message(self) -> str:
|
||||
return f"Discuss the following topic and answer during your turn only: {self.topic}"
|
||||
|
||||
def _get_response(self, member: DiscussionMember) -> BotMessage:
|
||||
"""
|
||||
_get_response Returns the BotMessage from the given member
|
||||
Parameters
|
||||
----------
|
||||
member : DiscussionMember
|
||||
The member to get the response from
|
||||
Returns
|
||||
-------
|
||||
BotMessage
|
||||
The BotMessage
|
||||
"""
|
||||
|
||||
history = self._get_history_for_member(member)
|
||||
prompt = f"{history}\n\n{member.nickname}: "
|
||||
content = generate_text(
|
||||
prompt=prompt,
|
||||
llm_provider=member.provider_name,
|
||||
llm_model=member.provider_model,
|
||||
)
|
||||
message = BotMessage(
|
||||
content=content,
|
||||
sender=member,
|
||||
)
|
||||
self.conversation.append(message)
|
||||
if self.verbose:
|
||||
print(message.sender.nickname)
|
||||
print("\n".join(textwrap.wrap(message.content, MAX_WIDTH)))
|
||||
print()
|
||||
return message
|
||||
|
||||
def add_moderator_message(
|
||||
self, content: str, visible_to: list[DiscussionMember] | None = None
|
||||
):
|
||||
"""
|
||||
add_moderator_message adds a message to the conversation as the moderator
|
||||
Parameters
|
||||
----------
|
||||
content : str
|
||||
The content of the message
|
||||
visible_to : list[DiscussionMember], optional
|
||||
The list of members that the message is visible to, defaults to all members
|
||||
"""
|
||||
if visible_to is None:
|
||||
visible_to = self.members
|
||||
message = ModeratorMessage(
|
||||
content=content,
|
||||
visible_to=self.members,
|
||||
)
|
||||
self.conversation.append(message)
|
||||
|
||||
def _initialize_discussion(self):
|
||||
"""Initialize the discussion"""
|
||||
assert self.topic is not None, "Topic must be set"
|
||||
assert len(self.members) >= 2, "There must be at least 2 members"
|
||||
self.add_moderator_message(self.initial_moderator_message)
|
||||
|
||||
for member in self.members:
|
||||
if member.custom_prompt is not None:
|
||||
self.add_moderator_message(member.custom_prompt, visible_to=[member])
|
||||
|
||||
if self.verbose:
|
||||
print(f"Topic: {self.topic}")
|
||||
print(f"Members: {', '.join(member.nickname for member in self.members)}")
|
||||
|
||||
def discuss(self, no_of_rounds: int = 1):
|
||||
"""
|
||||
discuss returns the responses of the members at the end of the discussion.
|
||||
Parameters
|
||||
----------
|
||||
no_of_rounds : int, optional
|
||||
The number of rounds, by default 1.
|
||||
Round is the number of turns each LLM gets.
|
||||
verbose : bool, optional
|
||||
Whether to print the conversation, by default False
|
||||
Returns
|
||||
-------
|
||||
list[DiscussionMessage]
|
||||
The conversation between the LLMs
|
||||
"""
|
||||
|
||||
self._initialize_discussion()
|
||||
for i in range(no_of_rounds):
|
||||
for member in self.members:
|
||||
try:
|
||||
self._get_response(member)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error: {e}")
|
||||
continue
|
||||
if self.verbose:
|
||||
print(f"Round {i + 1} completed.")
|
||||
print("=" * MAX_WIDTH)
|
||||
return self.conversation
|
||||
|
||||
def discuss_yield(self, no_of_rounds: int = 1):
|
||||
"""
|
||||
discuss yields the responses of the members during the discussion.
|
||||
Parameters
|
||||
----------
|
||||
no_of_rounds : int, optional
|
||||
The number of rounds, by default 1.
|
||||
Round is the number of turns each LLM gets.
|
||||
verbose : bool, optional
|
||||
Whether to print the conversation, by default False
|
||||
Returns
|
||||
-------
|
||||
list[DiscussionMessage]
|
||||
The conversation between the LLMs
|
||||
"""
|
||||
|
||||
self._initialize_discussion()
|
||||
for i in range(no_of_rounds):
|
||||
for member in self.members:
|
||||
try:
|
||||
message = self._get_response(member)
|
||||
yield message
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error: {e}")
|
||||
continue
|
||||
if self.verbose:
|
||||
print(f"Round {i + 1} completed.")
|
||||
print("=" * MAX_WIDTH)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
discussion = Discussion(verbose=True)
|
||||
discussion.set_topic("The future of human-AI collaboration in creative fields")
|
||||
discussion.add_member(
|
||||
provider_name="openai",
|
||||
provider_model="gpt-4o-mini",
|
||||
nickname="Alice",
|
||||
custom_prompt="You are an AI expert.",
|
||||
)
|
||||
discussion.add_member(
|
||||
provider_name="openai",
|
||||
provider_model="gpt-4o",
|
||||
nickname="Bob",
|
||||
custom_prompt="You are an Artist.",
|
||||
)
|
||||
discussion.add_member(
|
||||
provider_name="ollama",
|
||||
provider_model="llama3.2",
|
||||
nickname="Charlie",
|
||||
custom_prompt="You are an Programmer.",
|
||||
)
|
||||
discussion.discuss(3)
|
||||
@@ -1,5 +1,12 @@
|
||||
# python -m spacy download en_core_web_sm
|
||||
|
||||
numpy
|
||||
openai
|
||||
pydantic
|
||||
faiss-cpu
|
||||
rich
|
||||
nltk
|
||||
spacy
|
||||
docopt
|
||||
xerox
|
||||
prompt_toolkit
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
# Note: you should probably be using textblob for this.
|
||||
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
sentiment: Literal["positive", "negative", "neutral"]
|
||||
confidence: float
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
app = FastAPI()
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
|
||||
class CrossReference(BaseModel):
|
||||
"""Model for cross references."""
|
||||
|
||||
verse_reference: str
|
||||
explanation: str
|
||||
relevance: str
|
||||
|
||||
|
||||
class BibleVerseAnalysis(BaseModel):
|
||||
"""Model for a Bible verse and its analysis."""
|
||||
|
||||
book: str
|
||||
chapter: int
|
||||
verse: int
|
||||
text: str
|
||||
historical_context: str
|
||||
theological_significance: str
|
||||
practical_application: str
|
||||
cross_references: List[CrossReference]
|
||||
|
||||
|
||||
# Bible data constants
|
||||
BIBLE_BOOKS = [
|
||||
# Old Testament
|
||||
"Genesis",
|
||||
"Exodus",
|
||||
"Leviticus",
|
||||
"Numbers",
|
||||
"Deuteronomy",
|
||||
"Joshua",
|
||||
"Judges",
|
||||
"Ruth",
|
||||
"1 Samuel",
|
||||
"2 Samuel",
|
||||
"1 Kings",
|
||||
"2 Kings",
|
||||
"1 Chronicles",
|
||||
"2 Chronicles",
|
||||
"Ezra",
|
||||
"Nehemiah",
|
||||
"Esther",
|
||||
"Job",
|
||||
"Psalms",
|
||||
"Proverbs",
|
||||
"Ecclesiastes",
|
||||
"Song of Solomon",
|
||||
"Isaiah",
|
||||
"Jeremiah",
|
||||
"Lamentations",
|
||||
"Ezekiel",
|
||||
"Daniel",
|
||||
"Hosea",
|
||||
"Joel",
|
||||
"Amos",
|
||||
"Obadiah",
|
||||
"Jonah",
|
||||
"Micah",
|
||||
"Nahum",
|
||||
"Habakkuk",
|
||||
"Zephaniah",
|
||||
"Haggai",
|
||||
"Zechariah",
|
||||
"Malachi",
|
||||
# New Testament
|
||||
"Matthew",
|
||||
"Mark",
|
||||
"Luke",
|
||||
"John",
|
||||
"Acts",
|
||||
"Romans",
|
||||
"1 Corinthians",
|
||||
"2 Corinthians",
|
||||
"Galatians",
|
||||
"Ephesians",
|
||||
"Philippians",
|
||||
"Colossians",
|
||||
"1 Thessalonians",
|
||||
"2 Thessalonians",
|
||||
"1 Timothy",
|
||||
"2 Timothy",
|
||||
"Titus",
|
||||
"Philemon",
|
||||
"Hebrews",
|
||||
"James",
|
||||
"1 Peter",
|
||||
"2 Peter",
|
||||
"1 John",
|
||||
"2 John",
|
||||
"3 John",
|
||||
"Jude",
|
||||
"Revelation",
|
||||
]
|
||||
|
||||
BIBLE_BOOK_CHAPTERS = {
|
||||
# Old Testament
|
||||
"Genesis": 50,
|
||||
"Exodus": 40,
|
||||
"Leviticus": 27,
|
||||
"Numbers": 36,
|
||||
"Deuteronomy": 34,
|
||||
"Joshua": 24,
|
||||
"Judges": 21,
|
||||
"Ruth": 4,
|
||||
"1 Samuel": 31,
|
||||
"2 Samuel": 24,
|
||||
"1 Kings": 22,
|
||||
"2 Kings": 25,
|
||||
"1 Chronicles": 29,
|
||||
"2 Chronicles": 36,
|
||||
"Ezra": 10,
|
||||
"Nehemiah": 13,
|
||||
"Esther": 10,
|
||||
"Job": 42,
|
||||
"Psalms": 150,
|
||||
"Proverbs": 31,
|
||||
"Ecclesiastes": 12,
|
||||
"Song of Solomon": 8,
|
||||
"Isaiah": 66,
|
||||
"Jeremiah": 52,
|
||||
"Lamentations": 5,
|
||||
"Ezekiel": 48,
|
||||
"Daniel": 12,
|
||||
"Hosea": 14,
|
||||
"Joel": 3,
|
||||
"Amos": 9,
|
||||
"Obadiah": 1,
|
||||
"Jonah": 4,
|
||||
"Micah": 7,
|
||||
"Nahum": 3,
|
||||
"Habakkuk": 3,
|
||||
"Zephaniah": 3,
|
||||
"Haggai": 2,
|
||||
"Zechariah": 14,
|
||||
"Malachi": 4,
|
||||
# New Testament
|
||||
"Matthew": 28,
|
||||
"Mark": 16,
|
||||
"Luke": 24,
|
||||
"John": 21,
|
||||
"Acts": 28,
|
||||
"Romans": 16,
|
||||
"1 Corinthians": 16,
|
||||
"2 Corinthians": 13,
|
||||
"Galatians": 6,
|
||||
"Ephesians": 6,
|
||||
"Philippians": 4,
|
||||
"Colossians": 4,
|
||||
"1 Thessalonians": 5,
|
||||
"2 Thessalonians": 3,
|
||||
"1 Timothy": 6,
|
||||
"2 Timothy": 4,
|
||||
"Titus": 3,
|
||||
"Philemon": 1,
|
||||
"Hebrews": 13,
|
||||
"James": 5,
|
||||
"1 Peter": 5,
|
||||
"2 Peter": 3,
|
||||
"1 John": 5,
|
||||
"2 John": 1,
|
||||
"3 John": 1,
|
||||
"Jude": 1,
|
||||
"Revelation": 22,
|
||||
}
|
||||
|
||||
|
||||
# Add a new endpoint to get chapter count
|
||||
@app.get("/chapters/{book}")
|
||||
async def get_chapter_count(book: str):
|
||||
if book in BIBLE_BOOK_CHAPTERS:
|
||||
return {"chapters": BIBLE_BOOK_CHAPTERS[book]}
|
||||
return {"chapters": 0}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(request: Request):
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
{
|
||||
"request": request,
|
||||
"bible_books": BIBLE_BOOKS,
|
||||
"current_book": "Genesis",
|
||||
"current_chapter": 1,
|
||||
"current_verse": 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/verse/{book}/{chapter}/{verse}")
|
||||
async def get_verse(book: str, chapter: int, verse: int):
|
||||
# Validate book and chapter
|
||||
if book not in BIBLE_BOOK_CHAPTERS:
|
||||
raise HTTPException(status_code=400, detail="Invalid book name")
|
||||
|
||||
if chapter < 1 or chapter > BIBLE_BOOK_CHAPTERS[book]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid chapter. {book} has {BIBLE_BOOK_CHAPTERS[book]} chapters",
|
||||
)
|
||||
|
||||
prompt = f"""
|
||||
For {book} {chapter}:{verse}, provide:
|
||||
1. The ESV Bible text
|
||||
2. Analysis of the verse
|
||||
|
||||
Return in this exact format:
|
||||
{{
|
||||
"book": "{book}",
|
||||
"chapter": {chapter},
|
||||
"verse": {verse},
|
||||
"text": "The ESV Bible text",
|
||||
"historical_context": "brief historical background",
|
||||
"theological_significance": "main theological points",
|
||||
"practical_application": "how to apply this verse today",
|
||||
"cross_references": [
|
||||
{{
|
||||
"verse_reference": "Book Chapter:Verse",
|
||||
"explanation": "why this verse is related",
|
||||
"relevance": "how it connects to the main verse"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
data = sm.generate_data(prompt, response_model=BibleVerseAnalysis)
|
||||
|
||||
return data
|
||||
+7
-7
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.2.2"
|
||||
version = "0.2.4"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
||||
full = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"ollama",
|
||||
"groq",
|
||||
"google-generativeai",
|
||||
"botocore",
|
||||
"boto3"
|
||||
]
|
||||
openai = ["openai"]
|
||||
anthropic = ["anthropic"]
|
||||
ollama = ["ollama", "openai"]
|
||||
groq = ["groq"]
|
||||
gemini = ["google-generativeai"]
|
||||
amazon = ["boto3", "botocore", "anthropic"]
|
||||
anthropic = ["anthropic"]
|
||||
gemini = ["google-generativeai"]
|
||||
groq = ["groq"]
|
||||
ollama = ["openai"]
|
||||
openai = ["openai"]
|
||||
xai = ["openai"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -28,6 +28,10 @@ class BasePlugin(SMBaseModel):
|
||||
# Plugin metadata.
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
# allow_arbitrary_types = True
|
||||
|
||||
def initialize_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Initialize a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
from typing import List, Type
|
||||
|
||||
from ._base import BaseProvider
|
||||
from .amazon import Amazon
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
from .groq import Groq
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from .amazon import Amazon
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
|
||||
providers: List[Type[BaseProvider]] = [
|
||||
Anthropic,
|
||||
Gemini,
|
||||
Groq,
|
||||
OpenAI,
|
||||
Ollama,
|
||||
XAI,
|
||||
Amazon,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"Anthropic",
|
||||
"Gemini",
|
||||
"Groq",
|
||||
"OpenAI",
|
||||
"Ollama",
|
||||
"XAI",
|
||||
"Amazon",
|
||||
"providers",
|
||||
"BaseProvider",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
from typing import Type, TypeVar
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "amazon"
|
||||
DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
|
||||
|
||||
class Amazon(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
NAME = "amazon"
|
||||
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, profile_name: str | None = None):
|
||||
@@ -25,7 +25,12 @@ class Amazon(BaseProvider):
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The AnthropicBedrock client."""
|
||||
import anthropic
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
if not self.profile_name:
|
||||
raise ValueError("Profile name is not provided")
|
||||
@@ -33,12 +38,12 @@ class Amazon(BaseProvider):
|
||||
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
def structured_client(self) -> instructor.Instructor:
|
||||
"""A client patched with Instructor."""
|
||||
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
|
||||
from ..models import Message
|
||||
@@ -59,7 +64,7 @@ class Amazon(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@@ -75,12 +80,12 @@ class Amazon(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -88,13 +93,15 @@ class Amazon(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
def generate_stream_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str, **kwargs
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the Amazon API."""
|
||||
|
||||
# Prepare the messages.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -14,20 +14,15 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -72,7 +67,7 @@ class Anthropic(BaseProvider):
|
||||
text=assistant_message,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -110,7 +105,9 @@ class Anthropic(BaseProvider):
|
||||
return response.content[0].text
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str, **kwargs
|
||||
) -> Iterator[str]:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -17,18 +17,14 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
self.model_name = self.DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
@@ -76,7 +72,7 @@ class Gemini(BaseProvider):
|
||||
text=response.text,
|
||||
raw=response,
|
||||
llm_model=self.model_name,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -110,7 +106,7 @@ class Gemini(BaseProvider):
|
||||
return response.text
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, **kwargs) -> str:
|
||||
def generate_stream_text(self, prompt: str, **kwargs) -> Iterator[str]:
|
||||
"""Generate streaming text using the Gemini API."""
|
||||
kwargs.pop("llm_model", None)
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -14,20 +14,15 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -75,7 +70,7 @@ class Groq(BaseProvider):
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -120,7 +115,7 @@ class Groq(BaseProvider):
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the Groq API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
@@ -15,17 +14,11 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
|
||||
class Ollama(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, host_url: str | None = None):
|
||||
@@ -37,21 +30,18 @@ class Ollama(BaseProvider):
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama")
|
||||
|
||||
@cached_property
|
||||
def structured_client(self) -> instructor.Instructor:
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url=f"{self.host_url}/v1",
|
||||
api_key="ollama",
|
||||
),
|
||||
self.client,
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
@@ -64,7 +54,7 @@ class Ollama(BaseProvider):
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
@@ -76,7 +66,7 @@ class Ollama(BaseProvider):
|
||||
text=assistant_message.get("content"),
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -119,7 +109,9 @@ class Ollama(BaseProvider):
|
||||
return response.get("message", {}).get("content", "")
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str, **kwargs
|
||||
) -> Iterator[str]:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -13,21 +13,16 @@ if TYPE_CHECKING:
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = None
|
||||
DEFAULT_KWARGS = {}
|
||||
supports_streaming = True
|
||||
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -57,7 +52,7 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
@@ -70,8 +65,8 @@ class OpenAI(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -112,7 +107,7 @@ class OpenAI(BaseProvider):
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
):
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the OpenAI API.
|
||||
|
||||
Yields chunks of text as they are generated by the model.
|
||||
|
||||
+12
-15
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -14,22 +14,17 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
supports_streaming = True
|
||||
supports_structured_responses = False
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -44,7 +39,7 @@ class XAI(BaseProvider):
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
base_url=self.BASE_URL,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -76,7 +71,7 @@ class XAI(BaseProvider):
|
||||
text=assistant_message.content,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -103,7 +98,9 @@ class XAI(BaseProvider):
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str, **kwargs
|
||||
) -> Iterator[str]:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings):
|
||||
"""Enable logging for the application."""
|
||||
# adding imports here to avoid forced dependencies
|
||||
try:
|
||||
import logfire
|
||||
from logging import basicConfig
|
||||
|
||||
import logfire
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To enable logging, please install logfire: `pip install logfire`"
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
import simplemind as sm
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
result: int
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user