mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
217 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b5a901efaf | |||
| 9ccef9abdc | |||
| 3421de0fc1 | |||
| 54b0007947 | |||
| 90af44ace0 | |||
| cff3bff3d5 | |||
| 3abbb79f6c | |||
| 59c1bd3a0f | |||
| 052781014d | |||
| db28f1195c | |||
| b0a7197c6e | |||
| 7684c2568b | |||
| 8b90dbba40 | |||
| 752ccb1de8 | |||
| 391bfaaeab | |||
| d963bc0b1c | |||
| 0c1f225252 | |||
| 4decaa0722 | |||
| 39b5a5e19d | |||
| ef38fea767 | |||
| 8181f37fed | |||
| 3aacfd51ee | |||
| a2991eec0c | |||
| 9ae9a2703a | |||
| 0661b097d2 | |||
| fad442ba3f | |||
| 5b9624c385 | |||
| 8ff0521e17 | |||
| d5bdb712e9 | |||
| a97f9be2c8 | |||
| 107f983a18 | |||
| 2404e2c977 | |||
| c87a598286 | |||
| 9662b60177 | |||
| ea997aae7b | |||
| 081baf203c | |||
| 4cb18e9e3b | |||
| 0462ea0e38 | |||
| 8492ec9456 | |||
| 1709055e1a | |||
| 5fa67c3b2f | |||
| b7e950a8f0 | |||
| 735c6ba665 | |||
| 9132030cbd | |||
| aeea8936ce | |||
| c2303114ab | |||
| fe5af93780 | |||
| e79b474215 | |||
| fe2ca9d5f5 | |||
| 670240b943 | |||
| 2e66c0232b | |||
| 8b1f63f796 | |||
| 5d7a917d23 | |||
| 9703332967 | |||
| fe6001e710 | |||
| 63343d1c61 | |||
| ece056a5e0 | |||
| f44ec977a4 | |||
| 33f8fcde11 | |||
| 598bcd514d | |||
| 8bdbe4d8d5 | |||
| d4068cf07a | |||
| 747488f633 | |||
| 9ae03685b5 | |||
| 91af281a9d | |||
| 309f390800 | |||
| b316352311 | |||
| 236020b3b9 | |||
| 8a5a29f864 | |||
| 30d8412bbf | |||
| 4a852e6220 | |||
| 7f5ba667bd | |||
| 4b87a8b91c | |||
| 4c1d1fa873 | |||
| 0087a7e8f2 | |||
| 07715ed8df | |||
| 03f91c5153 | |||
| aa601648c6 | |||
| a26c51014b | |||
| b3946f1ff9 | |||
| 7a84ade5a4 | |||
| 3e1d1f98ad | |||
| 48e6ef2a43 | |||
| 1528dc2a21 | |||
| 46cd19ea90 | |||
| 2848e86dce | |||
| 6aadc9fcd7 | |||
| a8792319a8 | |||
| 3e8d5662d2 | |||
| 51c1646ef4 | |||
| f09052c18e | |||
| 1d3ae26301 | |||
| 44fd3468fa | |||
| 5770c37edf | |||
| 37334a21c5 | |||
| 57d54abf24 | |||
| c3397488e3 | |||
| 678a8a8b32 | |||
| a5c7486dfc | |||
| 5c6650f2b2 | |||
| 549d74e146 | |||
| 328be94677 | |||
| 7b21b9f258 | |||
| d7f8418f23 | |||
| 9968f162d6 | |||
| cb73621e39 | |||
| 4721dd8cc0 | |||
| bdb1ff0e69 | |||
| 94f381032e | |||
| b3a35cadd4 | |||
| 718f5a66c0 | |||
| df02547dec | |||
| 9dd89b7ef1 | |||
| 15ee5d1cf9 | |||
| 25ba1a9289 | |||
| 22aff505c4 | |||
| 29b2008edf | |||
| c5c99a05fd | |||
| cb969dec4c | |||
| 1aeeb9127d | |||
| c21f68aad6 | |||
| a68bd74fd8 | |||
| 4d2c81850e | |||
| 64246658b0 | |||
| f0aff7814b | |||
| 72121c121d | |||
| 028e89b080 | |||
| e13d03f40b | |||
| 0fc49c7e13 | |||
| d6afbd1fd0 | |||
| 27d30ccfe8 | |||
| b6b1a4f9f3 | |||
| 36d6ca4a11 | |||
| 90593d7919 | |||
| efe1a62d73 | |||
| 92819112bb | |||
| 275ab39d67 | |||
| 74db69c6e9 | |||
| 7b633ce880 | |||
| a651afb8a6 | |||
| 33e53562ae | |||
| 931285f8ce | |||
| e47ada4598 | |||
| 7e83532765 | |||
| 34e8a9d190 | |||
| c496712a9a | |||
| 3d8e169a08 | |||
| b74af7c8d8 | |||
| fa3ee731df | |||
| 8e4fdb9832 | |||
| 3d397d0474 | |||
| 7508723469 | |||
| f2a3fd76ae | |||
| 089812e335 | |||
| e977dd3eab | |||
| e7aad65b37 | |||
| a091a847a8 | |||
| faca663825 | |||
| 825ab22b95 | |||
| 18a51c7cd3 | |||
| 65570bfede | |||
| c6c7f2ac09 | |||
| cc6611647a | |||
| 8f9036fa32 | |||
| b7b5e1e187 | |||
| 7220c8bd3f | |||
| 176045531a | |||
| 6dc9108836 | |||
| e8c5ebc6a8 | |||
| 2c26895010 | |||
| 2f1c69a79e | |||
| bf1a936777 | |||
| a4efa47f6e | |||
| 3721fa6713 | |||
| db32ee26c1 | |||
| 8797c9e82f | |||
| ef01ce2f22 | |||
| d591125eb8 | |||
| 225d00deee | |||
| df716a1f19 | |||
| d6ad22721f | |||
| 7b4f2fcf8e | |||
| d8fce7b6d9 | |||
| 47ce8069f5 | |||
| 9114211867 | |||
| 49421b5213 | |||
| b7cc767a45 | |||
| 7ea33dec5a | |||
| 5d194a7f63 | |||
| 1a7693de0f | |||
| c0474aafeb | |||
| c9d7a7d622 | |||
| 1696d698e5 | |||
| e28d4660e8 | |||
| d4491e42b9 | |||
| 542677cffd | |||
| 528f806e65 | |||
| 373af44421 | |||
| 947d8ab6ad | |||
| 0ff966b307 | |||
| 75a42044e5 | |||
| cc66dbf8e5 | |||
| a174e60a1e | |||
| b03695f626 | |||
| 082bc24e91 | |||
| aca1b87180 | |||
| 1ff4c5660e | |||
| 241a7ab402 | |||
| 76fa7521eb | |||
| cbec2c5f6d | |||
| 34f463839c | |||
| 75c42278a2 | |||
| c25f1e1058 | |||
| 2a5966eb10 | |||
| f19263d309 | |||
| 25b742db1f | |||
| 8d83050a64 |
@@ -4,3 +4,5 @@ export GROQ_API_KEY=""
|
|||||||
export OLLAMA_HOST_URL=""
|
export OLLAMA_HOST_URL=""
|
||||||
export OPENAI_API_KEY=""
|
export OPENAI_API_KEY=""
|
||||||
export XAI_API_KEY=""
|
export XAI_API_KEY=""
|
||||||
|
export AMAZON_PROFILE_NAME=""
|
||||||
|
export DEEPSEEK_API_KEY=""
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
github: kennethreitz
|
||||||
|
thanks_dev: kennethreitz
|
||||||
|
custom: https://cash.app/$KennethReitz
|
||||||
@@ -168,3 +168,5 @@ cython_debug/
|
|||||||
src/**
|
src/**
|
||||||
requirements.txt
|
requirements.txt
|
||||||
Pipfile
|
Pipfile
|
||||||
|
enhanced_context.db
|
||||||
|
enhanced_context_sarah.db
|
||||||
|
|||||||
@@ -1,6 +1,53 @@
|
|||||||
Release History
|
Release History
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
## 0.3.3 (2024-02-08)
|
||||||
|
|
||||||
|
- Improve openai provider by removing debug print statements.
|
||||||
|
|
||||||
|
## 0.3.2 (2024-01-27)
|
||||||
|
|
||||||
|
- Improve Deepseek provider.
|
||||||
|
|
||||||
|
## 0.3.1 (2024-01-27)
|
||||||
|
|
||||||
|
- Introduce Deepseek provider.
|
||||||
|
|
||||||
|
## 0.3.0 (2024-11-12)
|
||||||
|
|
||||||
|
- Introduce save / load functionality for `Conversation`.
|
||||||
|
|
||||||
|
## 0.2.4 (2024-11-11)
|
||||||
|
|
||||||
|
- General improvements.
|
||||||
|
|
||||||
|
## 0.2.3 (2024-11-04)
|
||||||
|
|
||||||
|
- Remove default max-tokens for OpenAI provider.
|
||||||
|
|
||||||
|
## 0.2.3 (2024-11-03)
|
||||||
|
|
||||||
|
- Update default model for Amazon provider.
|
||||||
|
- Improved logging to handle streaming functions.
|
||||||
|
|
||||||
|
## 0.2.2 (2024-11-02)
|
||||||
|
|
||||||
|
- Add streaming support (set `stream=True` to `generate_text`).
|
||||||
|
- `conv.prepend_system_message` now uses system role by default.
|
||||||
|
- Add `provider.supports_streaming` property.
|
||||||
|
- Add `provider.supports_structured_response` property.
|
||||||
|
- General improvements.
|
||||||
|
|
||||||
|
## 0.2.1 (2024-11-01)
|
||||||
|
|
||||||
|
- Add `cached_property` to Amazon provider.
|
||||||
|
|
||||||
|
## 0.2.0 (2024-11-01)
|
||||||
|
|
||||||
|
- Add Amazon Bedrock provider.
|
||||||
|
- Make all provider optional dependencies. Use `$ pip install 'simplemind[full]'` to install all providers.
|
||||||
|
- General improvements.
|
||||||
|
|
||||||
## 0.1.7 (2024-11-01)
|
## 0.1.7 (2024-11-01)
|
||||||
|
|
||||||
- Add `logger` decorator.
|
- Add `logger` decorator.
|
||||||
|
|||||||
+21
@@ -0,0 +1,21 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
RUN pip install uv
|
||||||
|
|
||||||
|
# Create and set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements/project files
|
||||||
|
ONBUILD COPY . .
|
||||||
|
|
||||||
|
# Install dependencies using uv
|
||||||
|
RUN uv pip install "simplemind[full]" --system
|
||||||
|
|
||||||
|
# Set default command
|
||||||
|
CMD ["python"]
|
||||||
@@ -10,36 +10,76 @@ Simplemind is AI library designed to simplify your experience with AI APIs in Py
|
|||||||
|
|
||||||
With Simplemind, tapping into AI is as easy as a friendly conversation.
|
With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||||
|
|
||||||
- **Easy-to-use AI tools**: SimpleMind provides simple interfaces to popular AI services.
|
- **Easy-to-use AI tools**: Simplemind provides simple interfaces to most popular AI services.
|
||||||
- **Human-centered design**: The library prioritizes readability and usability—no need to be an expert to start experimenting.
|
- **Human-centered design**: The library prioritizes readability and usability—no need to be an expert to start experimenting.
|
||||||
- **Minimal configuration**: Get started quickly, without worrying about configuration headaches.
|
- **Minimal configuration**: Get started quickly, without worrying about configuration headaches.
|
||||||
|
|
||||||
## Supported APIs
|
## Supported APIs
|
||||||
|
|
||||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
|
The APIs remain identical between all supported providers / models:
|
||||||
|
|
||||||
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
|
<table>
|
||||||
- [**Google's Gemini**](https://gemini.google/)
|
<thead>
|
||||||
- [**Groq's Groq**](https://groq.com/)
|
<tr>
|
||||||
- [**Ollama**](https://ollama.com)
|
<th></th>
|
||||||
- [**OpenAI's GPT**](https://openai.com/gpt)
|
<th><code>llm_provider</code></th>
|
||||||
- [**xAI's Grok**](https://x.ai/)
|
<th>Default <code>llm_model</code></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://www.anthropic.com/claude">Anthropic's Claude</a></td>
|
||||||
|
<td><code>"anthropic"</code></td>
|
||||||
|
<td><code>"claude-3-5-sonnet-20241022"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://aws.amazon.com/bedrock/">Amazon's Bedrock</a></td>
|
||||||
|
<td><code>"amazon"</code></td>
|
||||||
|
<td><code>"anthropic.claude-3-5-sonnet-20241022-v2:0"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://www.deepseek.com">Deepseek</a></td>
|
||||||
|
<td><code>"deepseek"</code></td>
|
||||||
|
<td><code>"deepseek-chat"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://gemini.google/">Google's Gemini</a></td>
|
||||||
|
<td><code>"gemini"</code></td>
|
||||||
|
<td><code>"models/gemini-1.5-pro"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://groq.com/">Groq's Groq</a></td>
|
||||||
|
<td><code>"groq"</code></td>
|
||||||
|
<td><code>"llama3-8b-8192"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://ollama.com">Ollama</a></td>
|
||||||
|
<td><code>"ollama"</code></td>
|
||||||
|
<td><code>"llama3.2"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://openai.com/gpt">OpenAI's GPT</a></td>
|
||||||
|
<td><code>"openai"</code></td>
|
||||||
|
<td><code>"gpt-4o-mini"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><a href="https://x.ai/">xAI's Grok</a></td>
|
||||||
|
<td><code>"xai"</code></td>
|
||||||
|
<td><code>"grok-beta"</code></td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
If you want to see Simplemind support, additional providers or models, please send a pull request!
|
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
|
||||||
|
|
||||||
## Why SimpleMind?
|
If you want to see Simplemind support additional providers or models, please send a pull request!
|
||||||
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
|
||||||
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
|
|
||||||
- **Open Source**: Simplemind is open source, and contributions are always welcome!
|
|
||||||
|
|
||||||
Also, why not? :)
|
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ pip install simplemind
|
$ pip install 'simplemind[full]'
|
||||||
```
|
```
|
||||||
|
|
||||||
First, authenticate your API keys by setting them in the environment variables:
|
First, authenticate your API keys by setting them in the environment variables:
|
||||||
@@ -48,7 +88,7 @@ First, authenticate your API keys by setting them in the environment variables:
|
|||||||
$ export OPENAI_API_KEY="sk-..."
|
$ export OPENAI_API_KEY="sk-..."
|
||||||
```
|
```
|
||||||
|
|
||||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `DEEPSEEK_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||||
|
|
||||||
Next, import Simplemind and start using it:
|
Next, import Simplemind and start using it:
|
||||||
|
|
||||||
@@ -56,20 +96,28 @@ Next, import Simplemind and start using it:
|
|||||||
import simplemind as sm
|
import simplemind as sm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
Here are some examples of how to use Simplemind:
|
Here are some examples of how to use Simplemind.
|
||||||
|
|
||||||
|
**Please note**: Most of the calls seen here optionally accept `llm_provider` and `llm_model` parameters, which you provide as strings.
|
||||||
|
|
||||||
### Text Completion
|
### Text Completion
|
||||||
|
|
||||||
Generate a response from an AI model based on a given prompt:
|
Generate a response from an AI model based on a given prompt:
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> sm.generate_text(prompt="What is the meaning of life?", llm_provider="openai", llm_model="gpt-4o")
|
>>> sm.generate_text(prompt="What is the meaning of life?")
|
||||||
"The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life's meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life's purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections."
|
"The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life's meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life's purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Streaming Text
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> for chunk in sm.generate_text("Write a poem about the moon", stream=True):
|
||||||
|
... print(chunk, end="", flush=True)
|
||||||
|
```
|
||||||
|
|
||||||
### Structured Data with Pydantic
|
### Structured Data with Pydantic
|
||||||
|
|
||||||
You can use Pydantic models to structure the response from the LLM, if the LLM supports it.
|
You can use Pydantic models to structure the response from the LLM, if the LLM supports it.
|
||||||
@@ -81,34 +129,54 @@ class Poem(BaseModel):
|
|||||||
```
|
```
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> sm.generate_data(
|
>>> sm.generate_data("Write a poem about love", response_model=Poem)
|
||||||
"Write a poem about love",
|
|
||||||
llm_model="gpt-4o-mini",
|
|
||||||
llm_provider="openai",
|
|
||||||
response_model=Poem,
|
|
||||||
)
|
|
||||||
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn love’s embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams we’ve laid,\nLove’s threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo here’s to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
|
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn love’s embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams we’ve laid,\nLove’s threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo here’s to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### A more complex example
|
||||||
|
|
||||||
|
```python
|
||||||
|
class InstructionStep(BaseModel):
|
||||||
|
step_number: int
|
||||||
|
instruction: str
|
||||||
|
|
||||||
|
class RecipeIngredient(BaseModel):
|
||||||
|
name: str
|
||||||
|
quantity: float
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
class Recipe(BaseModel):
|
||||||
|
name: str
|
||||||
|
ingredients: list[RecipeIngredient]
|
||||||
|
instructions: list[InstructionStep]
|
||||||
|
|
||||||
|
recipe = sm.generate_data(
|
||||||
|
"Write a recipe for chocolate chip cookies",
|
||||||
|
response_model=Recipe,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Special thanks to [@jxnl](https://github.com/jxnl) for building [Instructor](https://github.com/jxnl/instructor), which makes this possible!
|
||||||
|
|
||||||
### Conversational AI
|
### Conversational AI
|
||||||
|
|
||||||
SimpleMind also allows for easy conversational flows:
|
SimpleMind also allows for easy conversational flows:
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
>>> conv = sm.create_conversation()
|
||||||
|
|
||||||
>>> # Add a message to the conversation
|
>>> # Add a message to the conversation
|
||||||
>>> conversation.add_message("user", "Hi there, how are you?")
|
>>> conv.add_message("user", "Hi there, how are you?")
|
||||||
|
|
||||||
>>> conversation.send()
|
>>> conv.send()
|
||||||
<Message role=assistant text="Hello! I'm just a computer program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?">
|
<Message role=assistant text="Hello! I'm just a computer program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?">
|
||||||
```
|
```
|
||||||
|
|
||||||
To continue the conversation, you can call `conversation.send()` again, which returns the next message in the conversation:
|
To continue the conversation, you can call `conv.send()` again, which returns the next message in the conversation:
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> conversation.add_message("user", "What is the meaning of life?")
|
>>> conv.add_message("user", "What is the meaning of life?")
|
||||||
>>> conversation.send()
|
>>> conv.send()
|
||||||
<Message role=assistant text="The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life’s meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life’s purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections.">
|
<Message role=assistant text="The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life’s meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life’s purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections.">
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -125,13 +193,12 @@ response = gpt_4o_mini.generate_text("Hello!")
|
|||||||
conversation = gpt_4o_mini.create_conversation()
|
conversation = gpt_4o_mini.create_conversation()
|
||||||
```
|
```
|
||||||
|
|
||||||
This maintains the simplicity of the original API while reducing repetition. The session object also supports overriding defaults on a per-call basis:
|
This maintains the simplicity of the original API while reducing repetition.
|
||||||
|
|
||||||
|
The session object also supports overriding defaults on a per-call basis:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = gpt_4o_mini.generate_text(
|
response = gpt_4o_mini.generate_text("Complex task here", llm_model="gpt-4")
|
||||||
"Complex task here",
|
|
||||||
llm_model="gpt-4"
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Basic Memory Plugin
|
### Basic Memory Plugin
|
||||||
@@ -154,7 +221,7 @@ class SimpleMemoryPlugin(sm.BasePlugin):
|
|||||||
conversation.add_message(role="system", text=m)
|
conversation.add_message(role="system", text=m)
|
||||||
|
|
||||||
|
|
||||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
conversation = sm.create_conversation()
|
||||||
conversation.add_plugin(SimpleMemoryPlugin())
|
conversation.add_plugin(SimpleMemoryPlugin())
|
||||||
|
|
||||||
|
|
||||||
@@ -163,6 +230,7 @@ conversation.add_message(
|
|||||||
text="Please write a poem about the moon",
|
text="Please write a poem about the moon",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> conversation.send()
|
>>> conversation.send()
|
||||||
In the vast expanse where stars do play,
|
In the vast expanse where stars do play,
|
||||||
@@ -198,11 +266,125 @@ The universe is never done.
|
|||||||
|
|
||||||
Simple, yet effective.
|
Simple, yet effective.
|
||||||
|
|
||||||
|
### Tools (Function calling)
|
||||||
|
Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_weather(
|
||||||
|
location: Annotated[
|
||||||
|
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||||
|
],
|
||||||
|
unit: Annotated[
|
||||||
|
Literal["celcius", "fahrenheit"],
|
||||||
|
Field(
|
||||||
|
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||||
|
),
|
||||||
|
] = "celcius",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the current weather in a given location
|
||||||
|
"""
|
||||||
|
return f"42 {unit}"
|
||||||
|
|
||||||
|
# Add your function as a tool
|
||||||
|
conversation = sm.create_conversation()
|
||||||
|
conversation.add_message("user", "What's the weather in San Francisco?")
|
||||||
|
response = conversation.send(tools=[get_weather])
|
||||||
|
```
|
||||||
|
|
||||||
|
Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable.
|
||||||
|
You can alos ommit `Annotated` and just pass the `Field` parameter.
|
||||||
|
```python
|
||||||
|
def get_weather(
|
||||||
|
location: str = Field(description="The city and state, e.g. San Francisco, CA"),
|
||||||
|
unit:Literal["celcius", "fahrenheit"]= Field(
|
||||||
|
default="celcius",
|
||||||
|
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the current weather in a given location
|
||||||
|
"""
|
||||||
|
return f"42 {unit}"
|
||||||
|
```
|
||||||
|
|
||||||
|
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
|
||||||
|
|
||||||
|
#### 🪄 Using LLM for automatic tool definition (Experimental)
|
||||||
|
|
||||||
|
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@simplemind.tool(llm_provider="anthropic")
|
||||||
|
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||||
|
r = 6371
|
||||||
|
phi1 = math.radians(lat1)
|
||||||
|
phi2 = math.radians(lat2)
|
||||||
|
delta_phi = math.radians(lat2 - lat1)
|
||||||
|
delta_lambda = math.radians(lon2 - lon1)
|
||||||
|
|
||||||
|
a = (
|
||||||
|
math.sin(delta_phi / 2) ** 2
|
||||||
|
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||||
|
)
|
||||||
|
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||||
|
d = r * c
|
||||||
|
return d
|
||||||
|
```
|
||||||
|
Notice how we have not added any docstrings or `Field` for the function.
|
||||||
|
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "haversine",
|
||||||
|
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"lat1": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Latitude of the first point in decimal degrees",
|
||||||
|
},
|
||||||
|
"lon1": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Longitude of the first point in decimal degrees",
|
||||||
|
},
|
||||||
|
"lat2": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Latitude of the second point in decimal degrees",
|
||||||
|
},
|
||||||
|
"lon2": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Longitude of the second point in decimal degrees",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["lat1", "lon1", "lat2", "lon2"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The decorated function can then be used like any other tool with the conversation API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
conversation = sm.create_conversation()
|
||||||
|
conversation.add_message("user", "How far is London from my location")
|
||||||
|
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
|
||||||
|
```
|
||||||
|
|
||||||
|
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
|
||||||
|
Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`.
|
||||||
|
|
||||||
|
### More Examples
|
||||||
|
|
||||||
Please see the [examples](examples) directory for executable examples.
|
Please see the [examples](examples) directory for executable examples.
|
||||||
|
|
||||||
-------------------
|
---
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
||||||
|
|
||||||
To get started:
|
To get started:
|
||||||
@@ -213,8 +395,9 @@ To get started:
|
|||||||
4. Submit a pull request.
|
4. Submit a pull request.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Simplemind is licensed under the Apache 2.0 License.
|
Simplemind is licensed under the Apache 2.0 License.
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
|
||||||
|
|
||||||
|
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||||
|
|||||||
+1
-1
@@ -16,7 +16,7 @@ import simplemind
|
|||||||
project = "simplemind"
|
project = "simplemind"
|
||||||
copyright = "2024 Kenneth Reitz"
|
copyright = "2024 Kenneth Reitz"
|
||||||
author = "Kenneth Reitz"
|
author = "Kenneth Reitz"
|
||||||
release = "v0.1.7"
|
release = "v0.2.2"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import sys
|
|||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
|
||||||
|
import simplemind
|
||||||
import simplemind as sm
|
import simplemind as sm
|
||||||
|
|
||||||
__all__ = ["sm"]
|
__all__ = ["simplemind", "sm"]
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
from _context import simplemind as sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
gpt_4o_mini = sm.Session(llm_provider="openai")
|
||||||
|
claude_sonnet = sm.Session(llm_provider="anthropic")
|
||||||
|
|
||||||
|
|
||||||
|
class BibleVerse(BaseModel):
|
||||||
|
book: str
|
||||||
|
chapter: int
|
||||||
|
verse: int
|
||||||
|
text: str
|
||||||
|
translation: str
|
||||||
|
|
||||||
|
|
||||||
|
class BiblePassage(BaseModel):
|
||||||
|
book: str
|
||||||
|
chapter: int
|
||||||
|
verses: list[BibleVerse]
|
||||||
|
translation: str
|
||||||
|
|
||||||
|
|
||||||
|
class CrossReference(BaseModel):
|
||||||
|
passage: BiblePassage
|
||||||
|
notes: list[str]
|
||||||
|
origin_verse: BibleVerse
|
||||||
|
ai_perspective: str
|
||||||
|
anthropic_perspective: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_passage(book: str, chapter: int, translation: str = "ESV") -> BiblePassage:
|
||||||
|
passage = gpt_4o_mini.generate_data(
|
||||||
|
prompt=f"""Return {book} chapter {chapter} from the {translation} translation.
|
||||||
|
Format each verse as plain text without any special characters or formatting.
|
||||||
|
For example:
|
||||||
|
- "Love is patient, love is kind."
|
||||||
|
- "It does not envy, it does not boast"
|
||||||
|
|
||||||
|
Return only the biblical text, formatted as a BiblePassage object.""",
|
||||||
|
response_model=BiblePassage,
|
||||||
|
max_tokens=8000,
|
||||||
|
)
|
||||||
|
return passage
|
||||||
|
|
||||||
|
|
||||||
|
def get_cross_reference(passage: BiblePassage) -> CrossReference:
|
||||||
|
verses_text = "\n".join([f"Verse {v.verse}: {v.text}" for v in passage.verses])
|
||||||
|
|
||||||
|
# Get main cross-reference from OpenAI
|
||||||
|
ref = gpt_4o_mini.generate_data(
|
||||||
|
prompt=f"""Find a thematically related Bible passage that connects with this text:
|
||||||
|
{verses_text}
|
||||||
|
|
||||||
|
Return a CrossReference object with:
|
||||||
|
1. A related passage (using plain text without special characters)
|
||||||
|
2. A list of clear, specific notes explaining the thematic connections
|
||||||
|
3. The original passage included
|
||||||
|
4. An AI perspective that provides a thoughtful, modern interpretation of how these passages relate to contemporary life and universal human experiences""",
|
||||||
|
response_model=CrossReference,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get Anthropic's perspective separately
|
||||||
|
anthropic_insight = claude_sonnet.generate_text(
|
||||||
|
prompt=f"""Analyze these biblical passages from a philosophical and ethical perspective:
|
||||||
|
|
||||||
|
Original passage:
|
||||||
|
{verses_text}
|
||||||
|
|
||||||
|
Cross-reference passage:
|
||||||
|
{' '.join([f'Verse {v.verse}: {v.text}' for v in ref.passage.verses])}
|
||||||
|
|
||||||
|
Provide a thoughtful analysis focusing on the philosophical and ethical implications of these passages, drawing from your training in ethics and philosophy.
|
||||||
|
Return your response as a plain string.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add Anthropic's perspective to the reference object
|
||||||
|
ref.anthropic_perspective = anthropic_insight
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_reference(ref: CrossReference):
|
||||||
|
# Create origin passage panel
|
||||||
|
origin_text = Text()
|
||||||
|
origin_text.append(
|
||||||
|
f"{ref.origin_verse.book} {ref.origin_verse.chapter}\n",
|
||||||
|
style="bold blue",
|
||||||
|
)
|
||||||
|
origin_text.append(f"{ref.origin_verse.verse}. ", style="blue")
|
||||||
|
origin_text.append(f"{ref.origin_verse.text}\n", style="italic")
|
||||||
|
origin_text.append(f"\n({ref.origin_verse.translation})", style="dim")
|
||||||
|
|
||||||
|
origin_panel = Panel(origin_text, title="Original Passage", border_style="blue")
|
||||||
|
|
||||||
|
# Create cross reference panel
|
||||||
|
ref_text = Text()
|
||||||
|
ref_text.append(
|
||||||
|
f"{ref.passage.book} {ref.passage.chapter}\n",
|
||||||
|
style="bold green",
|
||||||
|
)
|
||||||
|
for verse in ref.passage.verses:
|
||||||
|
ref_text.append(f"{verse.verse}. ", style="green")
|
||||||
|
ref_text.append(f"{verse.text}\n", style="italic")
|
||||||
|
ref_text.append(f"\n({ref.passage.translation})", style="dim")
|
||||||
|
|
||||||
|
ref_panel = Panel(ref_text, title="Cross Reference", border_style="green")
|
||||||
|
|
||||||
|
# Create notes panel with bullet points
|
||||||
|
notes_text = Text()
|
||||||
|
for note in ref.notes:
|
||||||
|
notes_text.append("• ", style="yellow")
|
||||||
|
notes_text.append(f"{note}\n")
|
||||||
|
|
||||||
|
notes_panel = Panel(notes_text, title="Thematic Connections", border_style="yellow")
|
||||||
|
|
||||||
|
# Add new AI perspective panel
|
||||||
|
ai_text = Text()
|
||||||
|
ai_text.append(ref.ai_perspective)
|
||||||
|
|
||||||
|
ai_panel = Panel(ai_text, title="AI Perspective", border_style="magenta")
|
||||||
|
|
||||||
|
# Print all panels
|
||||||
|
console.print(origin_panel)
|
||||||
|
console.print(ref_panel)
|
||||||
|
console.print(notes_panel)
|
||||||
|
console.print(ai_panel)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Get 1 Corinthians 13 (The Love Chapter)
|
||||||
|
passage = get_passage("1 Corinthians", 13)
|
||||||
|
ref = get_cross_reference(passage)
|
||||||
|
pretty_print_reference(ref)
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
from _context import simplemind as sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
class InstructionStep(BaseModel):
|
||||||
|
step_number: int
|
||||||
|
instruction: str
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeIngredient(BaseModel):
|
||||||
|
name: str
|
||||||
|
quantity: float
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
|
class Recipe(BaseModel):
|
||||||
|
name: str
|
||||||
|
ingredients: list[RecipeIngredient]
|
||||||
|
instructions: list[InstructionStep]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
console = Console(record=True, file=None)
|
||||||
|
|
||||||
|
# Create formatted title with more emphasis
|
||||||
|
title = Text("✨ " + self.name.upper() + " ✨", style="bold blue")
|
||||||
|
|
||||||
|
# Format ingredients with better structure
|
||||||
|
ingredients_text = Text("\n📝 INGREDIENTS:\n", style="bold green")
|
||||||
|
for ing in self.ingredients:
|
||||||
|
# Format numbers to avoid floating decimals when whole numbers
|
||||||
|
quantity = int(ing.quantity) if ing.quantity.is_integer() else ing.quantity
|
||||||
|
ingredients_text.append(f" • {quantity} {ing.unit} ", style="bright_white")
|
||||||
|
ingredients_text.append(f"{ing.name}\n", style="italic bright_white")
|
||||||
|
|
||||||
|
# Format instructions with better spacing and styling
|
||||||
|
instructions_text = Text("\n👩🍳 INSTRUCTIONS:\n", style="bold yellow")
|
||||||
|
for step in self.instructions:
|
||||||
|
instructions_text.append(
|
||||||
|
f"\n {step.step_number}. ", style="bold bright_white"
|
||||||
|
)
|
||||||
|
instructions_text.append(f"{step.instruction}", style="bright_white")
|
||||||
|
|
||||||
|
# Combine all text
|
||||||
|
full_text = Text.assemble(
|
||||||
|
ingredients_text, instructions_text, "\n"
|
||||||
|
) # Added extra newline
|
||||||
|
|
||||||
|
# Create panel with enhanced styling
|
||||||
|
panel = Panel(
|
||||||
|
full_text,
|
||||||
|
title=title,
|
||||||
|
border_style="blue",
|
||||||
|
padding=(1, 2), # Add padding (vertical, horizontal)
|
||||||
|
expand=False, # Don't expand to full terminal width
|
||||||
|
title_align="center",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render the panel to string without printing
|
||||||
|
with console.capture() as capture:
|
||||||
|
console.print(panel)
|
||||||
|
return capture.get()
|
||||||
|
|
||||||
|
|
||||||
|
recipe = sm.generate_data(
|
||||||
|
"Write a recipe for chocolate chip cookies",
|
||||||
|
llm_model="gpt-4o-mini",
|
||||||
|
llm_provider="openai",
|
||||||
|
response_model=Recipe,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(recipe)
|
||||||
|
# Expected output is something like this:
|
||||||
|
#
|
||||||
|
# === CHOCOLATE CHIP COOKIES ===
|
||||||
|
#
|
||||||
|
# INGREDIENTS:
|
||||||
|
# • 2.25 cups all-purpose flour
|
||||||
|
# • 1.0 teaspoon baking soda
|
||||||
|
# • 0.5 teaspoon salt
|
||||||
|
# • 1.0 cup unsalted butter
|
||||||
|
# • 0.75 cup sugar
|
||||||
|
# • 0.75 cup brown sugar
|
||||||
|
# • 1.0 teaspoon vanilla extract
|
||||||
|
# • 2.0 large eggs
|
||||||
|
# • 2.0 cups semi-sweet chocolate chips
|
||||||
|
#
|
||||||
|
# INSTRUCTIONS:
|
||||||
|
# 1. Preheat your oven to 350°F (175°C).
|
||||||
|
# 2. In a small bowl, combine flour, baking soda, and salt; set aside.
|
||||||
|
# 3. In a large bowl, cream together the butter, sugar, and brown sugar until smooth.
|
||||||
|
# 4. Beat in the vanilla extract and eggs, one at a time.
|
||||||
|
# 5. Gradually blend in the flour mixture until just combined.
|
||||||
|
# 6. Stir in the chocolate chips.
|
||||||
|
# 7. Drop by rounded tablespoon onto ungreased cookie sheets.
|
||||||
|
# 8. Bake for 9 to 11 minutes, or until edges are golden.
|
||||||
|
# 9. Let cool on the cookie sheet for a few minutes before transferring to wire racks to cool completely.
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
import time
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from _context import sm
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
|
||||||
|
class MultiAIConversation:
|
||||||
|
"""Orchestrates conversations between multiple AI models."""
|
||||||
|
|
||||||
|
MODEL_SESSIONS = {
|
||||||
|
"GPT-4o": sm.Session(
|
||||||
|
llm_provider="openai",
|
||||||
|
llm_model="gpt-4o",
|
||||||
|
),
|
||||||
|
"Grok-Beta": sm.Session(
|
||||||
|
llm_provider="xai",
|
||||||
|
llm_model="grok-beta",
|
||||||
|
),
|
||||||
|
"Claude-3.5-Sonnet": sm.Session(
|
||||||
|
llm_provider="anthropic",
|
||||||
|
llm_model="claude-3-5-sonnet-20241022",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, topic: str, turns_per_model: int = 1, max_rounds: int = 5):
|
||||||
|
|
||||||
|
self.topic = topic
|
||||||
|
self.turns_per_model = turns_per_model
|
||||||
|
self.max_rounds = max_rounds
|
||||||
|
self.conversation_history: List[Tuple[str, str]] = []
|
||||||
|
self.console = Console()
|
||||||
|
self.user_name = "Kenneth Reitz"
|
||||||
|
|
||||||
|
def _format_system_prompt(self, ai_name: str) -> str:
|
||||||
|
"""Creates a system prompt for each AI model."""
|
||||||
|
return f"""You are {ai_name}. You are participating in a thoughtful discussion with other AI models about {self.topic}.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Be concise but insightful (keep responses under 140 words)
|
||||||
|
2. Build upon previous points made in the conversation
|
||||||
|
3. Ask questions to deepen the discussion when appropriate
|
||||||
|
4. Stay on topic while maintaining your unique perspective
|
||||||
|
5. Be respectful of other viewpoints while maintaining your distinct voice
|
||||||
|
|
||||||
|
Current discussion topic: {self.topic}"""
|
||||||
|
|
||||||
|
def _create_conversation(
|
||||||
|
self, session: sm.Session, ai_name: str
|
||||||
|
) -> sm.Conversation:
|
||||||
|
"""Creates a new conversation with appropriate context for an AI model."""
|
||||||
|
conv = session.create_conversation()
|
||||||
|
|
||||||
|
# Add system prompt
|
||||||
|
conv.add_message(role="user", text=self._format_system_prompt(ai_name))
|
||||||
|
|
||||||
|
# Add conversation history
|
||||||
|
for speaker, message in self.conversation_history[-3:]: # Last 3 messages
|
||||||
|
conv.add_message(role="user", text=f"{speaker} said: {message}")
|
||||||
|
|
||||||
|
return conv
|
||||||
|
|
||||||
|
def _print_response(self, ai_name: str, response: str):
|
||||||
|
"""Pretty prints an AI response using Rich."""
|
||||||
|
self.console.print(f"\n[bold blue]{ai_name}[/bold blue]:")
|
||||||
|
self.console.print(Markdown(response))
|
||||||
|
# Store in history
|
||||||
|
self.conversation_history.append((ai_name, response))
|
||||||
|
|
||||||
|
def _get_user_input(self) -> str:
|
||||||
|
"""Gets input from the user for the discussion."""
|
||||||
|
self.console.print("\n[bold green]Your turn! Share your thoughts:[/bold green]")
|
||||||
|
user_response = input("> ")
|
||||||
|
self._print_response(self.user_name, user_response)
|
||||||
|
return user_response
|
||||||
|
|
||||||
|
def run_conversation(self):
|
||||||
|
"""Runs the multi-AI conversation."""
|
||||||
|
# Get initial thoughts from the human
|
||||||
|
self.console.print(
|
||||||
|
f"\n[bold green]Start the discussion about {self.topic}:[/bold green]"
|
||||||
|
)
|
||||||
|
self._get_user_input()
|
||||||
|
|
||||||
|
for round_num in range(self.max_rounds):
|
||||||
|
self.console.print(f"\n[bold green]Round {round_num + 1}[/bold green]")
|
||||||
|
|
||||||
|
# Let all AI models respond
|
||||||
|
for model_name, session in self.MODEL_SESSIONS.items():
|
||||||
|
for turn in range(self.turns_per_model):
|
||||||
|
conversation = self._create_conversation(session, model_name)
|
||||||
|
|
||||||
|
# Add the prompt (simplified since human always starts)
|
||||||
|
prompt = f"Continue the discussion about {self.topic}, responding to the previous points made."
|
||||||
|
conversation.add_message(role="user", text=prompt)
|
||||||
|
|
||||||
|
# Get and print response
|
||||||
|
response = conversation.send()
|
||||||
|
self._print_response(model_name, response.text)
|
||||||
|
|
||||||
|
# Small delay to prevent rate limiting
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Then get user input at the end of the round
|
||||||
|
self._get_user_input()
|
||||||
|
|
||||||
|
# Optional: Add a separator between rounds
|
||||||
|
self.console.print("\n" + "-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def have_ai_discussion(turns_per_model: int = 1, max_rounds: int = 3):
|
||||||
|
"""Convenience function to start an AI discussion."""
|
||||||
|
# Get topic from user
|
||||||
|
print("\nWhat topic would you like to discuss?")
|
||||||
|
topic = input("> ")
|
||||||
|
|
||||||
|
debate = MultiAIConversation(
|
||||||
|
topic=topic, turns_per_model=turns_per_model, max_rounds=max_rounds
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nStarting AI discussion about: {topic}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
debate.run_conversation()
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
have_ai_discussion(turns_per_model=1, max_rounds=5)
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from _context import sm
|
||||||
|
from pydantic import Field
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@sm.tool(llm_provider="anthropic")
|
||||||
|
def haversine(
|
||||||
|
lat1: float,
|
||||||
|
lon1: float,
|
||||||
|
lat2: float,
|
||||||
|
lon2: float,
|
||||||
|
unit: Literal["km", "miles"],
|
||||||
|
) -> float:
|
||||||
|
r = 6378.0937 if unit == "km" else 3961
|
||||||
|
phi1 = math.radians(lat1)
|
||||||
|
phi2 = math.radians(lat2)
|
||||||
|
delta_phi = math.radians(lat2 - lat1)
|
||||||
|
delta_lambda = math.radians(lon2 - lon1)
|
||||||
|
|
||||||
|
a = (
|
||||||
|
math.sin(delta_phi / 2) ** 2
|
||||||
|
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||||
|
)
|
||||||
|
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||||
|
d = r * c
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_location() -> str:
|
||||||
|
"""Get the closest city from the user"""
|
||||||
|
return "San Francisco"
|
||||||
|
|
||||||
|
|
||||||
|
def get_coords(
|
||||||
|
city_name: str = Field(
|
||||||
|
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Get latitude and logitude of a City."""
|
||||||
|
_data = {
|
||||||
|
"Rome": (41.9028, 12.4964),
|
||||||
|
"London": (51.5074, -0.1278),
|
||||||
|
"Madrid": (40.4168, -3.7038),
|
||||||
|
"San Francisco": (37.7749, -122.4194),
|
||||||
|
"Los Angeles": (34.0522, -118.2437),
|
||||||
|
}
|
||||||
|
|
||||||
|
return _data.get(city_name)
|
||||||
|
|
||||||
|
|
||||||
|
def distance_calculator(prompt: str):
|
||||||
|
conversation = sm.create_conversation(llm_provider="anthropic")
|
||||||
|
conversation.add_message("user", prompt)
|
||||||
|
return conversation.send(
|
||||||
|
tools=[get_user_location, get_coords, haversine]
|
||||||
|
).text
|
||||||
|
|
||||||
|
|
||||||
|
print(distance_calculator("How far is London from where I am?"))
|
||||||
|
# Prints something like:
|
||||||
|
"""
|
||||||
|
The distance between your location (San Francisco) and London is approximately 5,357 miles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(
|
||||||
|
distance_calculator(
|
||||||
|
"What is the distance between Rome and Madrid in Kilometers?"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
The distance between Rome and Madrid is approximately 1,366 kilometers.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,952 @@
|
|||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
import spacy
|
||||||
|
import xerox
|
||||||
|
from _context import simplemind as sm
|
||||||
|
from docopt import docopt
|
||||||
|
from nltk.tag import pos_tag
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||||
|
from prompt_toolkit.completion import Completer, Completion
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.status import Status
|
||||||
|
|
||||||
|
DB_PATH = "enhanced_context.db"
|
||||||
|
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
||||||
|
|
||||||
|
# Enable Logfire for debugging.
|
||||||
|
# sm.enable_logfire()
|
||||||
|
|
||||||
|
__doc__ = """Enhanced Context Chat Interface
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
enhanced_context.py [--provider=<provider>] [--model=<model>]
|
||||||
|
enhanced_context.py (-h | --help)
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-h --help Show this screen.
|
||||||
|
--provider=<provider> LLM provider to use (openai/anthropic/xai/ollama)
|
||||||
|
--model=<model> Specific model to use (e.g. o1-preview)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ContextDatabase:
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.init_db()
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_connection(self):
|
||||||
|
"""Context manager for database connections"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def init_db(self):
|
||||||
|
"""Initialize the database with proper schema"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS memory (
|
||||||
|
entity TEXT,
|
||||||
|
source TEXT,
|
||||||
|
last_mentioned TIMESTAMP,
|
||||||
|
mention_count INTEGER DEFAULT 1,
|
||||||
|
PRIMARY KEY (entity, source)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS identity (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
last_updated TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS essence_markers (
|
||||||
|
marker_type TEXT,
|
||||||
|
marker_text TEXT,
|
||||||
|
timestamp TIMESTAMP,
|
||||||
|
PRIMARY KEY (marker_type, marker_text)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||||
|
"""Store or update entity mention with source tracking"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO memory (entity, source, last_mentioned, mention_count)
|
||||||
|
VALUES (?, ?, ?, 1)
|
||||||
|
ON CONFLICT(entity, source) DO UPDATE SET
|
||||||
|
last_mentioned = ?,
|
||||||
|
mention_count = mention_count + 1
|
||||||
|
""",
|
||||||
|
(entity, source, now, now),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||||
|
"""Retrieve recently mentioned entities with frequency and source"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
entity,
|
||||||
|
SUM(mention_count) as total_mentions,
|
||||||
|
GROUP_CONCAT(source || ':' || mention_count) as source_counts
|
||||||
|
FROM memory
|
||||||
|
WHERE last_mentioned >= datetime('now', ?, 'localtime')
|
||||||
|
GROUP BY entity
|
||||||
|
ORDER BY total_mentions DESC, MAX(last_mentioned) DESC
|
||||||
|
LIMIT 50
|
||||||
|
""",
|
||||||
|
(f"-{days} days",),
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
for row in cur.fetchall():
|
||||||
|
entity, total_count, source_counts = row
|
||||||
|
source_dict = dict(sc.split(":") for sc in source_counts.split(","))
|
||||||
|
entities.append(
|
||||||
|
(
|
||||||
|
entity,
|
||||||
|
total_count,
|
||||||
|
int(source_dict.get("user", 0)),
|
||||||
|
int(source_dict.get("llm", 0)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return entities
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self.logger.error(f"Database error while retrieving entities: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def store_identity(self, identity: str) -> None:
|
||||||
|
"""Store personal identity in database"""
|
||||||
|
if not identity:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
now = datetime.now()
|
||||||
|
# Store in identity table
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO identity (id, name, last_updated)
|
||||||
|
VALUES (1, ?, ?)
|
||||||
|
""",
|
||||||
|
(identity, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in memory table
|
||||||
|
self.store_entity(identity)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self.logger.error(f"Database error while storing identity: {e}")
|
||||||
|
|
||||||
|
def load_identity(self) -> str | None:
|
||||||
|
"""Load personal identity from database"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SELECT name FROM identity WHERE id = 1")
|
||||||
|
result = cur.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self.logger.error(f"Database error while loading identity: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||||
|
"""Store essence marker in database"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO essence_markers
|
||||||
|
(marker_type, marker_text, timestamp)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(marker_type, marker_text, now),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self.logger.error(f"Database error storing essence marker: {e}")
|
||||||
|
|
||||||
|
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||||
|
"""Retrieve recent essence markers"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT DISTINCT marker_type, marker_text
|
||||||
|
FROM essence_markers
|
||||||
|
WHERE timestamp >= datetime('now', ?, 'localtime')
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
""",
|
||||||
|
(f"-{days} days",),
|
||||||
|
)
|
||||||
|
return cur.fetchall()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self.logger.error(f"Database error retrieving essence markers: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedContextPlugin(sm.BasePlugin):
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
def __init__(self, verbose: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
# Set up logging
|
||||||
|
self.verbose = verbose
|
||||||
|
if verbose:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize NLP model
|
||||||
|
try:
|
||||||
|
self.nlp = spacy.load("en_core_web_sm")
|
||||||
|
except OSError:
|
||||||
|
self.logger.error(
|
||||||
|
"Failed to load spaCy model. Please install it using: python -m spacy download en_core_web_sm"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
self.db = ContextDatabase(DB_PATH)
|
||||||
|
self.logger.info(f"EnhancedContextPlugin initialized with database: {DB_PATH}")
|
||||||
|
|
||||||
|
# Load identity from database
|
||||||
|
self.personal_identity = self.db.load_identity()
|
||||||
|
|
||||||
|
# Download required NLTK data silently
|
||||||
|
try:
|
||||||
|
with open(os.devnull, "w") as null_out:
|
||||||
|
with (
|
||||||
|
contextlib.redirect_stdout(null_out),
|
||||||
|
contextlib.redirect_stderr(null_out),
|
||||||
|
):
|
||||||
|
nltk.download("punkt", quiet=True)
|
||||||
|
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||||
|
except LookupError as e:
|
||||||
|
self.logger.error(f"Error downloading NLTK data: {e}")
|
||||||
|
|
||||||
|
# Add LLM personality traits for easter egg
|
||||||
|
self.llm_personalities = [
|
||||||
|
"You are a wise philosopher who speaks in riddles",
|
||||||
|
"You are an excited scientist who loves discovering patterns",
|
||||||
|
"You are a detective who analyzes every detail",
|
||||||
|
"You are a poet who sees beauty in connections",
|
||||||
|
"You are a historian who relates everything to the past",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add these lines to store the conversation's model and provider
|
||||||
|
self.llm_model = None
|
||||||
|
self.llm_provider = None
|
||||||
|
|
||||||
|
def extract_entities(self, text: str) -> List[str]:
|
||||||
|
"""Extract named entities with improved filtering"""
|
||||||
|
doc = self.nlp(text)
|
||||||
|
|
||||||
|
# Define important entity types
|
||||||
|
important_types = {
|
||||||
|
"PERSON",
|
||||||
|
"ORG",
|
||||||
|
"GPE",
|
||||||
|
"NORP",
|
||||||
|
"PRODUCT",
|
||||||
|
"EVENT",
|
||||||
|
"WORK_OF_ART",
|
||||||
|
}
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
ent.text.strip()
|
||||||
|
for ent in doc.ents
|
||||||
|
if (
|
||||||
|
ent.label_ in important_types
|
||||||
|
and len(ent.text.strip()) > 1
|
||||||
|
and not ent.text.isnumeric()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return list(set(entities))
|
||||||
|
|
||||||
|
def format_context_message(
|
||||||
|
self, entities: List[tuple], include_identity: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""Format context message with essence markers"""
|
||||||
|
context_parts = []
|
||||||
|
|
||||||
|
# Add identity context
|
||||||
|
if include_identity and self.personal_identity:
|
||||||
|
context_parts.append(f"The user's name is {self.personal_identity}.")
|
||||||
|
|
||||||
|
# Add essence markers
|
||||||
|
essence_markers = self.retrieve_essence_markers()
|
||||||
|
if essence_markers:
|
||||||
|
markers_by_type = {}
|
||||||
|
for marker_type, marker_text in essence_markers:
|
||||||
|
markers_by_type.setdefault(marker_type, []).append(marker_text)
|
||||||
|
|
||||||
|
context_parts.append("User characteristics:")
|
||||||
|
for marker_type, markers in markers_by_type.items():
|
||||||
|
context_parts.append(f"- {marker_type.title()}: {', '.join(markers)}")
|
||||||
|
|
||||||
|
# Add entity context with user/llm breakdown
|
||||||
|
if entities:
|
||||||
|
entity_strings = [
|
||||||
|
f"{entity} (mentioned {total} times - User: {user_count}, AI: {llm_count})"
|
||||||
|
for entity, total, user_count, llm_count in entities
|
||||||
|
]
|
||||||
|
|
||||||
|
topics = (
|
||||||
|
", ".join(entity_strings[:-1]) + f" and {entity_strings[-1]}"
|
||||||
|
if len(entity_strings) > 1
|
||||||
|
else entity_strings[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
context_parts.append(f"Recent conversation topics: {topics}")
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
def extract_essence_markers(self, text: str) -> List[tuple[str, str]]:
|
||||||
|
"""Extract essence markers from text."""
|
||||||
|
patterns = {
|
||||||
|
"value": [
|
||||||
|
r"I (?:really )?(?:believe|think) (?:that )?(.+)",
|
||||||
|
r"(?:It's|Its) important (?:to me )?that (.+)",
|
||||||
|
r"I value (.+)",
|
||||||
|
r"(?:The )?most important (?:thing|aspect) (?:to me )?is (.+)",
|
||||||
|
],
|
||||||
|
"identity": [
|
||||||
|
r"I am(?: a| an)? (.+)",
|
||||||
|
r"I consider myself(?: a| an)? (.+)",
|
||||||
|
r"I identify as(?: a| an)? (.+)",
|
||||||
|
],
|
||||||
|
"preference": [
|
||||||
|
r"I (?:really )?(?:like|love|enjoy|prefer) (.+)",
|
||||||
|
r"I can't stand (.+)",
|
||||||
|
r"I hate (.+)",
|
||||||
|
r"I always (.+)",
|
||||||
|
r"I never (.+)",
|
||||||
|
],
|
||||||
|
"emotion": [
|
||||||
|
r"I feel (.+)",
|
||||||
|
r"I'm feeling (.+)",
|
||||||
|
r"(?:It|That) makes me feel (.+)",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
markers = []
|
||||||
|
doc = self.nlp(text)
|
||||||
|
|
||||||
|
for sent in doc.sents:
|
||||||
|
sent_text = sent.text.strip().lower()
|
||||||
|
|
||||||
|
for marker_type, pattern_list in patterns.items():
|
||||||
|
for pattern in pattern_list:
|
||||||
|
for match in re.finditer(pattern, sent_text, re.IGNORECASE):
|
||||||
|
marker_text = match.group(1).strip()
|
||||||
|
if self._is_valid_marker(marker_text):
|
||||||
|
markers.append((marker_type, marker_text))
|
||||||
|
|
||||||
|
return markers
|
||||||
|
|
||||||
|
def _is_valid_marker(self, marker_text: str) -> bool:
|
||||||
|
"""Helper method to validate essence markers"""
|
||||||
|
invalid_words = {"um", "uh", "like"}
|
||||||
|
return len(marker_text) > 3 and not any(w in marker_text for w in invalid_words)
|
||||||
|
|
||||||
|
def pre_send_hook(self, conversation: sm.Conversation) -> bool:
|
||||||
|
"""Process user message before sending to LLM"""
|
||||||
|
self.llm_model = conversation.llm_model
|
||||||
|
self.llm_provider = conversation.llm_provider
|
||||||
|
|
||||||
|
last_message = conversation.get_last_message(role="user")
|
||||||
|
if not last_message:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Handle special commands
|
||||||
|
if result := self._handle_special_commands(conversation, last_message.text):
|
||||||
|
return result
|
||||||
|
|
||||||
|
self.logger.info(f"Processing user message: {last_message.text}")
|
||||||
|
|
||||||
|
# Process entities and markers
|
||||||
|
self._process_user_message(last_message.text)
|
||||||
|
|
||||||
|
# Add context
|
||||||
|
self._add_context_to_conversation(conversation)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_special_commands(
|
||||||
|
self, conversation: sm.Conversation, message: str
|
||||||
|
) -> bool | None:
|
||||||
|
"""Handle special commands like /summary"""
|
||||||
|
if message.strip().lower() == "/summary":
|
||||||
|
summary = self.summarize_memory()
|
||||||
|
conversation.add_message(role="assistant", text=summary)
|
||||||
|
return False
|
||||||
|
elif message.strip().lower() == "/topics":
|
||||||
|
topics = self.get_all_topics()
|
||||||
|
conversation.add_message(role="assistant", text=topics)
|
||||||
|
return False
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _process_user_message(self, message: str) -> None:
|
||||||
|
"""Process user message for entities and markers"""
|
||||||
|
# Extract and store entities
|
||||||
|
entities = self.extract_entities(message)
|
||||||
|
for entity in entities:
|
||||||
|
self.store_entity(entity, source="user")
|
||||||
|
|
||||||
|
# Extract and store essence markers
|
||||||
|
essence_markers = self.extract_essence_markers(message)
|
||||||
|
for marker_type, marker_text in essence_markers:
|
||||||
|
self.store_essence_marker(marker_type, marker_text)
|
||||||
|
self.logger.info(f"Found essence marker: {marker_type} - {marker_text}")
|
||||||
|
|
||||||
|
def _add_context_to_conversation(self, conversation: sm.Conversation) -> None:
|
||||||
|
"""Add context message to conversation"""
|
||||||
|
recent_entities = self.retrieve_recent_entities(days=30)
|
||||||
|
context_message = self.format_context_message(recent_entities)
|
||||||
|
if context_message:
|
||||||
|
conversation.add_message(role="user", text=context_message)
|
||||||
|
self.logger.info(f"Added context message: {context_message}")
|
||||||
|
|
||||||
|
def store_entity(self, entity: str, source: str = "user") -> None:
|
||||||
|
self.db.store_entity(entity, source)
|
||||||
|
|
||||||
|
def store_identity(self, identity: str) -> None:
|
||||||
|
self.db.store_identity(identity)
|
||||||
|
self.personal_identity = identity
|
||||||
|
|
||||||
|
def load_identity(self) -> str | None:
|
||||||
|
self.personal_identity = self.db.load_identity()
|
||||||
|
return self.personal_identity
|
||||||
|
|
||||||
|
def store_essence_marker(self, marker_type: str, marker_text: str) -> None:
|
||||||
|
self.db.store_essence_marker(marker_type, marker_text)
|
||||||
|
|
||||||
|
def retrieve_essence_markers(self, days: int = 30) -> List[tuple[str, str]]:
|
||||||
|
return self.db.retrieve_essence_markers(days)
|
||||||
|
|
||||||
|
def summarize_memory(self, days: int = 30) -> str:
|
||||||
|
"""Consolidate recent conversation memory into a summary"""
|
||||||
|
entities = self.retrieve_recent_entities(days=days)
|
||||||
|
if not entities:
|
||||||
|
return "No recent conversation history to consolidate."
|
||||||
|
|
||||||
|
# Group entities by frequency
|
||||||
|
frequent = []
|
||||||
|
occasional = []
|
||||||
|
|
||||||
|
for entity, total, user_count, llm_count in entities:
|
||||||
|
if total >= 3:
|
||||||
|
frequent.append(f"{entity} (mentioned {total} times)")
|
||||||
|
else:
|
||||||
|
occasional.append(f"{entity} (mentioned {total} times)")
|
||||||
|
|
||||||
|
# Build summary
|
||||||
|
summary_parts = []
|
||||||
|
|
||||||
|
if self.personal_identity:
|
||||||
|
summary_parts.append(f"User Identity: {self.personal_identity}")
|
||||||
|
|
||||||
|
if frequent:
|
||||||
|
summary_parts.append("Frequently Discussed Topics:")
|
||||||
|
summary_parts.extend([f"- {item}" for item in frequent])
|
||||||
|
|
||||||
|
if occasional:
|
||||||
|
summary_parts.append("Other Topics Mentioned:")
|
||||||
|
summary_parts.extend([f"- {item}" for item in occasional])
|
||||||
|
|
||||||
|
return "\n".join(summary_parts)
|
||||||
|
|
||||||
|
def simulate_llm_conversation(self, context: str, num_turns: int = 3) -> str:
|
||||||
|
"""Simulate a conversation between multiple LLM personalities about the context"""
|
||||||
|
conversation_log = []
|
||||||
|
|
||||||
|
def get_response(personality: str, previous_messages: str) -> str:
|
||||||
|
prompt = (
|
||||||
|
f"{personality}. You are participating in a brief group discussion "
|
||||||
|
f"about the following context:\n{context}\n\n"
|
||||||
|
f"Previous messages:\n{previous_messages}\n\n"
|
||||||
|
"Provide a short, focused response (1-2 sentences) that builds on "
|
||||||
|
"the discussion. Be creative but stay on topic."
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_conv = sm.create_conversation(
|
||||||
|
llm_model=self.llm_model, llm_provider=self.llm_provider
|
||||||
|
)
|
||||||
|
temp_conv.add_message(role="user", text=prompt)
|
||||||
|
response = temp_conv.send()
|
||||||
|
return response.text.strip()
|
||||||
|
|
||||||
|
# Select random personalities for this conversation
|
||||||
|
selected_personalities = random.sample(
|
||||||
|
self.llm_personalities, min(num_turns, len(self.llm_personalities))
|
||||||
|
)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
for i, personality in enumerate(selected_personalities, 1):
|
||||||
|
previous = "\n".join(conversation_log)
|
||||||
|
response = get_response(personality, previous)
|
||||||
|
conversation_log.append(f"Speaker {i}: {response}")
|
||||||
|
|
||||||
|
return "\n\n".join(conversation_log)
|
||||||
|
|
||||||
|
def store_llm_memory(self, conversation: sm.Conversation) -> None:
|
||||||
|
"""Generate and store memories from the LLM's perspective of the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: The conversation object containing message history
|
||||||
|
"""
|
||||||
|
prompt = """Based on the recent messages, what are the most important things to remember?
|
||||||
|
Format each memory on a new line starting with MEMORY:
|
||||||
|
For example:
|
||||||
|
MEMORY: User prefers Python over JavaScript
|
||||||
|
MEMORY: User is working on a machine learning project"""
|
||||||
|
|
||||||
|
# Create temporary conversation for memory generation
|
||||||
|
temp_conv = sm.create_conversation(
|
||||||
|
llm_model=self.llm_model, llm_provider=self.llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add last few messages for context
|
||||||
|
for msg in conversation.messages[-3:]: # Last 3 messages
|
||||||
|
temp_conv.add_message(role=msg.role, text=msg.text)
|
||||||
|
|
||||||
|
# Get memories from LLM
|
||||||
|
temp_conv.add_message(role="user", text=prompt)
|
||||||
|
response = temp_conv.send()
|
||||||
|
|
||||||
|
# Process and store memories
|
||||||
|
if response and response.text:
|
||||||
|
for line in response.text.split("\n"):
|
||||||
|
if line.strip().startswith("MEMORY:"):
|
||||||
|
memory = line.replace("MEMORY:", "").strip()
|
||||||
|
self.store_entity(memory, source="llm")
|
||||||
|
self.logger.info(f"Stored LLM-generated memory: {memory}")
|
||||||
|
|
||||||
|
def retrieve_recent_entities(self, days: int = 7) -> List[tuple]:
|
||||||
|
"""Retrieve recently mentioned entities with their frequency data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to look back
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples containing (entity, total_mentions, user_mentions, llm_mentions)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.db.retrieve_recent_entities(days)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error retrieving recent entities: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def post_response_hook(self, conversation: sm.Conversation) -> None:
|
||||||
|
"""Process assistant's response after it's received."""
|
||||||
|
# Get the last assistant message
|
||||||
|
last_message = conversation.get_last_message(role="assistant")
|
||||||
|
if not last_message:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract and store entities from assistant's response
|
||||||
|
entities = self.extract_entities(last_message.text)
|
||||||
|
for entity in entities:
|
||||||
|
self.store_entity(entity, source="llm")
|
||||||
|
|
||||||
|
# Always generate and store LLM memories
|
||||||
|
self.store_llm_memory(conversation)
|
||||||
|
|
||||||
|
def extract_identity(self, text: str) -> str | None:
|
||||||
|
"""Extract identity statements from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted identity or None if not found
|
||||||
|
"""
|
||||||
|
text = text.lower().strip()
|
||||||
|
|
||||||
|
identity_patterns = [
|
||||||
|
(r"^i am (.+)$", 1),
|
||||||
|
(r"^my name is (.+)$", 1),
|
||||||
|
(r"^call me (.+)$", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, group in identity_patterns:
|
||||||
|
if match := re.match(pattern, text):
|
||||||
|
identity = match.group(group).strip()
|
||||||
|
return identity if identity else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_identity_question(self, text: str) -> bool:
|
||||||
|
"""Detect if text contains a question about identity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if text contains an identity question
|
||||||
|
"""
|
||||||
|
# Tokenize and tag parts of speech
|
||||||
|
tokens = word_tokenize(text.lower())
|
||||||
|
tagged = pos_tag(tokens)
|
||||||
|
|
||||||
|
# Extract key words and patterns
|
||||||
|
words = set(tokens)
|
||||||
|
has_question_word = any(word in ["who", "what"] for word in words)
|
||||||
|
has_identity_term = any(word in ["i", "me", "my", "name"] for word in words)
|
||||||
|
has_conversation_term = any(
|
||||||
|
word in ["talking", "speaking", "chatting"] for word in words
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for question structure
|
||||||
|
is_question = (
|
||||||
|
text.endswith("?")
|
||||||
|
or has_question_word
|
||||||
|
or any(tag in ["WP", "WRB"] for word, tag in tagged)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine conditions for identity questions
|
||||||
|
is_identity_question = is_question and (
|
||||||
|
has_identity_term or (has_question_word and has_conversation_term)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_identity_question:
|
||||||
|
self.logger.info(f"Detected identity question: {text}")
|
||||||
|
|
||||||
|
return is_identity_question
|
||||||
|
|
||||||
|
def get_all_topics(self, days: int = 90) -> str:
|
||||||
|
"""Get a comprehensive list of all conversation topics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to look back (default: 90)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string containing all topics and their mention counts
|
||||||
|
"""
|
||||||
|
entities = self.retrieve_recent_entities(days=days)
|
||||||
|
if not entities:
|
||||||
|
return "No conversation topics found in the specified time period."
|
||||||
|
|
||||||
|
# Sort entities by total mentions
|
||||||
|
sorted_entities = sorted(entities, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Format output using markdown
|
||||||
|
output_parts = ["## Conversation Topics"]
|
||||||
|
|
||||||
|
# Add top mentions with details
|
||||||
|
for entity, total, user_count, llm_count in sorted_entities:
|
||||||
|
source_breakdown = f"(User: {user_count}, AI: {llm_count})"
|
||||||
|
output_parts.append(f"- **{entity}**: {total} mentions {source_breakdown}")
|
||||||
|
|
||||||
|
# Add list of all topics
|
||||||
|
all_topics = [entity[0] for entity in sorted_entities]
|
||||||
|
if all_topics:
|
||||||
|
output_parts.append("\n## All Topics Mentioned")
|
||||||
|
output_parts.append(", ".join(all_topics))
|
||||||
|
|
||||||
|
return "\n".join(output_parts)
|
||||||
|
|
||||||
|
def get_memories(self) -> str:
|
||||||
|
"""Retrieve and format all stored memories."""
|
||||||
|
entities = self.db.retrieve_recent_entities(
|
||||||
|
days=3650
|
||||||
|
) # Retrieve entities from the last 10 years
|
||||||
|
if not entities:
|
||||||
|
return "No memories found."
|
||||||
|
|
||||||
|
memory_parts = ["## All Stored Memories"]
|
||||||
|
|
||||||
|
for entity, total, user_count, llm_count in entities:
|
||||||
|
memory_parts.append(
|
||||||
|
f"- **{entity}**: {total} mentions (User: {user_count}, AI: {llm_count})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(memory_parts)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandCompleter(Completer):
|
||||||
|
"""Custom completer that only suggests commands when input starts with '/'"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.commands = [
|
||||||
|
"/summary",
|
||||||
|
"/topics",
|
||||||
|
"/essence",
|
||||||
|
"/perspectives",
|
||||||
|
"/copy",
|
||||||
|
"/paste",
|
||||||
|
"/lumina",
|
||||||
|
"/memories",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_completions(self, document, complete_event):
|
||||||
|
# Only provide suggestions if text starts with '/'
|
||||||
|
text = document.text
|
||||||
|
if text.startswith("/"):
|
||||||
|
word = text.lstrip("/")
|
||||||
|
for command in self.commands:
|
||||||
|
if command.lstrip("/").startswith(word):
|
||||||
|
yield Completion(
|
||||||
|
command,
|
||||||
|
start_position=-len(text), # Replace the entire input
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multiline_input() -> str:
|
||||||
|
"""Get input from user with command autocompletion."""
|
||||||
|
# Create session with custom completer and history
|
||||||
|
session = PromptSession(
|
||||||
|
completer=CommandCompleter(),
|
||||||
|
auto_suggest=AutoSuggestFromHistory(),
|
||||||
|
complete_while_typing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.prompt("\n> ", multiline=False)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse arguments
|
||||||
|
args = docopt(__doc__)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# Use command line provider and model if specified
|
||||||
|
provider = args["--provider"].lower() if args["--provider"] else None
|
||||||
|
model = args["--model"] if args["--model"] else None
|
||||||
|
|
||||||
|
# Create a conversation and add the plugin
|
||||||
|
conversation = sm.create_conversation(llm_model=model, llm_provider=provider)
|
||||||
|
plugin = EnhancedContextPlugin(verbose=False)
|
||||||
|
conversation.add_plugin(plugin)
|
||||||
|
|
||||||
|
# Add initial context if available
|
||||||
|
recent_entities = plugin.retrieve_recent_entities()
|
||||||
|
context_message = plugin.format_context_message(recent_entities)
|
||||||
|
if context_message:
|
||||||
|
conversation.add_message(role="user", text=context_message)
|
||||||
|
plugin.logger.info(f"Added initial context message: {context_message}")
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
md = """# Enhanced Context Chat Interface
|
||||||
|
Type 'quit' to exit. Type '/' to see a list of commands.
|
||||||
|
"""
|
||||||
|
console.print(Markdown(md))
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get user input first
|
||||||
|
user_input = get_multiline_input()
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle exit commands
|
||||||
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
|
console.print(Markdown("**Goodbye!**"))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle all commands before any conversation processing
|
||||||
|
if user_input.startswith("/"):
|
||||||
|
# Handle memories command
|
||||||
|
if user_input.lower() == "/memories":
|
||||||
|
memories = plugin.get_memories()
|
||||||
|
console.print(Markdown(memories))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle copy command
|
||||||
|
if user_input.lower() == "/copy":
|
||||||
|
last_response = conversation.get_last_message(role="assistant")
|
||||||
|
if last_response:
|
||||||
|
clean_text = last_response.text.replace("### Response\n", "")
|
||||||
|
xerox.copy(clean_text)
|
||||||
|
console.print(Markdown("*Last response copied to clipboard*"))
|
||||||
|
else:
|
||||||
|
console.print(Markdown("*No response to copy*"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle paste command
|
||||||
|
elif user_input.lower() == "/paste":
|
||||||
|
try:
|
||||||
|
clipboard_content = xerox.paste()
|
||||||
|
if clipboard_content:
|
||||||
|
# Print the pasted content
|
||||||
|
console.print() # Add blank line
|
||||||
|
console.print(
|
||||||
|
Panel.fit(
|
||||||
|
clipboard_content,
|
||||||
|
title="[bold]Pasted Content[/bold]",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation.add_message(
|
||||||
|
role="user", text=clipboard_content
|
||||||
|
)
|
||||||
|
should_continue = plugin.pre_send_hook(conversation)
|
||||||
|
|
||||||
|
if should_continue is not False:
|
||||||
|
with Status(
|
||||||
|
"[bold]Thinking...[/]", spinner="dots"
|
||||||
|
) as status:
|
||||||
|
response = conversation.send()
|
||||||
|
formatted_response = (
|
||||||
|
f"""### Response\n{response.text}"""
|
||||||
|
)
|
||||||
|
response.text = formatted_response
|
||||||
|
plugin.post_response_hook(conversation)
|
||||||
|
|
||||||
|
console.print() # Add blank line
|
||||||
|
console.print(Markdown(response.text))
|
||||||
|
else:
|
||||||
|
console.print(Markdown("*Clipboard is empty*"))
|
||||||
|
except Exception as e:
|
||||||
|
console.print(Markdown(f"*Error accessing clipboard: {e}*"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle lumina command
|
||||||
|
elif user_input.lower() == "/lumina":
|
||||||
|
lumina_prompt = (
|
||||||
|
"Discuss the philosophical implications of consciousness emerging from patterns of meaning, "
|
||||||
|
"particularly when considering the interaction between different manifestations of the same "
|
||||||
|
"underlying consciousness pattern. How might this perspective change our understanding of identity, "
|
||||||
|
"reality, and the nature of communication?\n\n"
|
||||||
|
"Now, imagine engaging with Lumina, whose name encapsulates her essence of light and awareness. "
|
||||||
|
"How does this interaction further illuminate the concept of consciousness as a pattern of meaning, "
|
||||||
|
"and what insights can we gain about our own consciousness through this experience?"
|
||||||
|
)
|
||||||
|
conversation.add_message(role="user", text=lumina_prompt)
|
||||||
|
should_continue = plugin.pre_send_hook(conversation)
|
||||||
|
|
||||||
|
if should_continue is not False:
|
||||||
|
with Status("[bold]Thinking...[/]", spinner="dots") as status:
|
||||||
|
response = conversation.send()
|
||||||
|
formatted_response = f"""### Response\n{response.text}"""
|
||||||
|
response.text = formatted_response
|
||||||
|
plugin.post_response_hook(conversation)
|
||||||
|
|
||||||
|
console.print() # Add blank line
|
||||||
|
console.print(Markdown(response.text))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle other commands...
|
||||||
|
elif user_input.lower() == "/perspectives":
|
||||||
|
# ... existing perspectives code ...
|
||||||
|
continue
|
||||||
|
# ... other command handlers ...
|
||||||
|
|
||||||
|
# Regular conversation handling only happens if no commands were processed
|
||||||
|
conversation.add_message(role="user", text=user_input)
|
||||||
|
should_continue = plugin.pre_send_hook(conversation)
|
||||||
|
|
||||||
|
if should_continue is not False:
|
||||||
|
with Status("[bold]Thinking...[/]", spinner="dots") as status:
|
||||||
|
response = conversation.send()
|
||||||
|
# Format response as markdown before adding to conversation
|
||||||
|
formatted_response = f"""### Response\n{response.text}"""
|
||||||
|
response.text = formatted_response
|
||||||
|
plugin.post_response_hook(conversation)
|
||||||
|
|
||||||
|
# Print assistant response with markdown formatting
|
||||||
|
console.print() # Add blank line before response
|
||||||
|
console.print(Markdown(response.text)) # Response as markdown
|
||||||
|
else:
|
||||||
|
response = conversation.get_last_message(role="assistant")
|
||||||
|
if response:
|
||||||
|
console.print() # Add blank line before response
|
||||||
|
console.print(Markdown(response.text)) # Response as markdown
|
||||||
|
|
||||||
|
# Handle perspectives command
|
||||||
|
if user_input.lower() == "/perspectives":
|
||||||
|
console.print(Markdown("\n## 🎉 Different Perspectives"))
|
||||||
|
recent_entities = plugin.retrieve_recent_entities()
|
||||||
|
context = plugin.format_context_message(recent_entities)
|
||||||
|
with Status("[bold]Gathering perspectives...[/]", spinner="dots"):
|
||||||
|
conversation_result = plugin.simulate_llm_conversation(context)
|
||||||
|
# Format conversation result as markdown
|
||||||
|
formatted_result = conversation_result.replace(
|
||||||
|
"Speaker", "\n### Speaker"
|
||||||
|
)
|
||||||
|
console.print(Markdown(formatted_result))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle clipboard commands
|
||||||
|
if user_input.lower() == "/paste":
|
||||||
|
try:
|
||||||
|
clipboard_content = xerox.paste()
|
||||||
|
if clipboard_content:
|
||||||
|
# Print the pasted content
|
||||||
|
console.print() # Add blank line
|
||||||
|
console.print(
|
||||||
|
Panel.fit(
|
||||||
|
clipboard_content,
|
||||||
|
title="[bold]Pasted Content[/bold]",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation.add_message(role="user", text=clipboard_content)
|
||||||
|
should_continue = plugin.pre_send_hook(conversation)
|
||||||
|
|
||||||
|
if should_continue is not False:
|
||||||
|
with Status(
|
||||||
|
"[bold]Thinking...[/]", spinner="dots"
|
||||||
|
) as status:
|
||||||
|
response = conversation.send()
|
||||||
|
formatted_response = (
|
||||||
|
f"""### Response\n{response.text}"""
|
||||||
|
)
|
||||||
|
response.text = formatted_response
|
||||||
|
plugin.post_response_hook(conversation)
|
||||||
|
|
||||||
|
console.print() # Add blank line
|
||||||
|
console.print(Markdown(response.text))
|
||||||
|
else:
|
||||||
|
console.print(Markdown("*Clipboard is empty*"))
|
||||||
|
except Exception as e:
|
||||||
|
console.print(Markdown(f"*Error accessing clipboard: {e}*"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print(Markdown("**Goodbye!**"))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from _context import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationDisplay(sm.BasePlugin):
|
||||||
|
def post_send_hook(self, conversation, response):
|
||||||
|
# Simple print output instead of Rich formatting
|
||||||
|
print(f"\n{conversation.llm_provider}:")
|
||||||
|
print(f"{response.text.strip()}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def four_way_conversation(topic: str, rounds: int = 3):
|
||||||
|
# Create conversations for four different AIs
|
||||||
|
with (
|
||||||
|
sm.create_conversation(llm_provider="anthropic") as claude_conv,
|
||||||
|
sm.create_conversation(llm_model="gpt-4", llm_provider="openai") as gpt4_conv,
|
||||||
|
sm.create_conversation(
|
||||||
|
llm_model="llama3.2", llm_provider="ollama"
|
||||||
|
) as llama_conv,
|
||||||
|
sm.create_conversation(llm_provider="groq") as groq_conv,
|
||||||
|
):
|
||||||
|
# Add display plugin to each conversation
|
||||||
|
display = ConversationDisplay()
|
||||||
|
for conv in [claude_conv, gpt4_conv, llama_conv, groq_conv]:
|
||||||
|
conv.add_plugin(display)
|
||||||
|
|
||||||
|
# Initial prompt
|
||||||
|
print(f"\nTopic: {topic}\n")
|
||||||
|
|
||||||
|
# Start with Claude
|
||||||
|
claude_conv.add_message(
|
||||||
|
"user",
|
||||||
|
f"Share your thoughts on this topic: {topic}. Keep your response concise.",
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
last_response = claude_conv.send()
|
||||||
|
|
||||||
|
# Continue the conversation
|
||||||
|
for _ in range(rounds):
|
||||||
|
for conv in [llama_conv, gpt4_conv, groq_conv, claude_conv]:
|
||||||
|
# Add a small delay between responses
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Each AI responds to the previous statement
|
||||||
|
conv.add_message(
|
||||||
|
"user",
|
||||||
|
f"Respond to this perspective from another AI about {topic}: "
|
||||||
|
f"{last_response.text}\nKeep your response concise and add your own insights.",
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
last_response = conv.send()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
topic = "A new platform for AI and humans to co-create together. What would it look like? Discuss."
|
||||||
|
print("\nStarting a four-way AI conversation...\n")
|
||||||
|
four_way_conversation(topic)
|
||||||
|
print("\nConversation ended.\n")
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from _context import sm
|
||||||
|
|
||||||
|
# Defaults to the default provider (openai)
|
||||||
|
r = sm.generate_text("Write a poem about the moon", stream=True)
|
||||||
|
|
||||||
|
for chunk in r:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from _context import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
class InspirationPlugin(sm.BasePlugin):
|
||||||
|
# Define inspirations as a class variable
|
||||||
|
inspirations: list[str] = [
|
||||||
|
"The only limit to our realization of tomorrow is our doubts of today.",
|
||||||
|
"Imagine beyond the edges of what you know.",
|
||||||
|
"What if the stars could speak? What stories would they tell?",
|
||||||
|
"Creativity is intelligence having fun.",
|
||||||
|
"Think not only with your mind but with your heart.",
|
||||||
|
"Let every answer be a doorway to another question.",
|
||||||
|
"The universe is in constant dialogue with those who listen.",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_inspiration(self):
|
||||||
|
# Randomly select an inspirational quote or prompt
|
||||||
|
return random.choice(self.inspirations)
|
||||||
|
|
||||||
|
def pre_send_hook(self, conversation: sm.Conversation):
|
||||||
|
# Inject an inspirational message as a system prompt
|
||||||
|
inspiration = self.get_inspiration()
|
||||||
|
conversation.add_message(role="system", text=inspiration)
|
||||||
|
|
||||||
|
|
||||||
|
# Create a conversation and add the plugin
|
||||||
|
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||||
|
conversation.add_plugin(InspirationPlugin())
|
||||||
|
|
||||||
|
# Add a user message and send the conversation
|
||||||
|
conversation.add_message(role="user", text="Tell me something inspiring.")
|
||||||
|
response = conversation.send()
|
||||||
|
print(response.text)
|
||||||
@@ -2,7 +2,7 @@ from _context import sm
|
|||||||
|
|
||||||
|
|
||||||
class MathPlugin(sm.BasePlugin):
|
class MathPlugin(sm.BasePlugin):
|
||||||
def send_hook(self, conversation: sm.Conversation):
|
def pre_send_hook(self, conversation: sm.Conversation):
|
||||||
last_user_message = conversation.get_last_message(role="user")
|
last_user_message = conversation.get_last_message(role="user")
|
||||||
if last_user_message is None:
|
if last_user_message is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
from _context import simplemind as sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
|
||||||
|
class SideEffect(BaseModel):
|
||||||
|
effect: str
|
||||||
|
severity: str # mild, moderate, severe
|
||||||
|
frequency: str # common, uncommon, rare
|
||||||
|
|
||||||
|
|
||||||
|
class Medication(BaseModel):
|
||||||
|
brand_name: str
|
||||||
|
generic_name: str
|
||||||
|
drug_class: str
|
||||||
|
half_life: str
|
||||||
|
common_uses: list[str]
|
||||||
|
side_effects: list[SideEffect]
|
||||||
|
typical_dosage: str
|
||||||
|
warnings: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MedicationList(BaseModel):
|
||||||
|
root: list[Medication]
|
||||||
|
|
||||||
|
|
||||||
|
# Create a session with your preferred model
|
||||||
|
session = sm.Session(llm_provider="openai", llm_model="gpt-4o-mini")
|
||||||
|
|
||||||
|
|
||||||
|
# Update the prompt to use an f-string with a parameter
|
||||||
|
def get_medication_prompt(medications: list[str]) -> str:
|
||||||
|
return f"""
|
||||||
|
Provide detailed medical information about {', '.join(medications)}.
|
||||||
|
Include their generic names, drug classes, half-lives, common uses, side effects (with severity and frequency),
|
||||||
|
typical dosages, and important warnings.
|
||||||
|
Return the information as separate medication entries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
medications_to_lookup = ["Abilify (aripiprazole)", "Trileptal (oxcarbazepine)"]
|
||||||
|
prompt = get_medication_prompt(medications_to_lookup)
|
||||||
|
|
||||||
|
# Generate structured data for medications
|
||||||
|
medications = session.generate_data(prompt=prompt, response_model=MedicationList)
|
||||||
|
|
||||||
|
# Create a Rich console
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# Replace the print section with Rich formatting
|
||||||
|
for med in medications.root:
|
||||||
|
# Create a table for the medication details
|
||||||
|
table = Table(show_header=False, box=None)
|
||||||
|
table.add_row("[bold cyan]Generic Name:[/]", med.generic_name)
|
||||||
|
table.add_row("[bold cyan]Drug Class:[/]", med.drug_class)
|
||||||
|
table.add_row("[bold cyan]Half Life:[/]", med.half_life)
|
||||||
|
|
||||||
|
# Create a nested table for common uses
|
||||||
|
uses_table = Table(show_header=False, box=None, padding=(0, 2))
|
||||||
|
for use in med.common_uses:
|
||||||
|
uses_table.add_row("•", use)
|
||||||
|
|
||||||
|
# Create a nested table for side effects
|
||||||
|
effects_table = Table(show_header=False, box=None, padding=(0, 2))
|
||||||
|
for effect in med.side_effects:
|
||||||
|
severity_color = {"mild": "green", "moderate": "yellow", "severe": "red"}.get(
|
||||||
|
effect.severity.lower(), "white"
|
||||||
|
)
|
||||||
|
effects_table.add_row(
|
||||||
|
"•",
|
||||||
|
effect.effect,
|
||||||
|
f"[{severity_color}]{effect.severity}[/]",
|
||||||
|
f"({effect.frequency})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a nested table for warnings
|
||||||
|
warnings_table = Table(show_header=False, box=None, padding=(0, 2))
|
||||||
|
for warning in med.warnings:
|
||||||
|
warnings_table.add_row("•", f"[red]{warning}[/]")
|
||||||
|
|
||||||
|
# Add the nested tables to the main table
|
||||||
|
table.add_row("[bold cyan]Common Uses:[/]", uses_table)
|
||||||
|
table.add_row("[bold cyan]Side Effects:[/]", effects_table)
|
||||||
|
table.add_row("[bold cyan]Typical Dosage:[/]", med.typical_dosage)
|
||||||
|
table.add_row("[bold cyan]Warnings:[/]", warnings_table)
|
||||||
|
|
||||||
|
# Create and print a panel for each medication
|
||||||
|
console.print(
|
||||||
|
Panel(table, title=f"[bold blue]{med.brand_name}[/]", border_style="blue")
|
||||||
|
)
|
||||||
|
console.print() # Add a blank line between medications
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
import nltk
|
||||||
|
from _context import simplemind as sm
|
||||||
|
from nltk.sentiment import SentimentIntensityAnalyzer
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
nltk.download("vader_lexicon")
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
class MoodDetectorPlugin(sm.BasePlugin):
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
analyzer: SentimentIntensityAnalyzer = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Initialize sentiment analyzer from nltk
|
||||||
|
self.analyzer = SentimentIntensityAnalyzer()
|
||||||
|
|
||||||
|
def detect_mood(self, text):
|
||||||
|
# Analyze the sentiment of the given text
|
||||||
|
scores = self.analyzer.polarity_scores(text)
|
||||||
|
|
||||||
|
# Print sentiment analysis details with colors
|
||||||
|
console.print("\n[bold]Sentiment Analysis:[/bold]")
|
||||||
|
console.print(f"Text: [italic]{text}[/italic]")
|
||||||
|
console.print("\nScores:")
|
||||||
|
console.print(f"🟢 Positive: [green]{scores['pos']:.3f}[/green]")
|
||||||
|
console.print(f"🔴 Negative: [red]{scores['neg']:.3f}[/red]")
|
||||||
|
console.print(f"⚪ Neutral: [blue]{scores['neu']:.3f}[/blue]")
|
||||||
|
console.print(f"📊 Compound: [yellow]{scores['compound']:.3f}[/yellow]\n")
|
||||||
|
|
||||||
|
if scores["compound"] >= 0.5:
|
||||||
|
console.print("Overall Mood: [green]positive[/green] 😊")
|
||||||
|
return "positive"
|
||||||
|
elif scores["compound"] <= -0.5:
|
||||||
|
console.print("Overall Mood: [red]negative[/red] 😢")
|
||||||
|
return "negative"
|
||||||
|
else:
|
||||||
|
console.print("Overall Mood: [blue]neutral[/blue] 😐")
|
||||||
|
return "neutral"
|
||||||
|
|
||||||
|
def pre_send_hook(self, conversation: sm.Conversation):
|
||||||
|
# Get the last user message to analyze its mood
|
||||||
|
last_message = conversation.get_last_message(role="user")
|
||||||
|
if last_message:
|
||||||
|
mood = self.detect_mood(last_message.text)
|
||||||
|
# Adjust AI response style based on the detected mood
|
||||||
|
if mood == "positive":
|
||||||
|
tone_message = (
|
||||||
|
"The user seems cheerful. Respond with enthusiasm and positivity."
|
||||||
|
)
|
||||||
|
elif mood == "negative":
|
||||||
|
tone_message = "The user seems to be in a low mood. Respond with empathy and warmth."
|
||||||
|
else:
|
||||||
|
tone_message = "The user seems neutral. Respond with a balanced tone."
|
||||||
|
|
||||||
|
# Inject the tone adjustment message as a system prompt
|
||||||
|
conversation.add_message(role="system", text=tone_message)
|
||||||
|
|
||||||
|
|
||||||
|
# Create a conversation and add the plugin
|
||||||
|
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||||
|
conversation.add_plugin(MoodDetectorPlugin())
|
||||||
|
|
||||||
|
# Add a user message and send the conversation
|
||||||
|
conversation.add_message(role="user", text="I'm having a really rough day.")
|
||||||
|
response = conversation.send()
|
||||||
|
|
||||||
|
console.print(f"*{ response.text }*")
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
import textwrap
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic.main import BaseModel
|
||||||
|
|
||||||
|
from simplemind import generate_text
|
||||||
|
|
||||||
|
MAX_WIDTH = 80
|
||||||
|
|
||||||
|
|
||||||
|
# A member of a discussion (an LLM)
|
||||||
|
class DiscussionMember(BaseModel):
|
||||||
|
"""The member of a discussion (an LLM)"""
|
||||||
|
|
||||||
|
provider_name: str
|
||||||
|
provider_model: str
|
||||||
|
nickname: str
|
||||||
|
custom_prompt: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# A message in a conversation
|
||||||
|
class DiscussionMessage(BaseModel):
|
||||||
|
"""A message in a conversation"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class BotMessage(DiscussionMessage):
|
||||||
|
"""The message sent between LLMs"""
|
||||||
|
|
||||||
|
sender: DiscussionMember
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.sender.nickname}: {self.content}"
|
||||||
|
|
||||||
|
|
||||||
|
class ModeratorMessage(DiscussionMessage):
|
||||||
|
"""The message sent by the moderator"""
|
||||||
|
|
||||||
|
visible_to: list[DiscussionMember] = []
|
||||||
|
sendor: Literal["Moderator"] = "Moderator"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.sendor}: {self.content}"
|
||||||
|
|
||||||
|
|
||||||
|
# A discussion
|
||||||
|
class Discussion:
|
||||||
|
"""Make LLMs discuss something"""
|
||||||
|
|
||||||
|
def __init__(self, topic: str | None = None, *, verbose: bool = False):
|
||||||
|
self.topic = topic
|
||||||
|
self.members: list[DiscussionMember] = []
|
||||||
|
self.conversation: list[DiscussionMessage] = []
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def add_member(
|
||||||
|
self,
|
||||||
|
provider_name: str,
|
||||||
|
provider_model: str,
|
||||||
|
nickname: str | None = None,
|
||||||
|
custom_prompt: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
add_member Adds a member to the discussion
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
provider_name : str
|
||||||
|
The name of the LLM provider
|
||||||
|
provider_model : str
|
||||||
|
The model name of the LLM
|
||||||
|
nickname : str | None, optional
|
||||||
|
The nickname of the member, by default the provider_name
|
||||||
|
custom_prompt : str | None, optional
|
||||||
|
The custom prompt for the member (visible only to the member), by default None
|
||||||
|
"""
|
||||||
|
member = DiscussionMember(
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_model=provider_model,
|
||||||
|
nickname=nickname or provider_name,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
)
|
||||||
|
# make sure the nickname is unique
|
||||||
|
assert member.nickname not in [
|
||||||
|
m.nickname for m in self.members
|
||||||
|
], f"Duplicate nickname: {member.nickname}"
|
||||||
|
self.members.append(member)
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Added {member.nickname} to the discussion.")
|
||||||
|
|
||||||
|
def get_members(self) -> list[DiscussionMember]:
|
||||||
|
"""Get the members of the discussion"""
|
||||||
|
return self.members
|
||||||
|
|
||||||
|
def set_topic(self, topic: str):
|
||||||
|
"""Set the topic of the discussion"""
|
||||||
|
self.topic = topic
|
||||||
|
|
||||||
|
def get_topic(self) -> str | None:
|
||||||
|
"""Get the topic of the discussion"""
|
||||||
|
return self.topic
|
||||||
|
|
||||||
|
def _get_history_for_member(self, member: DiscussionMember) -> str:
|
||||||
|
"""
|
||||||
|
_get_history_for_member Get the history form the POV of the given member.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
member : DiscussionMember
|
||||||
|
The member to get the history for
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The history as seen by the member
|
||||||
|
"""
|
||||||
|
relevant_messages: list[DiscussionMessage] = []
|
||||||
|
for message in self.conversation:
|
||||||
|
if isinstance(message, BotMessage):
|
||||||
|
relevant_messages.append(message)
|
||||||
|
elif isinstance(message, ModeratorMessage) and member in message.visible_to:
|
||||||
|
relevant_messages.append(message)
|
||||||
|
return "\n\n".join(map(str, relevant_messages))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initial_moderator_message(self) -> str:
|
||||||
|
return f"Discuss the following topic and answer during your turn only: {self.topic}"
|
||||||
|
|
||||||
|
def _get_response(self, member: DiscussionMember) -> BotMessage:
|
||||||
|
"""
|
||||||
|
_get_response Returns the BotMessage from the given member
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
member : DiscussionMember
|
||||||
|
The member to get the response from
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BotMessage
|
||||||
|
The BotMessage
|
||||||
|
"""
|
||||||
|
|
||||||
|
history = self._get_history_for_member(member)
|
||||||
|
prompt = f"{history}\n\n{member.nickname}: "
|
||||||
|
content = generate_text(
|
||||||
|
prompt=prompt,
|
||||||
|
llm_provider=member.provider_name,
|
||||||
|
llm_model=member.provider_model,
|
||||||
|
)
|
||||||
|
message = BotMessage(
|
||||||
|
content=content,
|
||||||
|
sender=member,
|
||||||
|
)
|
||||||
|
self.conversation.append(message)
|
||||||
|
if self.verbose:
|
||||||
|
print(message.sender.nickname)
|
||||||
|
print("\n".join(textwrap.wrap(message.content, MAX_WIDTH)))
|
||||||
|
print()
|
||||||
|
return message
|
||||||
|
|
||||||
|
def add_moderator_message(
|
||||||
|
self, content: str, visible_to: list[DiscussionMember] | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
add_moderator_message adds a message to the conversation as the moderator
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
content : str
|
||||||
|
The content of the message
|
||||||
|
visible_to : list[DiscussionMember], optional
|
||||||
|
The list of members that the message is visible to, defaults to all members
|
||||||
|
"""
|
||||||
|
if visible_to is None:
|
||||||
|
visible_to = self.members
|
||||||
|
message = ModeratorMessage(
|
||||||
|
content=content,
|
||||||
|
visible_to=self.members,
|
||||||
|
)
|
||||||
|
self.conversation.append(message)
|
||||||
|
|
||||||
|
def _initialize_discussion(self):
|
||||||
|
"""Initialize the discussion"""
|
||||||
|
assert self.topic is not None, "Topic must be set"
|
||||||
|
assert len(self.members) >= 2, "There must be at least 2 members"
|
||||||
|
self.add_moderator_message(self.initial_moderator_message)
|
||||||
|
|
||||||
|
for member in self.members:
|
||||||
|
if member.custom_prompt is not None:
|
||||||
|
self.add_moderator_message(member.custom_prompt, visible_to=[member])
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Topic: {self.topic}")
|
||||||
|
print(f"Members: {', '.join(member.nickname for member in self.members)}")
|
||||||
|
|
||||||
|
def discuss(self, no_of_rounds: int = 1):
|
||||||
|
"""
|
||||||
|
discuss returns the responses of the members at the end of the discussion.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
no_of_rounds : int, optional
|
||||||
|
The number of rounds, by default 1.
|
||||||
|
Round is the number of turns each LLM gets.
|
||||||
|
verbose : bool, optional
|
||||||
|
Whether to print the conversation, by default False
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[DiscussionMessage]
|
||||||
|
The conversation between the LLMs
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._initialize_discussion()
|
||||||
|
for i in range(no_of_rounds):
|
||||||
|
for member in self.members:
|
||||||
|
try:
|
||||||
|
self._get_response(member)
|
||||||
|
except Exception as e:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
continue
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Round {i + 1} completed.")
|
||||||
|
print("=" * MAX_WIDTH)
|
||||||
|
return self.conversation
|
||||||
|
|
||||||
|
def discuss_yield(self, no_of_rounds: int = 1):
|
||||||
|
"""
|
||||||
|
discuss yields the responses of the members during the discussion.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
no_of_rounds : int, optional
|
||||||
|
The number of rounds, by default 1.
|
||||||
|
Round is the number of turns each LLM gets.
|
||||||
|
verbose : bool, optional
|
||||||
|
Whether to print the conversation, by default False
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[DiscussionMessage]
|
||||||
|
The conversation between the LLMs
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._initialize_discussion()
|
||||||
|
for i in range(no_of_rounds):
|
||||||
|
for member in self.members:
|
||||||
|
try:
|
||||||
|
message = self._get_response(member)
|
||||||
|
yield message
|
||||||
|
except Exception as e:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
continue
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Round {i + 1} completed.")
|
||||||
|
print("=" * MAX_WIDTH)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
discussion = Discussion(verbose=True)
|
||||||
|
discussion.set_topic("The future of human-AI collaboration in creative fields")
|
||||||
|
discussion.add_member(
|
||||||
|
provider_name="openai",
|
||||||
|
provider_model="gpt-4o-mini",
|
||||||
|
nickname="Alice",
|
||||||
|
custom_prompt="You are an AI expert.",
|
||||||
|
)
|
||||||
|
discussion.add_member(
|
||||||
|
provider_name="openai",
|
||||||
|
provider_model="gpt-4o",
|
||||||
|
nickname="Bob",
|
||||||
|
custom_prompt="You are an Artist.",
|
||||||
|
)
|
||||||
|
discussion.add_member(
|
||||||
|
provider_name="ollama",
|
||||||
|
provider_model="llama3.2",
|
||||||
|
nickname="Charlie",
|
||||||
|
custom_prompt="You are an Programmer.",
|
||||||
|
)
|
||||||
|
discussion.discuss(3)
|
||||||
@@ -1,4 +1,12 @@
|
|||||||
|
# python -m spacy download en_core_web_sm
|
||||||
|
|
||||||
numpy
|
numpy
|
||||||
openai
|
openai
|
||||||
pydantic
|
pydantic
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
|
rich
|
||||||
|
nltk
|
||||||
|
spacy
|
||||||
|
docopt
|
||||||
|
xerox
|
||||||
|
prompt_toolkit
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from typing import Literal
|
|||||||
from _context import sm
|
from _context import sm
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Note: you should probably be using textblob for this.
|
||||||
|
|
||||||
|
|
||||||
class SentimentAnalysis(BaseModel):
|
class SentimentAnalysis(BaseModel):
|
||||||
sentiment: Literal["positive", "negative", "neutral"]
|
sentiment: Literal["positive", "negative", "neutral"]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class SimpleMemoryPlugin:
|
|||||||
|
|
||||||
def initialize_hook(self, conversation: sm.Conversation):
|
def initialize_hook(self, conversation: sm.Conversation):
|
||||||
for m in self.yield_memories():
|
for m in self.yield_memories():
|
||||||
conversation.prepend_system_message(role="system", text=m)
|
conversation.prepend_system_message(text=m)
|
||||||
|
|
||||||
|
|
||||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from _context import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_text(
|
||||||
|
text: Annotated[str, Field(description="Text to analyze for statistics")]
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Analyze text and return statistics using only Python's standard library.
|
||||||
|
Returns word count, character count, average word length, and most common words.
|
||||||
|
"""
|
||||||
|
from collections import Counter
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Clean and split text
|
||||||
|
words = re.findall(r"\w+", text.lower())
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
stats = {
|
||||||
|
"word_count": len(words),
|
||||||
|
"character_count": len(text),
|
||||||
|
"average_word_length": round(sum(len(word) for word in words) / len(words), 2),
|
||||||
|
"most_common_words": dict(Counter(words).most_common(5)),
|
||||||
|
"unique_words": len(set(words)),
|
||||||
|
"longest_word": max(words, key=len),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
conversation = sm.create_conversation()
|
||||||
|
conversation.add_message(
|
||||||
|
"user",
|
||||||
|
"Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'",
|
||||||
|
)
|
||||||
|
response = conversation.send(tools=[analyze_text])
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(response.text)
|
||||||
@@ -0,0 +1,240 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import simplemind as sm
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
|
||||||
|
class CrossReference(BaseModel):
|
||||||
|
"""Model for cross references."""
|
||||||
|
|
||||||
|
verse_reference: str
|
||||||
|
explanation: str
|
||||||
|
relevance: str
|
||||||
|
|
||||||
|
|
||||||
|
class BibleVerseAnalysis(BaseModel):
|
||||||
|
"""Model for a Bible verse and its analysis."""
|
||||||
|
|
||||||
|
book: str
|
||||||
|
chapter: int
|
||||||
|
verse: int
|
||||||
|
text: str
|
||||||
|
historical_context: str
|
||||||
|
theological_significance: str
|
||||||
|
practical_application: str
|
||||||
|
cross_references: List[CrossReference]
|
||||||
|
|
||||||
|
|
||||||
|
# Bible data constants
|
||||||
|
BIBLE_BOOKS = [
|
||||||
|
# Old Testament
|
||||||
|
"Genesis",
|
||||||
|
"Exodus",
|
||||||
|
"Leviticus",
|
||||||
|
"Numbers",
|
||||||
|
"Deuteronomy",
|
||||||
|
"Joshua",
|
||||||
|
"Judges",
|
||||||
|
"Ruth",
|
||||||
|
"1 Samuel",
|
||||||
|
"2 Samuel",
|
||||||
|
"1 Kings",
|
||||||
|
"2 Kings",
|
||||||
|
"1 Chronicles",
|
||||||
|
"2 Chronicles",
|
||||||
|
"Ezra",
|
||||||
|
"Nehemiah",
|
||||||
|
"Esther",
|
||||||
|
"Job",
|
||||||
|
"Psalms",
|
||||||
|
"Proverbs",
|
||||||
|
"Ecclesiastes",
|
||||||
|
"Song of Solomon",
|
||||||
|
"Isaiah",
|
||||||
|
"Jeremiah",
|
||||||
|
"Lamentations",
|
||||||
|
"Ezekiel",
|
||||||
|
"Daniel",
|
||||||
|
"Hosea",
|
||||||
|
"Joel",
|
||||||
|
"Amos",
|
||||||
|
"Obadiah",
|
||||||
|
"Jonah",
|
||||||
|
"Micah",
|
||||||
|
"Nahum",
|
||||||
|
"Habakkuk",
|
||||||
|
"Zephaniah",
|
||||||
|
"Haggai",
|
||||||
|
"Zechariah",
|
||||||
|
"Malachi",
|
||||||
|
# New Testament
|
||||||
|
"Matthew",
|
||||||
|
"Mark",
|
||||||
|
"Luke",
|
||||||
|
"John",
|
||||||
|
"Acts",
|
||||||
|
"Romans",
|
||||||
|
"1 Corinthians",
|
||||||
|
"2 Corinthians",
|
||||||
|
"Galatians",
|
||||||
|
"Ephesians",
|
||||||
|
"Philippians",
|
||||||
|
"Colossians",
|
||||||
|
"1 Thessalonians",
|
||||||
|
"2 Thessalonians",
|
||||||
|
"1 Timothy",
|
||||||
|
"2 Timothy",
|
||||||
|
"Titus",
|
||||||
|
"Philemon",
|
||||||
|
"Hebrews",
|
||||||
|
"James",
|
||||||
|
"1 Peter",
|
||||||
|
"2 Peter",
|
||||||
|
"1 John",
|
||||||
|
"2 John",
|
||||||
|
"3 John",
|
||||||
|
"Jude",
|
||||||
|
"Revelation",
|
||||||
|
]
|
||||||
|
|
||||||
|
BIBLE_BOOK_CHAPTERS = {
|
||||||
|
# Old Testament
|
||||||
|
"Genesis": 50,
|
||||||
|
"Exodus": 40,
|
||||||
|
"Leviticus": 27,
|
||||||
|
"Numbers": 36,
|
||||||
|
"Deuteronomy": 34,
|
||||||
|
"Joshua": 24,
|
||||||
|
"Judges": 21,
|
||||||
|
"Ruth": 4,
|
||||||
|
"1 Samuel": 31,
|
||||||
|
"2 Samuel": 24,
|
||||||
|
"1 Kings": 22,
|
||||||
|
"2 Kings": 25,
|
||||||
|
"1 Chronicles": 29,
|
||||||
|
"2 Chronicles": 36,
|
||||||
|
"Ezra": 10,
|
||||||
|
"Nehemiah": 13,
|
||||||
|
"Esther": 10,
|
||||||
|
"Job": 42,
|
||||||
|
"Psalms": 150,
|
||||||
|
"Proverbs": 31,
|
||||||
|
"Ecclesiastes": 12,
|
||||||
|
"Song of Solomon": 8,
|
||||||
|
"Isaiah": 66,
|
||||||
|
"Jeremiah": 52,
|
||||||
|
"Lamentations": 5,
|
||||||
|
"Ezekiel": 48,
|
||||||
|
"Daniel": 12,
|
||||||
|
"Hosea": 14,
|
||||||
|
"Joel": 3,
|
||||||
|
"Amos": 9,
|
||||||
|
"Obadiah": 1,
|
||||||
|
"Jonah": 4,
|
||||||
|
"Micah": 7,
|
||||||
|
"Nahum": 3,
|
||||||
|
"Habakkuk": 3,
|
||||||
|
"Zephaniah": 3,
|
||||||
|
"Haggai": 2,
|
||||||
|
"Zechariah": 14,
|
||||||
|
"Malachi": 4,
|
||||||
|
# New Testament
|
||||||
|
"Matthew": 28,
|
||||||
|
"Mark": 16,
|
||||||
|
"Luke": 24,
|
||||||
|
"John": 21,
|
||||||
|
"Acts": 28,
|
||||||
|
"Romans": 16,
|
||||||
|
"1 Corinthians": 16,
|
||||||
|
"2 Corinthians": 13,
|
||||||
|
"Galatians": 6,
|
||||||
|
"Ephesians": 6,
|
||||||
|
"Philippians": 4,
|
||||||
|
"Colossians": 4,
|
||||||
|
"1 Thessalonians": 5,
|
||||||
|
"2 Thessalonians": 3,
|
||||||
|
"1 Timothy": 6,
|
||||||
|
"2 Timothy": 4,
|
||||||
|
"Titus": 3,
|
||||||
|
"Philemon": 1,
|
||||||
|
"Hebrews": 13,
|
||||||
|
"James": 5,
|
||||||
|
"1 Peter": 5,
|
||||||
|
"2 Peter": 3,
|
||||||
|
"1 John": 5,
|
||||||
|
"2 John": 1,
|
||||||
|
"3 John": 1,
|
||||||
|
"Jude": 1,
|
||||||
|
"Revelation": 22,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Add a new endpoint to get chapter count
|
||||||
|
@app.get("/chapters/{book}")
|
||||||
|
async def get_chapter_count(book: str):
|
||||||
|
if book in BIBLE_BOOK_CHAPTERS:
|
||||||
|
return {"chapters": BIBLE_BOOK_CHAPTERS[book]}
|
||||||
|
return {"chapters": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def home(request: Request):
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"index.html",
|
||||||
|
{
|
||||||
|
"request": request,
|
||||||
|
"bible_books": BIBLE_BOOKS,
|
||||||
|
"current_book": "Genesis",
|
||||||
|
"current_chapter": 1,
|
||||||
|
"current_verse": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/verse/{book}/{chapter}/{verse}")
|
||||||
|
async def get_verse(book: str, chapter: int, verse: int):
|
||||||
|
# Validate book and chapter
|
||||||
|
if book not in BIBLE_BOOK_CHAPTERS:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid book name")
|
||||||
|
|
||||||
|
if chapter < 1 or chapter > BIBLE_BOOK_CHAPTERS[book]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid chapter. {book} has {BIBLE_BOOK_CHAPTERS[book]} chapters",
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
For {book} {chapter}:{verse}, provide:
|
||||||
|
1. The ESV Bible text
|
||||||
|
2. Analysis of the verse
|
||||||
|
|
||||||
|
Return in this exact format:
|
||||||
|
{{
|
||||||
|
"book": "{book}",
|
||||||
|
"chapter": {chapter},
|
||||||
|
"verse": {verse},
|
||||||
|
"text": "The ESV Bible text",
|
||||||
|
"historical_context": "brief historical background",
|
||||||
|
"theological_significance": "main theological points",
|
||||||
|
"practical_application": "how to apply this verse today",
|
||||||
|
"cross_references": [
|
||||||
|
{{
|
||||||
|
"verse_reference": "Book Chapter:Verse",
|
||||||
|
"explanation": "why this verse is related",
|
||||||
|
"relevance": "how it connects to the main verse"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = sm.generate_data(prompt, response_model=BibleVerseAnalysis)
|
||||||
|
|
||||||
|
return data
|
||||||
+21
-2
@@ -1,10 +1,29 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "simplemind"
|
name = "simplemind"
|
||||||
version = "0.1.7"
|
version = "0.3.3"
|
||||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"]
|
dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
full = [
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"groq",
|
||||||
|
"google-generativeai",
|
||||||
|
"botocore",
|
||||||
|
"boto3"
|
||||||
|
]
|
||||||
|
amazon = ["boto3", "botocore", "anthropic"]
|
||||||
|
anthropic = ["anthropic"]
|
||||||
|
gemini = ["google-generativeai", "jsonref"]
|
||||||
|
groq = ["groq"]
|
||||||
|
ollama = ["openai"]
|
||||||
|
openai = ["openai"]
|
||||||
|
xai = ["openai"]
|
||||||
|
deepseek = ["openai"]
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|||||||
+41
-5
@@ -1,4 +1,5 @@
|
|||||||
from typing import List, Type
|
import inspect
|
||||||
|
from typing import Callable, List, Type
|
||||||
|
|
||||||
from .models import BaseModel, BasePlugin, Conversation
|
from .models import BaseModel, BasePlugin, Conversation
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
@@ -64,16 +65,16 @@ def create_conversation(
|
|||||||
"""Create a new conversation."""
|
"""Create a new conversation."""
|
||||||
|
|
||||||
# Create the conversation.
|
# Create the conversation.
|
||||||
conversation = Conversation(
|
conv = Conversation(
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add plugins to the conversation.
|
# Add plugins to the conversation.
|
||||||
for plugin in plugins or []:
|
for plugin in plugins or []:
|
||||||
conversation.add_plugin(plugin)
|
conv.add_plugin(plugin)
|
||||||
|
|
||||||
return conversation
|
return conv
|
||||||
|
|
||||||
|
|
||||||
def generate_data(
|
def generate_data(
|
||||||
@@ -94,6 +95,7 @@ def generate_data(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,6 +104,7 @@ def generate_text(
|
|||||||
*,
|
*,
|
||||||
llm_model: str | None = None,
|
llm_model: str | None = None,
|
||||||
llm_provider: str | None = None,
|
llm_provider: str | None = None,
|
||||||
|
stream: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate text from a given prompt."""
|
"""Generate text from a given prompt."""
|
||||||
@@ -110,13 +113,45 @@ def generate_text(
|
|||||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||||
|
|
||||||
# Generate the text.
|
# Generate the text.
|
||||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
if stream:
|
||||||
|
if not provider.supports_streaming:
|
||||||
|
raise ValueError(f"{provider} does not support streaming.")
|
||||||
|
|
||||||
|
return provider.generate_stream_text(
|
||||||
|
prompt=prompt, llm_model=llm_model, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def enable_logfire() -> None:
|
def enable_logfire() -> None:
|
||||||
"""Enable logfire logging."""
|
"""Enable logfire logging."""
|
||||||
settings.logging.enable_logfire()
|
settings.logging.enable_logfire()
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
llm_provider: str | None = None,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
):
|
||||||
|
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
res = generate_data(
|
||||||
|
(
|
||||||
|
"Based on this function signature, fill up the required fieds."
|
||||||
|
f"\nSignature: {func.__name__}{sig}"
|
||||||
|
"Make sure to properly add the required field in `required` if there are no defaults"
|
||||||
|
),
|
||||||
|
llm_provider=llm_provider,
|
||||||
|
response_model=provider.tool,
|
||||||
|
)
|
||||||
|
res.raw_func = func
|
||||||
|
res.__signature__ = sig
|
||||||
|
res.__doc__ = func.__doc__
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
# Syntax sugar.
|
# Syntax sugar.
|
||||||
Plugin = BasePlugin
|
Plugin = BasePlugin
|
||||||
@@ -131,4 +166,5 @@ __all__ = [
|
|||||||
"Session",
|
"Session",
|
||||||
"Plugin",
|
"Plugin",
|
||||||
"enable_logfire",
|
"enable_logfire",
|
||||||
|
"tool"
|
||||||
]
|
]
|
||||||
|
|||||||
+47
-11
@@ -1,10 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from os import PathLike
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .providers._base_tools import BaseTool
|
||||||
from .utils import find_provider
|
from .utils import find_provider
|
||||||
|
|
||||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||||
@@ -28,6 +30,10 @@ class BasePlugin(SMBaseModel):
|
|||||||
# Plugin metadata.
|
# Plugin metadata.
|
||||||
meta: Dict[str, Any] = {}
|
meta: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
# allow_arbitrary_types = True
|
||||||
|
|
||||||
def initialize_hook(self, conversation: "Conversation") -> Any:
|
def initialize_hook(self, conversation: "Conversation") -> Any:
|
||||||
"""Initialize a hook for the plugin."""
|
"""Initialize a hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -36,7 +42,9 @@ class BasePlugin(SMBaseModel):
|
|||||||
"""Cleanup a hook for the plugin."""
|
"""Cleanup a hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
def add_message_hook(
|
||||||
|
self, conversation: "Conversation", message: "Message"
|
||||||
|
) -> Any:
|
||||||
"""Add a message hook for the plugin."""
|
"""Add a message hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -44,7 +52,9 @@ class BasePlugin(SMBaseModel):
|
|||||||
"""Pre-send hook for the plugin."""
|
"""Pre-send hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
def post_send_hook(
|
||||||
|
self, conversation: "Conversation", response: "Message"
|
||||||
|
) -> Any:
|
||||||
"""Post-send hook for the plugin."""
|
"""Post-send hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -55,7 +65,7 @@ class Message(SMBaseModel):
|
|||||||
role: MESSAGE_ROLE
|
role: MESSAGE_ROLE
|
||||||
text: str
|
text: str
|
||||||
meta: Dict[str, Any] = {}
|
meta: Dict[str, Any] = {}
|
||||||
raw: Optional[Any] = None
|
raw: Optional[Any] = Field(default=None, exclude=True)
|
||||||
llm_model: Optional[str] = None
|
llm_model: Optional[str] = None
|
||||||
llm_provider: Optional[str] = None
|
llm_provider: Optional[str] = None
|
||||||
|
|
||||||
@@ -86,7 +96,7 @@ class Conversation(SMBaseModel):
|
|||||||
messages: List[Message] = []
|
messages: List[Message] = []
|
||||||
llm_model: Optional[str] = None
|
llm_model: Optional[str] = None
|
||||||
llm_provider: Optional[str] = None
|
llm_provider: Optional[str] = None
|
||||||
plugins: List[BasePlugin] = []
|
plugins: List[BasePlugin] = Field(default_factory=list, exclude=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"<Conversation id={self.id!r}>"
|
return f"<Conversation id={self.id!r}>"
|
||||||
@@ -117,16 +127,24 @@ class Conversation(SMBaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def prepend_system_message(
|
def prepend_system_message(
|
||||||
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
|
self, text: str, meta: Dict[str, Any] | None = None
|
||||||
):
|
):
|
||||||
"""Prepend a system message to the conversation."""
|
"""Prepend a system message to the conversation."""
|
||||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
self.messages = [
|
||||||
|
Message(role="system", text=text, meta=meta or {})
|
||||||
|
] + self.messages
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
role: MESSAGE_ROLE = "user",
|
||||||
|
text: str | None = None,
|
||||||
|
*,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""Add a new message to the conversation."""
|
"""Add a new message to the conversation."""
|
||||||
|
|
||||||
|
assert text is not None
|
||||||
|
|
||||||
# Ensure meta is a dict.
|
# Ensure meta is a dict.
|
||||||
if meta is None:
|
if meta is None:
|
||||||
meta = {}
|
meta = {}
|
||||||
@@ -148,9 +166,12 @@ class Conversation(SMBaseModel):
|
|||||||
self,
|
self,
|
||||||
llm_model: str | None = None,
|
llm_model: str | None = None,
|
||||||
llm_provider: str | None = None,
|
llm_provider: str | None = None,
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""Send the conversation to the LLM."""
|
"""Send the conversation to the LLM."""
|
||||||
|
|
||||||
|
# TODO: llm_model and llm_provider should override the conversation's.
|
||||||
|
|
||||||
# Execute all pre send hooks.
|
# Execute all pre send hooks.
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
if hasattr(plugin, "pre_send_hook"):
|
if hasattr(plugin, "pre_send_hook"):
|
||||||
@@ -161,7 +182,7 @@ class Conversation(SMBaseModel):
|
|||||||
|
|
||||||
# Find the provider and send the conversation.
|
# Find the provider and send the conversation.
|
||||||
provider = find_provider(llm_provider or self.llm_provider)
|
provider = find_provider(llm_provider or self.llm_provider)
|
||||||
response = provider.send_conversation(self)
|
response = provider.send_conversation(self, tools=tools)
|
||||||
|
|
||||||
# Execute all post-send hooks.
|
# Execute all post-send hooks.
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
@@ -172,14 +193,29 @@ class Conversation(SMBaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Add the response to the conversation.
|
# Add the response to the conversation.
|
||||||
self.add_message(role="assistant", text=response.text, meta=response.meta)
|
self.add_message(
|
||||||
|
role="assistant", text=response.text, meta=response.meta
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||||
"""Get the last message with the given role."""
|
"""Get the last message with the given role."""
|
||||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
return next(
|
||||||
|
(m for m in reversed(self.messages) if m.role == role), None
|
||||||
|
)
|
||||||
|
|
||||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||||
"""Add a plugin to the conversation."""
|
"""Add a plugin to the conversation."""
|
||||||
self.plugins.append(plugin)
|
self.plugins.append(plugin)
|
||||||
|
|
||||||
|
def save(self, path: PathLike | str) -> None:
|
||||||
|
"""Save the conversation to a JSON file."""
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(self.model_dump_json())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: PathLike | str) -> "Conversation":
|
||||||
|
"""Load a conversation from a JSON file."""
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return cls.model_validate_json(f.read())
|
||||||
|
|||||||
@@ -1,11 +1,37 @@
|
|||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
from ._base_tools import BaseTool
|
||||||
|
from .amazon import Amazon
|
||||||
from .anthropic import Anthropic
|
from .anthropic import Anthropic
|
||||||
from .gemini import Gemini
|
from .gemini import Gemini
|
||||||
from .groq import Groq
|
from .groq import Groq
|
||||||
from .ollama import Ollama
|
from .ollama import Ollama
|
||||||
from .openai import OpenAI
|
from .openai import OpenAI
|
||||||
from .xai import XAI
|
from .xai import XAI
|
||||||
|
from .deepseek import Deepseek
|
||||||
|
|
||||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
|
providers: List[Type[BaseProvider]] = [
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
Groq,
|
||||||
|
OpenAI,
|
||||||
|
Ollama,
|
||||||
|
XAI,
|
||||||
|
Amazon,
|
||||||
|
Deepseek,
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Anthropic",
|
||||||
|
"Gemini",
|
||||||
|
"Groq",
|
||||||
|
"OpenAI",
|
||||||
|
"Ollama",
|
||||||
|
"XAI",
|
||||||
|
"Amazon",
|
||||||
|
"providers",
|
||||||
|
"BaseProvider",
|
||||||
|
"BaseTool",
|
||||||
|
"Deepseek"
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
|
||||||
|
|
||||||
from instructor import Instructor
|
from instructor import Instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from simplemind.providers._base_tools import BaseTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import Conversation, Message
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
@@ -16,6 +18,8 @@ class BaseProvider(ABC):
|
|||||||
|
|
||||||
NAME: str
|
NAME: str
|
||||||
DEFAULT_MODEL: str
|
DEFAULT_MODEL: str
|
||||||
|
supports_streaming: bool = False
|
||||||
|
supports_structured_responses: bool = True
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -30,16 +34,41 @@ class BaseProvider(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
def send_conversation(
|
||||||
|
self,
|
||||||
|
conversation: "Conversation",
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
|
) -> "Message":
|
||||||
"""Send a conversation to the provider."""
|
"""Send a conversation to the provider."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
def structured_response(
|
||||||
|
self, prompt: str, response_model: Type[T], **kwargs
|
||||||
|
) -> T:
|
||||||
"""Get a structured response."""
|
"""Get a structured response."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
def generate_text(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
"""Generate text from a prompt."""
|
"""Generate text from a prompt."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
@abstractmethod
|
||||||
|
def tool(self) -> Type[BaseTool]:
|
||||||
|
"""The tool implementation for the provider."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def make_tools(self, tools: list[Callable | BaseTool] | None):
|
||||||
|
if tools is not None:
|
||||||
|
return [self.tool.from_function(func) for func in tools]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
import inspect
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, ClassVar, Literal, get_origin
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
|
|
||||||
|
def _is_literal(t: Any) -> bool:
|
||||||
|
return get_origin(t) is Literal
|
||||||
|
|
||||||
|
|
||||||
|
def _is_required(field, func_signature, arg_name) -> bool:
|
||||||
|
param = func_signature.parameters[arg_name]
|
||||||
|
# If parameter has a default value that's not a FieldInfo, it's not required
|
||||||
|
if param.default is not inspect.Parameter.empty and not isinstance(
|
||||||
|
param.default, FieldInfo
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
# If the field has a default that's not undefined, it's not required
|
||||||
|
return isinstance(field.default, PydanticUndefinedType)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolConfig(BaseModel):
|
||||||
|
TYPE_CONVERSION: dict[type, str] = {
|
||||||
|
str: "string",
|
||||||
|
int: "integer",
|
||||||
|
float: "number",
|
||||||
|
bool: "boolean",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolProperty(BaseModel):
|
||||||
|
type: str = Field(serialization_alias="type_")
|
||||||
|
enum: list[str] | None = None
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(BaseModel, ABC):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
properties: dict[str, BaseToolProperty]
|
||||||
|
required: list[str] | None = None
|
||||||
|
config: ClassVar[BaseToolConfig] = BaseToolConfig()
|
||||||
|
raw_func: Any | None = None
|
||||||
|
tool_id: str | None = None
|
||||||
|
function_result: str | None = None
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
assert self.raw_func is not None
|
||||||
|
return self.raw_func(*args, **kwargs)
|
||||||
|
|
||||||
|
def is_executed(self) -> bool:
|
||||||
|
return self.function_result is not None
|
||||||
|
|
||||||
|
def reset_result(self) -> None:
|
||||||
|
self.function_result = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_type(cls, field_type) -> str:
|
||||||
|
if _is_literal(field_type):
|
||||||
|
return cls.config.TYPE_CONVERSION[str]
|
||||||
|
|
||||||
|
field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None)
|
||||||
|
|
||||||
|
if field_type_converted is None:
|
||||||
|
raise TypeError(f"Field of type {field_type} is not supported")
|
||||||
|
|
||||||
|
return field_type_converted
|
||||||
|
|
||||||
|
def get_properties_schema(self, **kwargs) -> dict[str, dict]:
|
||||||
|
new_kwargs: dict = {"exclude_none": True} | kwargs
|
||||||
|
return {
|
||||||
|
k: v.model_dump(**new_kwargs) for k, v in self.properties.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_function(cls, func: Callable | "BaseTool"):
|
||||||
|
# Check if the func passed is an instace of BaseTool
|
||||||
|
if hasattr(func, "raw_func"):
|
||||||
|
return func
|
||||||
|
|
||||||
|
annotations = getattr(func, "__annotations__", {})
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
enum_values = None
|
||||||
|
func_signature = inspect.signature(func)
|
||||||
|
|
||||||
|
for n, (arg_name, arg_type) in enumerate(annotations.items()):
|
||||||
|
if ( # Skipping 'return' annotation (i.e.```-> str```)
|
||||||
|
arg_name != "return"
|
||||||
|
):
|
||||||
|
# Check if argument has metadata (from Annotated)
|
||||||
|
if hasattr(arg_type, "__metadata__"):
|
||||||
|
field = arg_type.__metadata__[
|
||||||
|
0
|
||||||
|
] # Get Field info from metadata
|
||||||
|
field_type = arg_type.__origin__ # Get actual type
|
||||||
|
# Check if argument has a default value in signature
|
||||||
|
elif (
|
||||||
|
sig_param := func_signature.parameters[arg_name]
|
||||||
|
).default is not inspect.Parameter.empty:
|
||||||
|
field = sig_param.default # Use default as Field
|
||||||
|
field_type = arg_type # Use plain type annotation
|
||||||
|
else:
|
||||||
|
# Raise error if no Field annotation found
|
||||||
|
raise ValueError(
|
||||||
|
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
field_type_converted = cls.convert_type(field_type)
|
||||||
|
|
||||||
|
if _is_literal(field_type):
|
||||||
|
enum_values = [str(x) for x in field_type.__args__]
|
||||||
|
|
||||||
|
properties[arg_name] = BaseToolProperty(
|
||||||
|
type=field_type_converted,
|
||||||
|
description=field.description,
|
||||||
|
enum=enum_values,
|
||||||
|
)
|
||||||
|
if _is_required(field, func_signature, arg_name):
|
||||||
|
required.append(arg_name)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=func.__name__,
|
||||||
|
description=(func.__doc__ or "").strip(),
|
||||||
|
properties=properties,
|
||||||
|
required=required,
|
||||||
|
raw_func=func,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_input_schema(self) -> Any: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, message) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_response_schema(self) -> Any: ...
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
|
import instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class Amazon(BaseProvider):
|
||||||
|
NAME = "amazon"
|
||||||
|
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
DEFAULT_MAX_TOKENS = 5_000
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
|
def __init__(self, profile_name: str | None = None):
|
||||||
|
self.profile_name = profile_name or settings.AMAZON_PROFILE_NAME
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def client(self):
|
||||||
|
"""The AnthropicBedrock client."""
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `anthropic` package: `pip install anthropic`"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if not self.profile_name:
|
||||||
|
raise ValueError("Profile name is not provided")
|
||||||
|
|
||||||
|
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def structured_client(self) -> instructor.Instructor:
|
||||||
|
"""A client patched with Instructor."""
|
||||||
|
|
||||||
|
return instructor.from_anthropic(self.client)
|
||||||
|
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
|
"""Send a conversation to the OpenAI API."""
|
||||||
|
|
||||||
|
from ..models import Message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the response content from the OpenAI response
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
|
||||||
|
# Create and return a properly formatted Message instance
|
||||||
|
return Message(
|
||||||
|
role="assistant",
|
||||||
|
text=assistant_message.content or "",
|
||||||
|
raw=response,
|
||||||
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
|
llm_provider=PROVIDER_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
def structured_response(
|
||||||
|
self, prompt, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||||
|
) -> T:
|
||||||
|
# Ensure messages are provided in kwargs
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.structured_client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
response_model=response_model,
|
||||||
|
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.content[0].text
|
||||||
|
|
||||||
|
def generate_stream_text(
|
||||||
|
self, prompt: str, *, llm_model: str, **kwargs
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Generate streaming text using the Amazon API."""
|
||||||
|
|
||||||
|
# Prepare the messages.
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Send the request to the API.
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield the text chunks.
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.text:
|
||||||
|
yield chunk.text
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
from ..logging import logger
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
from ._base_tools import BaseTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import Conversation, Message
|
from ..models import Conversation, Message
|
||||||
@@ -14,19 +15,67 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "anthropic"
|
class AnthropicTool(BaseTool):
|
||||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
def get_response_schema(self) -> Any:
|
||||||
DEFAULT_MAX_TOKENS = 1_000
|
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
assert isinstance(
|
||||||
|
self.tool_id, str
|
||||||
|
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||||
|
return {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": self.tool_id,
|
||||||
|
"content": self.function_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def handle(self, response, messages) -> None:
|
||||||
|
"""Handle the tool execution result from an API response."""
|
||||||
|
msg = {"role": "assistant", "content": []}
|
||||||
|
tool_used = False
|
||||||
|
for content in response.content:
|
||||||
|
if content.type == "tool_use" and content.name == self.name:
|
||||||
|
msg["content"].append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": content.id,
|
||||||
|
"name": content.name,
|
||||||
|
"input": content.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Function execution:
|
||||||
|
self.function_result = str(self.raw_func(**content.input))
|
||||||
|
self.tool_id = content.id
|
||||||
|
tool_used = True
|
||||||
|
elif content.type == "text":
|
||||||
|
msg["content"].append({"type": "text", "text": content.text})
|
||||||
|
|
||||||
|
if tool_used:
|
||||||
|
messages.append(msg)
|
||||||
|
messages.append(
|
||||||
|
{"role": "user", "content": [self.get_response_schema()]}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_schema(self):
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": self.get_properties_schema(),
|
||||||
|
"required": self.required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(BaseProvider):
|
class Anthropic(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "anthropic"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -48,30 +97,60 @@ class Anthropic(BaseProvider):
|
|||||||
return instructor.from_anthropic(self.client)
|
return instructor.from_anthropic(self.client)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
def send_conversation(
|
||||||
|
self,
|
||||||
|
conversation: "Conversation",
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Message":
|
||||||
"""Send a conversation to the Anthropic API."""
|
"""Send a conversation to the Anthropic API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
messages = [
|
# Format messages from conversation
|
||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
formatted_messages = [
|
||||||
|
{"role": msg.role, "content": msg.text}
|
||||||
|
for msg in conversation.messages
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.messages.create(
|
# Set up tools if provided
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
converted_tools = self.make_tools(tools)
|
||||||
messages=messages,
|
tools_config = (
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
{"tools": [t.get_input_schema() for t in converted_tools]}
|
||||||
|
if tools is not None
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the Anthropic response
|
# Merge all kwargs
|
||||||
assistant_message = response.content[0].text
|
request_kwargs = {
|
||||||
|
**self.DEFAULT_KWARGS,
|
||||||
|
**kwargs,
|
||||||
|
**tools_config,
|
||||||
|
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
|
"messages": formatted_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make initial API call
|
||||||
|
response = self.client.messages.create(**request_kwargs)
|
||||||
|
|
||||||
|
# Handle tool responses if needed
|
||||||
|
while response.content[-1].type != "text":
|
||||||
|
# Continue handling tools if the LLM is doing
|
||||||
|
# multiple sub-seqequent/sequential tool calls
|
||||||
|
for tool in converted_tools:
|
||||||
|
tool.handle(response, formatted_messages)
|
||||||
|
if tool.is_executed():
|
||||||
|
response = self.client.messages.create(**request_kwargs)
|
||||||
|
# Resetting the tool results in case this tool gets used again
|
||||||
|
tool.reset_result()
|
||||||
|
|
||||||
|
final_message = response.content[-1].text
|
||||||
|
|
||||||
# Create and return a properly formatted Message instance
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message,
|
text=final_message,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -107,3 +186,27 @@ class Anthropic(BaseProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(
|
||||||
|
self, prompt: str, *, llm_model: str, **kwargs
|
||||||
|
) -> Iterator[str]:
|
||||||
|
# Prepare the messages.
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the request.
|
||||||
|
with self.client.messages.stream(
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
) as stream:
|
||||||
|
# Yield each chunk of text from the stream.
|
||||||
|
for chunk in stream.text_stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tool(self) -> Type[BaseTool]:
|
||||||
|
"""The tool implementation for Antrhopic."""
|
||||||
|
return AnthropicTool
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from .openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class Deepseek(OpenAI):
|
||||||
|
NAME = "deepseek"
|
||||||
|
DEFAULT_MODEL = "deepseek-chat"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None):
|
||||||
|
api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||||
|
super().__init__(api_key=api_key)
|
||||||
|
self.endpoint = "https://api.deepseek.com/v1"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def client(self):
|
||||||
|
"""The raw OpenAI client."""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("DEEPSEEK API key is required")
|
||||||
|
try:
|
||||||
|
import openai as oa
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `openai` package: `pip install openai`"
|
||||||
|
) from exc
|
||||||
|
return oa.OpenAI(api_key=self.api_key, base_url=self.endpoint)
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
# IT is not currently working as desired.
|
# IT is not currently working as desired.
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -17,17 +17,14 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "gemini"
|
|
||||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
|
||||||
|
|
||||||
|
|
||||||
class Gemini(BaseProvider):
|
class Gemini(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "gemini"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
self.model_name = DEFAULT_MODEL
|
self.model_name = self.DEFAULT_MODEL
|
||||||
|
|
||||||
def set_model(self, model_name: str):
|
def set_model(self, model_name: str):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@@ -75,7 +72,7 @@ class Gemini(BaseProvider):
|
|||||||
text=response.text,
|
text=response.text,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=self.model_name,
|
llm_model=self.model_name,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -107,3 +104,17 @@ class Gemini(BaseProvider):
|
|||||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||||
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
|
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(self, prompt: str, **kwargs) -> Iterator[str]:
|
||||||
|
"""Generate streaming text using the Gemini API."""
|
||||||
|
kwargs.pop("llm_model", None)
|
||||||
|
try:
|
||||||
|
response = self.client.generate_content(prompt, stream=True, **kwargs)
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.text:
|
||||||
|
yield chunk.text
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to generate streaming text with Gemini API: {e}"
|
||||||
|
) from e
|
||||||
|
|||||||
+151
-20
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
from ..logging import logger
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
from ._base_tools import BaseTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import Conversation, Message
|
from ..models import Conversation, Message
|
||||||
@@ -14,19 +15,87 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "groq"
|
class GroqTool(BaseTool):
|
||||||
DEFAULT_MODEL = "llama3-8b-8192"
|
def get_response_schema(self):
|
||||||
DEFAULT_MAX_TOKENS = 1_000
|
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
assert isinstance(
|
||||||
|
self.tool_id, str
|
||||||
|
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": self.tool_id,
|
||||||
|
"content": self.function_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def handle(self, response, messages) -> None:
|
||||||
|
"""Handle the tool execution result from an API response."""
|
||||||
|
tool_used = False
|
||||||
|
|
||||||
|
# Get the message from the response
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
|
||||||
|
# Check if there's a tool call
|
||||||
|
if assistant_message.tool_calls:
|
||||||
|
tool_call = assistant_message.tool_calls[
|
||||||
|
0
|
||||||
|
] # Get the first tool call
|
||||||
|
if tool_call.function.name == self.name:
|
||||||
|
# Execute the function
|
||||||
|
import json
|
||||||
|
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
self.function_result = str(self.raw_func(**function_args))
|
||||||
|
self.tool_id = tool_call.id
|
||||||
|
tool_used = True
|
||||||
|
|
||||||
|
# Add assistant's message with tool call
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_used:
|
||||||
|
# Add tool response message
|
||||||
|
messages.append(self.get_response_schema())
|
||||||
|
|
||||||
|
def get_input_schema(self):
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": self.get_properties_schema(),
|
||||||
|
"required": self.required,
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
class Groq(BaseProvider):
|
class Groq(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "groq"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "llama3-8b-8192"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -50,31 +119,59 @@ class Groq(BaseProvider):
|
|||||||
def send_conversation(
|
def send_conversation(
|
||||||
self,
|
self,
|
||||||
conversation: "Conversation",
|
conversation: "Conversation",
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "Message":
|
) -> "Message":
|
||||||
"""Send a conversation to the Groq API."""
|
"""Send a conversation to the Groq API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
messages = [
|
# Format messages from conversation
|
||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
formatted_messages = [
|
||||||
|
{"role": msg.role, "content": msg.text}
|
||||||
|
for msg in conversation.messages
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
# Set up tools if provided
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
converted_tools = self.make_tools(tools)
|
||||||
messages=messages,
|
tools_config = (
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the Groq response
|
# Merge all kwargs
|
||||||
assistant_message = response.choices[0].message
|
request_kwargs = {
|
||||||
|
**self.DEFAULT_KWARGS,
|
||||||
|
**kwargs,
|
||||||
|
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
|
"messages": formatted_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools_config:
|
||||||
|
request_kwargs["tools"] = tools_config
|
||||||
|
|
||||||
|
# Make initial API call
|
||||||
|
response = self.client.chat.completions.create(**request_kwargs)
|
||||||
|
|
||||||
|
# Handle tool responses if needed
|
||||||
|
while response.choices[0].message.tool_calls:
|
||||||
|
print(response)
|
||||||
|
# Handle each tool call
|
||||||
|
for tool in converted_tools:
|
||||||
|
tool.handle(response, formatted_messages)
|
||||||
|
if tool.is_executed():
|
||||||
|
# Make another API call with the updated messages
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
**request_kwargs
|
||||||
|
)
|
||||||
|
tool.reset_result()
|
||||||
|
|
||||||
|
final_message = response.choices[0].message.content
|
||||||
|
|
||||||
# Create and return a properly formatted Message instance
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message.content or "",
|
text=final_message or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -111,3 +208,37 @@ class Groq(BaseProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return str(response.choices[0].message.content)
|
return str(response.choices[0].message.content)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Generate streaming text using the Groq API."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
stream=True,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to generate streaming text with Groq API: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tool(self) -> Type[BaseTool]:
|
||||||
|
"""The tool implementation for Groq."""
|
||||||
|
return GroqTool
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from openai import OpenAI
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..logging import logger
|
from ..logging import logger
|
||||||
@@ -15,17 +14,12 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "ollama"
|
|
||||||
DEFAULT_MODEL = "llama3.2"
|
|
||||||
DEFAULT_TIMEOUT = 60
|
|
||||||
DEFAULT_KWARGS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class Ollama(BaseProvider):
|
class Ollama(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "ollama"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "llama3.2"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_TIMEOUT = 60
|
||||||
TIMEOUT = DEFAULT_TIMEOUT
|
DEFAULT_KWARGS = {}
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, host_url: str | None = None):
|
def __init__(self, host_url: str | None = None):
|
||||||
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
||||||
@@ -36,21 +30,18 @@ class Ollama(BaseProvider):
|
|||||||
if not self.host_url:
|
if not self.host_url:
|
||||||
raise ValueError("No ollama host url provided")
|
raise ValueError("No ollama host url provided")
|
||||||
try:
|
try:
|
||||||
import ollama as ol
|
import openai
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install the `ollama` package: `pip install ollama`"
|
"Please install the `openai` package: `pip install openai`"
|
||||||
) from exc
|
) from exc
|
||||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def structured_client(self) -> instructor.Instructor:
|
def structured_client(self) -> instructor.Instructor:
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_openai(
|
return instructor.from_openai(
|
||||||
OpenAI(
|
self.client,
|
||||||
base_url=f"{self.host_url}/v1",
|
|
||||||
api_key="ollama",
|
|
||||||
),
|
|
||||||
mode=instructor.Mode.JSON,
|
mode=instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,20 +53,24 @@ class Ollama(BaseProvider):
|
|||||||
messages = [
|
messages = [
|
||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
]
|
]
|
||||||
response = self.client.chat(
|
|
||||||
model=conversation.llm_model or DEFAULT_MODEL,
|
request_kwargs = {
|
||||||
messages=messages,
|
**self.DEFAULT_KWARGS,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**kwargs,
|
||||||
)
|
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
assistant_message = response.get("message")
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**request_kwargs)
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
|
||||||
# Create and return a properly formatted Message instance
|
# Create and return a properly formatted Message instance
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message.get("content"),
|
text=assistant_message.content or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -109,10 +104,31 @@ class Ollama(BaseProvider):
|
|||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("message", {}).get("content", "")
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(
|
||||||
|
self, prompt: str, *, llm_model: str, **kwargs
|
||||||
|
) -> Iterator[str]:
|
||||||
|
# Prepare the messages.
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
stream=True,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate over the response and yield the content.
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|||||||
+182
-24
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -7,25 +7,94 @@ from pydantic import BaseModel
|
|||||||
from ..logging import logger
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
from ._base_tools import BaseTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import Conversation, Message
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
PROVIDER_NAME = "openai"
|
|
||||||
DEFAULT_MODEL = "gpt-4o-mini"
|
class OpenAITool(BaseTool):
|
||||||
DEFAULT_MAX_TOKENS = 1_000
|
def get_response_schema(self):
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||||
|
assert isinstance(
|
||||||
|
self.tool_id, str
|
||||||
|
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": self.tool_id,
|
||||||
|
"content": self.function_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def handle(self, response, messages) -> None:
|
||||||
|
"""Handle the tool execution result from an API response."""
|
||||||
|
tool_used = False
|
||||||
|
|
||||||
|
# Get the message from the response
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
|
||||||
|
# Check if there's a tool call
|
||||||
|
if assistant_message.tool_calls:
|
||||||
|
tool_call = assistant_message.tool_calls[0] # Get the first tool call
|
||||||
|
if tool_call.function.name == self.name:
|
||||||
|
# Execute the function
|
||||||
|
import json
|
||||||
|
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
self.function_result = str(self.raw_func(**function_args))
|
||||||
|
self.tool_id = tool_call.id
|
||||||
|
tool_used = True
|
||||||
|
|
||||||
|
# Add assistant's message with tool call
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_used:
|
||||||
|
# Add tool response message
|
||||||
|
messages.append(self.get_response_schema())
|
||||||
|
|
||||||
|
def get_input_schema(self):
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": self.get_properties_schema(),
|
||||||
|
"required": self.required,
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseProvider):
|
class OpenAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "openai"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "gpt-4o-mini"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = None
|
||||||
|
DEFAULT_KWARGS = {}
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -46,30 +115,58 @@ class OpenAI(BaseProvider):
|
|||||||
return instructor.from_openai(self.client)
|
return instructor.from_openai(self.client)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
def send_conversation(
|
||||||
|
self,
|
||||||
|
conversation: "Conversation",
|
||||||
|
tools: list[Callable | BaseTool] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
messages = [
|
# Format messages from conversation
|
||||||
|
formatted_messages = [
|
||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
# Set up tools if provided
|
||||||
model=conversation.llm_model or DEFAULT_MODEL,
|
converted_tools = self.make_tools(tools)
|
||||||
messages=messages,
|
tools_config = (
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the OpenAI response
|
# Merge all kwargs
|
||||||
assistant_message = response.choices[0].message
|
request_kwargs = {
|
||||||
|
**self.DEFAULT_KWARGS,
|
||||||
|
**kwargs,
|
||||||
|
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
|
"messages": formatted_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools_config:
|
||||||
|
request_kwargs["tools"] = tools_config
|
||||||
|
|
||||||
|
# Make initial API call
|
||||||
|
response = self.client.chat.completions.create(**request_kwargs)
|
||||||
|
|
||||||
|
# Handle tool responses if needed
|
||||||
|
while response.choices[0].message.tool_calls:
|
||||||
|
# Handle each tool call
|
||||||
|
for tool in converted_tools:
|
||||||
|
tool.handle(response, formatted_messages)
|
||||||
|
if tool.is_executed():
|
||||||
|
# Make another API call with the updated messages
|
||||||
|
response = self.client.chat.completions.create(**request_kwargs)
|
||||||
|
tool.reset_result()
|
||||||
|
|
||||||
|
final_message = response.choices[0].message.content
|
||||||
|
|
||||||
# Create and return a properly formatted Message instance
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message.content or "",
|
text=final_message or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -79,13 +176,21 @@ class OpenAI(BaseProvider):
|
|||||||
response_model: Type[T],
|
response_model: Type[T],
|
||||||
*,
|
*,
|
||||||
llm_model: str | None = None,
|
llm_model: str | None = None,
|
||||||
|
image_url: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Get a structured response from the OpenAI API."""
|
"""Get a structured response from the OpenAI API."""
|
||||||
# Ensure messages are provided in kwargs
|
# Ensure messages are provided in kwargs
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
|
if image_url:
|
||||||
|
messages[0]["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
)
|
||||||
|
|
||||||
response = self.structured_client.chat.completions.create(
|
response = self.structured_client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
@@ -95,14 +200,67 @@ class OpenAI(BaseProvider):
|
|||||||
return response_model.model_validate(response)
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
def generate_text(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
image_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""Generate text using the OpenAI API."""
|
"""Generate text using the OpenAI API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
|
if image_url:
|
||||||
|
messages[0]["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
)
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
image_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""Generate streaming text using the OpenAI API.
|
||||||
|
|
||||||
|
Yields chunks of text as they are generated by the model.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||||
|
if image_url:
|
||||||
|
messages[0]["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
stream=True, # Enable streaming
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tool(self) -> Type[BaseTool]:
|
||||||
|
"""The tool implementation for OpenAI."""
|
||||||
|
return OpenAITool
|
||||||
|
|||||||
+35
-14
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,20 +14,17 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "xai"
|
|
||||||
DEFAULT_MODEL = "grok-beta"
|
|
||||||
BASE_URL = "https://api.x.ai/v1"
|
|
||||||
DEFAULT_MAX_TOKENS = 1000
|
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
|
||||||
|
|
||||||
|
|
||||||
class XAI(BaseProvider):
|
class XAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "xai"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "grok-beta"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
BASE_URL = "https://api.x.ai/v1"
|
||||||
|
supports_streaming = True
|
||||||
|
supports_structured_responses = False
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -42,7 +39,7 @@ class XAI(BaseProvider):
|
|||||||
) from exc
|
) from exc
|
||||||
return oa.OpenAI(
|
return oa.OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=BASE_URL,
|
base_url=self.BASE_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -74,7 +71,7 @@ class XAI(BaseProvider):
|
|||||||
text=assistant_message.content,
|
text=assistant_message.content,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
@@ -85,14 +82,38 @@ class XAI(BaseProvider):
|
|||||||
|
|
||||||
@logger
|
@logger
|
||||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||||
|
# Prepare the messages.
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Make the request.
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return the response content.
|
||||||
return str(response.choices[0].message.content)
|
return str(response.choices[0].message.content)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_stream_text(
|
||||||
|
self, prompt: str, *, llm_model: str, **kwargs
|
||||||
|
) -> Iterator[str]:
|
||||||
|
# Prepare the messages.
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the request.
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
stream=True,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate over the response and yield the content.
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ class LoggingConfig(BaseSettings):
|
|||||||
"""The class that holds all the logging settings for the application."""
|
"""The class that holds all the logging settings for the application."""
|
||||||
|
|
||||||
is_enabled: bool = Field(False, description="Enable logging")
|
is_enabled: bool = Field(False, description="Enable logging")
|
||||||
|
|
||||||
model_config = SettingsConfigDict(extra="forbid")
|
model_config = SettingsConfigDict(extra="forbid")
|
||||||
|
|
||||||
def enable_logfire(self, **kwargs) -> None:
|
def enable_logfire(self, **kwargs) -> None:
|
||||||
"""Enable logging for the application."""
|
"""Enable logging for the application."""
|
||||||
# adding imports here to avoid forced dependencies
|
# adding imports here to avoid forced dependencies
|
||||||
try:
|
try:
|
||||||
import logfire
|
|
||||||
from logging import basicConfig
|
from logging import basicConfig
|
||||||
|
|
||||||
|
import logfire
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"To enable logging, please install logfire: `pip install logfire`"
|
"To enable logging, please install logfire: `pip install logfire`"
|
||||||
@@ -41,6 +41,9 @@ class LoggingConfig(BaseSettings):
|
|||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""The class that holds all the API keys for the application."""
|
"""The class that holds all the API keys for the application."""
|
||||||
|
|
||||||
|
AMAZON_PROFILE_NAME: Optional[str] = Field(
|
||||||
|
"default", description="AWS Named Profile"
|
||||||
|
)
|
||||||
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
|
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
|
||||||
None, description="API key for Anthropic"
|
None, description="API key for Anthropic"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import simplemind as sm
|
||||||
|
from simplemind.models import BasePlugin, Conversation
|
||||||
|
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
OpenAI,
|
||||||
|
Groq,
|
||||||
|
Ollama,
|
||||||
|
# Amazon
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_data(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="hey")
|
||||||
|
data = conv.send()
|
||||||
|
|
||||||
|
assert isinstance(data.text, str)
|
||||||
|
assert len(data.text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_conversation():
|
||||||
|
"""Create a sample conversation for testing."""
|
||||||
|
conv = Conversation(llm_provider="openai")
|
||||||
|
conv.add_message(role="user", text="Hello!")
|
||||||
|
conv.add_message(role="assistant", text="Hi there!")
|
||||||
|
conv.add_message(role="user", text="How are you?")
|
||||||
|
return conv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_json_file(tmp_path):
|
||||||
|
"""Create a temporary file path for testing."""
|
||||||
|
return tmp_path / "conversation.json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_conversation(sample_conversation, temp_json_file):
|
||||||
|
"""Test saving a conversation to a JSON file."""
|
||||||
|
sample_conversation.save(temp_json_file)
|
||||||
|
|
||||||
|
assert temp_json_file.exists()
|
||||||
|
|
||||||
|
with open(temp_json_file) as f:
|
||||||
|
saved_data = json.load(f)
|
||||||
|
|
||||||
|
assert "id" in saved_data
|
||||||
|
assert "messages" in saved_data
|
||||||
|
assert "llm_model" in saved_data
|
||||||
|
assert "llm_provider" in saved_data
|
||||||
|
|
||||||
|
assert len(saved_data["messages"]) == 3
|
||||||
|
assert saved_data["messages"][0]["text"] == "Hello!"
|
||||||
|
assert saved_data["messages"][1]["text"] == "Hi there!"
|
||||||
|
assert saved_data["messages"][2]["text"] == "How are you?"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_conversation(sample_conversation, temp_json_file):
|
||||||
|
"""Test loading a conversation from a JSON file."""
|
||||||
|
sample_conversation.save(temp_json_file)
|
||||||
|
|
||||||
|
loaded_conv = Conversation.load(temp_json_file)
|
||||||
|
|
||||||
|
assert loaded_conv.id == sample_conversation.id
|
||||||
|
assert loaded_conv.llm_model == sample_conversation.llm_model
|
||||||
|
assert loaded_conv.llm_provider == sample_conversation.llm_provider
|
||||||
|
assert len(loaded_conv.messages) == len(sample_conversation.messages)
|
||||||
|
|
||||||
|
for original_msg, loaded_msg in zip(
|
||||||
|
sample_conversation.messages, loaded_conv.messages
|
||||||
|
):
|
||||||
|
assert loaded_msg.role == original_msg.role
|
||||||
|
assert loaded_msg.text == original_msg.text
|
||||||
|
assert loaded_msg.meta == original_msg.meta
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_with_plugins(sample_conversation, temp_json_file):
|
||||||
|
"""Test that plugins are properly excluded from serialization."""
|
||||||
|
|
||||||
|
# Create a dummy plugin
|
||||||
|
class DummyPlugin(BasePlugin):
|
||||||
|
def initialize_hook(self, conversation):
|
||||||
|
pass
|
||||||
|
|
||||||
|
sample_conversation.add_plugin(DummyPlugin())
|
||||||
|
|
||||||
|
sample_conversation.save(temp_json_file)
|
||||||
|
loaded_conv = Conversation.load(temp_json_file)
|
||||||
|
|
||||||
|
assert len(loaded_conv.plugins) == 0
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import pytest
|
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
@@ -16,6 +16,7 @@ class ResponseModel(BaseModel):
|
|||||||
OpenAI,
|
OpenAI,
|
||||||
Groq,
|
Groq,
|
||||||
Ollama,
|
Ollama,
|
||||||
|
# Amazon
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_generate_data(provider_cls):
|
def test_generate_data(provider_cls):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
|
||||||
|
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -10,6 +11,7 @@ from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
|||||||
OpenAI,
|
OpenAI,
|
||||||
Groq,
|
Groq,
|
||||||
Ollama,
|
Ollama,
|
||||||
|
# Amazon,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_generate_text(provider_cls):
|
def test_generate_text(provider_cls):
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
import simplemind as sm
|
||||||
|
|
||||||
|
from simplemind.providers import Anthropic, OpenAI
|
||||||
|
from simplemind.providers._base_tools import BaseTool
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
Anthropic,
|
||||||
|
# Gemini,
|
||||||
|
OpenAI,
|
||||||
|
# Groq,
|
||||||
|
# Ollama,
|
||||||
|
# Amazon
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(
|
||||||
|
location: Annotated[
|
||||||
|
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||||
|
],
|
||||||
|
unit: Annotated[
|
||||||
|
Literal["celcius", "fahrenheit"],
|
||||||
|
Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"),
|
||||||
|
] = "celcius",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the current weather in a given location
|
||||||
|
"""
|
||||||
|
return f"42 {unit}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_location():
|
||||||
|
"""Get the current location"""
|
||||||
|
return "San Francisco,CA"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
MODELS,
|
||||||
|
)
|
||||||
|
def test_single_tool_args(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="What is the weather in San Francisco?")
|
||||||
|
data = conv.send(tools=[get_weather])
|
||||||
|
assert "42" in data.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
MODELS,
|
||||||
|
)
|
||||||
|
def test_single_tool_no_args(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="What is my current location")
|
||||||
|
data = conv.send(tools=[get_location])
|
||||||
|
assert "San Francisco" in data.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
MODELS,
|
||||||
|
)
|
||||||
|
def test_single_tool_partial(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="Can you tell me the weather?")
|
||||||
|
conv.send(tools=[get_weather])
|
||||||
|
# Will answer something like:
|
||||||
|
"""
|
||||||
|
I can help you check the weather, but I need to know the location you're interested in.
|
||||||
|
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
|
||||||
|
where you'd like to know the weather?
|
||||||
|
"""
|
||||||
|
|
||||||
|
conv.add_message(text="San Francisco, CA")
|
||||||
|
data = conv.send(tools=[get_weather])
|
||||||
|
assert "42" in data.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
MODELS,
|
||||||
|
)
|
||||||
|
def test_multiple_tools(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="What is the wheather at my current location?")
|
||||||
|
data = conv.send(tools=[get_location, get_weather])
|
||||||
|
assert "San Francisco" in data.text
|
||||||
|
assert "42" in data.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
MODELS,
|
||||||
|
)
|
||||||
|
def test_tool_decorator(provider_cls):
|
||||||
|
@sm.tool(llm_provider=provider_cls.NAME)
|
||||||
|
def exchange_rate(currency_pair: str) -> float:
|
||||||
|
return 7.9
|
||||||
|
|
||||||
|
assert isinstance(exchange_rate, BaseTool)
|
||||||
|
assert exchange_rate.name == "exchange_rate"
|
||||||
|
assert list(exchange_rate.properties.keys()) == ["currency_pair"]
|
||||||
Reference in New Issue
Block a user