mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 90b85ce08a | |||
| de36bc1328 | |||
| d78aec4e1a | |||
| b47f04c557 | |||
| 7d89af37f1 | |||
| 2e448b9c3d | |||
| 4d38ac02cc | |||
| 88e82d1ad1 | |||
| e44201b800 | |||
| 97f745f230 | |||
| 3af715d650 | |||
| 285f996082 | |||
| 9a5c7ff61b | |||
| 1ecd4a4966 | |||
| b7287ad32a | |||
| 6045d5b5d2 | |||
| d4cfce01ba | |||
| da9958ef46 | |||
| 918705e2d5 | |||
| eae68d1ee1 | |||
| 5bf4fc81e7 | |||
| ca0246a3bb | |||
| 30885beda7 | |||
| a1dfe65084 | |||
| 641de59138 | |||
| 3c4ed48786 | |||
| 467f67d283 | |||
| b109964340 |
@@ -1,6 +1,18 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
|
||||
## 0.1.3 (2024-10-30)
|
||||
|
||||
- Make Conversation a context manager.
|
||||
- Add more robust conversation plugin hooks — replace `send_hook` with `pre_send_hook` and `post_send_hook`.
|
||||
- Change plugin hooks to try/except NotImplementedError.
|
||||
- Implement 'did you mean' with provider names. Can do this eventually with model names, as well.
|
||||
|
||||
## 0.1.2 (2024-10-29)
|
||||
|
||||
- Add ollama provider.
|
||||
|
||||
## 0.1.1 (2024-10-29)
|
||||
|
||||
- Fix Groq provider.
|
||||
|
||||
@@ -21,6 +21,7 @@ To specify a specific provider or model, you can use the `llm_provider` and `llm
|
||||
- **[Anthropic's Claude](https://www.anthropic.com/claude)**
|
||||
- **[xAI's Grok](https://x.ai/)**
|
||||
- **[Groq's Groq](https://groq.com/)**
|
||||
- **[Ollama](https://ollama.com)**
|
||||
|
||||
If you'd like to see SimpleMind support additional providers or models, please send a pull request!
|
||||
|
||||
@@ -40,7 +41,7 @@ First, authenticate your API keys by setting them in the environment variables:
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `GROK_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
|
||||
|
||||
Next, import SimpleMind and start using it:
|
||||
|
||||
@@ -122,7 +123,7 @@ class SimpleMemoryPlugin:
|
||||
def yield_memories(self):
|
||||
return (m for m in self.memories)
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
for m in self.yield_memories():
|
||||
conversation.add_message(role="system", text=m)
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
import simplemind
|
||||
|
||||
project = "simplemind"
|
||||
copyright = "2024, Kenneth Reitz"
|
||||
author = "Kenneth Reitz"
|
||||
release = "v0.1.3"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = ["sphinx.ext.autodoc"]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "alabaster"
|
||||
html_static_path = ["_static"]
|
||||
+236
@@ -0,0 +1,236 @@
|
||||
.. simplemind documentation master file, created by
|
||||
sphinx-quickstart on Wed Oct 30 08:08:14 2024.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
SimpleMind: AI for Humans™
|
||||
==========================
|
||||
|
||||
**SimpleMind** is a versatile Python library designed to simplify interactions with various AI models. It provides a consistent and user-friendly interface to numerous AI providers, enabling developers to seamlessly integrate powerful AI capabilities into their applications without the overhead of managing multiple APIs and configurations.
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- **Unified Interface**: Interact with multiple AI providers using a single, consistent API
|
||||
- **Plugin Architecture**: Extend functionality with custom plugins for tasks like memory management and sentiment analysis
|
||||
- **Structured Data Support**: Generate and manipulate structured data using Pydantic models
|
||||
- **Human-Centered Design**: Prioritizes readability and ease of use, making AI integration accessible to all developers
|
||||
- **Minimal Configuration**: Quickly get started without extensive setup or configuration
|
||||
|
||||
Supported Providers
|
||||
------------------
|
||||
|
||||
SimpleMind supports a variety of AI providers:
|
||||
|
||||
- `OpenAI's GPT <https://openai.com/gpt>`_
|
||||
- `Anthropic's Claude <https://www.anthropic.com/claude>`_
|
||||
- `xAI's Grok <https://x.ai/>`_
|
||||
- `Groq's Groq <https://groq.com/>`_
|
||||
- `Ollama <https://ollama.com>`_
|
||||
|
||||
Installation
|
||||
-----------
|
||||
|
||||
Install SimpleMind using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ pip install simplemind
|
||||
|
||||
Quickstart
|
||||
----------
|
||||
|
||||
1. Set your API keys as environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
$ export ANTHROPIC_API_KEY="..."
|
||||
$ export XAI_API_KEY="..."
|
||||
$ export GROQ_API_KEY="..."
|
||||
|
||||
This is the only required configuration.
|
||||
|
||||
2. Import and use SimpleMind:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
# Generate text using the default provider (OpenAI)
|
||||
response = sm.generate_text("Write a poem about the moon.", llm_model="gpt-4o-mini")
|
||||
print(response)
|
||||
|
||||
Things to know:
|
||||
|
||||
- The primary function for generating text is ``generate_text()``, which is used in the example above.
|
||||
- To generate structured data, use ``generate_data()``, which most providers support. This is extremely useful.
|
||||
- The third function, ``create_conversation()``, is used to engage in conversations with AI models.
|
||||
|
||||
All of these functions accept an ``llm_model`` and ``llm_provider`` parameter, which allows you to specify the AI model to use. If not provided, the default model for the given provider will be used.
|
||||
|
||||
|
||||
Usage Examples
|
||||
--------------
|
||||
|
||||
Here are some examples demonstrating SimpleMind's key features. From generating creative text and structured data to engaging in conversations and extending functionality with plugins, these examples showcase the library's versatility and ease of use.
|
||||
|
||||
Feel free to adapt these examples to your specific use cases!
|
||||
|
||||
|
||||
Text Generation
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
This example generates a poem about the moon using the ``gpt-4o-mini`` model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
poem = sm.generate_text("Write a poem about the moon.", llm_model="gpt-4o-mini")
|
||||
print(poem)
|
||||
|
||||
Structured Data Generation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This example generates a poem about love using the ``gpt-4o-mini`` model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Poem(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
poem = sm.generate_data(
|
||||
prompt="Write a poem about love",
|
||||
llm_model="gpt-4o-mini",
|
||||
response_model=Poem,
|
||||
)
|
||||
print(poem)
|
||||
|
||||
Conversational AI
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
This example engages in a conversation with the ``gpt-4o-mini`` model.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
conversation = sm.create_conversation(llm_model="gpt-4o-mini")
|
||||
conversation.add_message("user", "Hi there, how are you?")
|
||||
response = conversation.send()
|
||||
print(response.text)
|
||||
|
||||
Plugins
|
||||
~~~~~~~
|
||||
|
||||
This example adds a simple custom memory plugin to the conversation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemoryPlugin:
|
||||
def __init__(self):
|
||||
self.memories = ["the moon is made of cheese."]
|
||||
|
||||
def send_hook(self, conversation):
|
||||
for memory in self.memories:
|
||||
conversation.add_message(role="system", text=memory)
|
||||
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_plugin(SimpleMemoryPlugin())
|
||||
conversation.add_message("user", "Write a poem about the moon")
|
||||
print(conversation.send().text)
|
||||
|
||||
Plugin Development
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Plugins in SimpleMind follow a simple hook-based architecture. The ``send_hook`` method shown above is just one of several hooks available. Here's a more detailed example showing the complete plugin interface:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from simplemind.plugins import BasePlugin
|
||||
|
||||
class CustomPlugin(BasePlugin):
|
||||
def __init__(self):
|
||||
self.conversation_history = []
|
||||
|
||||
def initialize_hook(self, conversation):
|
||||
"""Called when the plugin is first added to a conversation."""
|
||||
print("Plugin initialized!")
|
||||
|
||||
def pre_send_hook(self, conversation):
|
||||
"""Called before the conversation is sent to the AI provider."""
|
||||
# Add any system messages or modify the conversation
|
||||
conversation.add_message("system", "Remember to be helpful.")
|
||||
|
||||
def send_hook(self, conversation):
|
||||
"""Called during the send process."""
|
||||
# Add messages or modify the conversation
|
||||
self.conversation_history.append(conversation.messages)
|
||||
|
||||
def post_send_hook(self, conversation, response):
|
||||
"""Called after receiving a response from the AI provider."""
|
||||
# Process or modify the response
|
||||
return response
|
||||
|
||||
def cleanup_hook(self):
|
||||
"""Called when the plugin is removed or the conversation ends."""
|
||||
self.conversation_history.clear()
|
||||
|
||||
All plugins should inherit from ``BasePlugin``, which provides default no-op implementations of these hooks. You only need to implement the hooks you want to use. Here's a simpler example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from simplemind.plugins import BasePlugin
|
||||
|
||||
class LoggingPlugin(BasePlugin):
|
||||
def pre_send_hook(self, conversation):
|
||||
print(f"Sending conversation with {len(conversation.messages)} messages")
|
||||
|
||||
def post_send_hook(self, conversation, response):
|
||||
print(f"Received response: {response.text[:50]}...")
|
||||
return response
|
||||
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_plugin(LoggingPlugin())
|
||||
conversation.add_message("user", "Hello!")
|
||||
response = conversation.send()
|
||||
|
||||
Plugins can be used to implement features like:
|
||||
|
||||
- Conversation logging
|
||||
- Memory management
|
||||
- Response filtering
|
||||
- Token counting
|
||||
- Custom prompt engineering
|
||||
- Analytics and monitoring
|
||||
|
||||
Multiple plugins can be added to a single conversation, and they will be executed in the order they were added.
|
||||
|
||||
|
||||
Contributing
|
||||
-----------
|
||||
|
||||
1. Fork the Repository
|
||||
2. Create a New Branch
|
||||
3. Make Your Changes
|
||||
4. Submit a Pull Request
|
||||
|
||||
Please review our `Code of Conduct <LICENSE>`_ before contributing.
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
SimpleMind is licensed under the `Apache 2.0 License <LICENSE>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
installation
|
||||
usage
|
||||
api
|
||||
contributing
|
||||
changelog
|
||||
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
@@ -9,7 +9,12 @@ import pickle
|
||||
|
||||
|
||||
class ContextualMemoryPlugin:
|
||||
def __init__(self, api_key: str, memory_file: str = "memories.pkl", embedding_model: str = "text-embedding-ada-002"):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
memory_file: str = "memories.pkl",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
):
|
||||
openai.api_key = api_key
|
||||
self.memory_file = memory_file
|
||||
self.embedding_model = embedding_model
|
||||
@@ -35,29 +40,29 @@ class ContextualMemoryPlugin:
|
||||
def build_faiss_index(self):
|
||||
if self.embeddings:
|
||||
self.index = faiss.IndexFlatL2(len(self.embeddings[0]))
|
||||
self.index.add(np.array(self.embeddings).astype('float32'))
|
||||
self.index.add(np.array(self.embeddings).astype("float32"))
|
||||
else:
|
||||
self.index = faiss.IndexFlatL2(1536)
|
||||
|
||||
def get_embedding(self, text: str) -> list:
|
||||
response = openai.Embedding.create(input=text, model=self.embedding_model)
|
||||
return response['data'][0]['embedding']
|
||||
return response["data"][0]["embedding"]
|
||||
|
||||
def add_memory(self, memory: str):
|
||||
embedding = self.get_embedding(memory)
|
||||
self.memories.append(memory)
|
||||
self.embeddings.append(embedding)
|
||||
self.index.add(np.array([embedding]).astype('float32'))
|
||||
self.index.add(np.array([embedding]).astype("float32"))
|
||||
self.save_memories()
|
||||
|
||||
def retrieve_memories(self, query: str, top_k: int = 3) -> list:
|
||||
if not self.index or len(self.embeddings) == 0:
|
||||
return []
|
||||
query_embedding = self.get_embedding(query)
|
||||
D, I = self.index.search(np.array([query_embedding]).astype('float32'), top_k)
|
||||
D, I = self.index.search(np.array([query_embedding]).astype("float32"), top_k)
|
||||
return [self.memories[i] for i in I[0] if i < len(self.memories)]
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
def pre_send_hook(self, conversation: sm.Conversation):
|
||||
# Retrieve relevant memories based on the latest user message
|
||||
if conversation.messages:
|
||||
last_user_message = conversation.messages[-1].text
|
||||
@@ -69,13 +74,16 @@ class ContextualMemoryPlugin:
|
||||
# Optionally, add the AI's response to memories
|
||||
self.add_memory(response)
|
||||
|
||||
|
||||
# Example Usage
|
||||
|
||||
|
||||
# Define a Pydantic model if needed
|
||||
class Story(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
# Initialize the conversation with the ContextualMemoryPlugin
|
||||
memory_plugin = ContextualMemoryPlugin(api_key=sm.settings.OPENAI_API_KEY)
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class LoggingPlugin(sm.BasePlugin):
|
||||
def pre_send_hook(self, conversation):
|
||||
print(f"Sending conversation with {len(conversation.messages)} messages")
|
||||
|
||||
def add_message_hook(self, conversation, message):
|
||||
print(f"Adding message to conversation: {message.text}")
|
||||
|
||||
def cleanup_hook(self, conversation):
|
||||
print(f"Cleaning up conversation with {len(conversation.messages)} messages")
|
||||
|
||||
def initialize_hook(self, conversation):
|
||||
print("Initializing conversation")
|
||||
|
||||
def post_send_hook(self, conversation, response):
|
||||
print(f"Received response: {response.text}")
|
||||
|
||||
|
||||
with sm.create_conversation() as conversation:
|
||||
# Add the logging plugin.
|
||||
conversation.add_plugin(LoggingPlugin())
|
||||
|
||||
# Add a message to the conversation.
|
||||
conversation.add_message("user", "Hello!", meta={})
|
||||
|
||||
# Send the conversation.
|
||||
response = conversation.send()
|
||||
|
||||
print(f"Response: {response.text}")
|
||||
@@ -1,7 +1,7 @@
|
||||
from _context import sm
|
||||
|
||||
|
||||
class SimpleMemoryPlugin:
|
||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally beeen destroyed.",
|
||||
@@ -11,9 +11,9 @@ class SimpleMemoryPlugin:
|
||||
def yield_memories(self):
|
||||
return (m for m in self.memories)
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
def initialize_hook(self, conversation: sm.Conversation):
|
||||
for m in self.yield_memories():
|
||||
conversation.add_message(role="system", text=m)
|
||||
conversation.prepend_system_message(role="system", text=m)
|
||||
|
||||
|
||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.1.1"
|
||||
version = "0.1.3"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
|
||||
+20
-4
@@ -1,21 +1,34 @@
|
||||
from .models import Conversation
|
||||
from typing import List, Optional
|
||||
|
||||
from .models import Conversation, BasePlugin
|
||||
from .utils import find_provider
|
||||
from .settings import settings
|
||||
|
||||
|
||||
def create_conversation(llm_model=None, llm_provider=None):
|
||||
def create_conversation(
|
||||
llm_model=None, llm_provider=None, *, plugins: Optional[List[BasePlugin]] = None
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
|
||||
return Conversation(
|
||||
# Create the conversation.
|
||||
conversation = Conversation(
|
||||
llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER
|
||||
)
|
||||
|
||||
# Add plugins to the conversation.
|
||||
for plugin in plugins or []:
|
||||
conversation.add_plugin(plugin)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
|
||||
"""Generate structured data from a given prompt."""
|
||||
|
||||
# Find the provider.
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
# Generate the data.
|
||||
return provider.structured_response(
|
||||
prompt=prompt,
|
||||
llm_model=llm_model,
|
||||
@@ -25,16 +38,19 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
|
||||
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
||||
"""Generate text from a given prompt."""
|
||||
|
||||
# Find the provider.
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
# Generate the text.
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"create_conversation",
|
||||
"find_provider",
|
||||
"generate_data",
|
||||
"generate_text",
|
||||
"settings",
|
||||
"BasePlugin",
|
||||
]
|
||||
|
||||
+83
-6
@@ -25,9 +25,32 @@ class SMBaseModel(BaseModel):
|
||||
class BasePlugin(ABC):
|
||||
"""The base conversation plugin class."""
|
||||
|
||||
@abstractmethod
|
||||
def send_hook(self, conversation: "Conversation"):
|
||||
"""Send a hook to the plugin."""
|
||||
# Plugin metadata.
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# @abstractmethod
|
||||
def initialize_hook(self, conversation: "Conversation"):
|
||||
"""Initialize a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def cleanup_hook(self, conversation: "Conversation"):
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message"):
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def pre_send_hook(self, conversation: "Conversation"):
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message"):
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -60,28 +83,82 @@ class Conversation(SMBaseModel):
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
|
||||
def prepend_system_message(self, role: str, text: str, meta: Optional[Dict[str, Any]] = None):
|
||||
def __enter__(self):
|
||||
# Execute all initialize hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "initialize_hook"):
|
||||
try:
|
||||
plugin.initialize_hook(self)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# Execute all cleanup hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "cleanup_hook"):
|
||||
try:
|
||||
plugin.cleanup_hook(self)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def prepend_system_message(
|
||||
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
||||
|
||||
def add_message(
|
||||
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Add a new message to the conversation."""
|
||||
|
||||
# Ensure meta is a dict.
|
||||
if meta is None:
|
||||
meta = {}
|
||||
|
||||
# Execute all add-message hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "add_message_hook"):
|
||||
try:
|
||||
plugin.add_message_hook(
|
||||
self, Message(role=role, text=text, meta=meta)
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Add the message to the conversation.
|
||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||
|
||||
def send(
|
||||
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
for plugin in self.plugins:
|
||||
plugin.send_hook(self)
|
||||
|
||||
# Execute all pre send hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "pre_send_hook"):
|
||||
try:
|
||||
plugin.pre_send_hook(self)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Find the provider and send the conversation.
|
||||
provider = find_provider(llm_provider or self.llm_provider)
|
||||
response = provider.send_conversation(self)
|
||||
|
||||
# Execute all post-send hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "post_send_hook"):
|
||||
try:
|
||||
plugin.post_send_hook(self, response)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Add the response to the conversation.
|
||||
self.add_message(role="assistant", text=response.text, meta=response.meta)
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
|
||||
|
||||
@@ -39,7 +39,7 @@ class Anthropic(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
@@ -53,13 +53,13 @@ class Anthropic(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model, response_model, **kwargs):
|
||||
response = self.structured_client.messages.create(
|
||||
model=model, response_model=response_model, **kwargs
|
||||
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -69,7 +69,7 @@ class Anthropic(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=llm_model,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
|
||||
@@ -42,7 +42,7 @@ class Groq(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -55,7 +55,7 @@ class Groq(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@@ -85,7 +85,7 @@ class Groq(BaseProvider):
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class Ollama(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.get("content"),
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@@ -63,7 +63,10 @@ class Ollama(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages, model=llm_model, response_model=response_model, **kwargs
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -72,6 +75,8 @@ class Ollama(BaseProvider):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(messages=messages, model=llm_model)
|
||||
response = self.client.chat(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
return response.get("message").get("content")
|
||||
|
||||
@@ -60,7 +60,10 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages, model=llm_model, response_model=response_model, **kwargs
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -70,7 +73,7 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model, **kwargs
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -43,7 +43,7 @@ class XAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -56,7 +56,7 @@ class XAI(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.content,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ class XAI(BaseProvider):
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ class Settings(BaseSettings):
|
||||
)
|
||||
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
||||
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
||||
OLLAMA_HOST_URL: Optional[str] = Field(None, description="Fully qualified host URL for Ollama")
|
||||
OLLAMA_HOST_URL: Optional[str] = Field(
|
||||
"http://127.0.0.1:11434", description="Fully qualified host URL for Ollama"
|
||||
)
|
||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
||||
|
||||
|
||||
+14
-1
@@ -1,7 +1,10 @@
|
||||
import difflib
|
||||
from typing import Union
|
||||
|
||||
from .providers import providers
|
||||
|
||||
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
||||
|
||||
|
||||
def find_provider(provider_name: Union[str, None]):
|
||||
"""Find a provider by name."""
|
||||
@@ -10,4 +13,14 @@ def find_provider(provider_name: Union[str, None]):
|
||||
if provider_class.NAME.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
raise ValueError(f"Provider {provider_name} not found")
|
||||
|
||||
provider_found = difflib.get_close_matches(
|
||||
provider_name.lower(), _PROVIDER_NAMES, n=1
|
||||
) # Show only one suggestion
|
||||
|
||||
if provider_found:
|
||||
raise ValueError(
|
||||
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} not found.")
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import simplemind as sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestOllama(unittest.TestCase):
|
||||
|
||||
def test_generate_text(self):
|
||||
result = sm.generate_text(prompt="What is the meaning of life?", llm_provider="ollama", llm_model="llama3.2")
|
||||
self.assertGreater(len(result), 0)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_create_conversation(self):
|
||||
conversation = sm.create_conversation(llm_provider="ollama", llm_model="llama3.2")
|
||||
conversation.add_message("user", "Remember the number 42.")
|
||||
result = conversation.send()
|
||||
self.assertIsNotNone(result)
|
||||
self.assertGreaterEqual(len(result.text), 0)
|
||||
self.assertIsInstance(result, sm.models.Message)
|
||||
|
||||
def test_memory(self):
|
||||
class SimpleMemoryPlugin:
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally been destroyed.",
|
||||
"the moon is made of cheese.",
|
||||
]
|
||||
|
||||
def yield_memories(self):
|
||||
return (m for m in self.memories)
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
for m in self.yield_memories():
|
||||
conversation.prepend_system_message(role="system", text=m)
|
||||
|
||||
conversation = sm.create_conversation(llm_provider="ollama", llm_model="llama3.2")
|
||||
|
||||
conversation.add_message(
|
||||
role="user",
|
||||
text="Write a poem about the moon",
|
||||
)
|
||||
self.assertGreater(len(conversation.messages), 0)
|
||||
conversation.add_plugin(SimpleMemoryPlugin())
|
||||
result = conversation.send()
|
||||
self.assertGreater(len(conversation.messages), 2)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIsNotNone(result.text)
|
||||
self.assertGreater(len(result.text), 0)
|
||||
self.assertIsInstance(result, sm.models.Message)
|
||||
|
||||
def test_structure_response(self):
|
||||
class Poem(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
# Test for NotImplementedError
|
||||
with self.assertRaises(NotImplementedError):
|
||||
sm.generate_data(
|
||||
prompt="Write a poem about love",
|
||||
llm_provider="ollama",
|
||||
llm_model="llama3.2",
|
||||
response_model=Poem)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user