mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
add conversation save/load functionality
This commit is contained in:
+29
-7
@@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from os import PathLike
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
@@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||
def add_message_hook(
|
||||
self, conversation: "Conversation", message: "Message"
|
||||
) -> Any:
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||
def post_send_hook(
|
||||
self, conversation: "Conversation", response: "Message"
|
||||
) -> Any:
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -59,7 +64,7 @@ class Message(SMBaseModel):
|
||||
role: MESSAGE_ROLE
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
raw: Optional[Any] = None
|
||||
raw: Optional[Any] = Field(default=None, exclude=True)
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
|
||||
@@ -90,7 +95,7 @@ class Conversation(SMBaseModel):
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
plugins: List[BasePlugin] = []
|
||||
plugins: List[BasePlugin] = Field(default_factory=list, exclude=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
@@ -120,7 +125,9 @@ class Conversation(SMBaseModel):
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
|
||||
def prepend_system_message(
|
||||
self, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [
|
||||
Message(role="system", text=text, meta=meta or {})
|
||||
@@ -184,14 +191,29 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
# Add the response to the conversation.
|
||||
self.add_message(role="assistant", text=response.text, meta=response.meta)
|
||||
self.add_message(
|
||||
role="assistant", text=response.text, meta=response.meta
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||
"""Get the last message with the given role."""
|
||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||
return next(
|
||||
(m for m in reversed(self.messages) if m.role == role), None
|
||||
)
|
||||
|
||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Add a plugin to the conversation."""
|
||||
self.plugins.append(plugin)
|
||||
|
||||
def save(self, path: PathLike | str) -> None:
|
||||
"""Save the conversation to a JSON file."""
|
||||
with open(path, "w") as f:
|
||||
f.write(self.model_dump_json())
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: PathLike | str) -> "Conversation":
|
||||
"""Load a conversation from a JSON file."""
|
||||
with open(path, "r") as f:
|
||||
return cls.model_validate_json(f.read())
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
import simplemind as sm
|
||||
from simplemind.models import BasePlugin, Conversation
|
||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -26,3 +28,74 @@ def test_generate_data(provider_cls):
|
||||
|
||||
assert isinstance(data.text, str)
|
||||
assert len(data.text) > 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
"""Create a sample conversation for testing."""
|
||||
conv = Conversation(llm_provider="openai")
|
||||
conv.add_message(role="user", text="Hello!")
|
||||
conv.add_message(role="assistant", text="Hi there!")
|
||||
conv.add_message(role="user", text="How are you?")
|
||||
return conv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_json_file(tmp_path):
|
||||
"""Create a temporary file path for testing."""
|
||||
return tmp_path / "conversation.json"
|
||||
|
||||
|
||||
def test_save_conversation(sample_conversation, temp_json_file):
|
||||
"""Test saving a conversation to a JSON file."""
|
||||
sample_conversation.save(temp_json_file)
|
||||
|
||||
assert temp_json_file.exists()
|
||||
|
||||
with open(temp_json_file) as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert "id" in saved_data
|
||||
assert "messages" in saved_data
|
||||
assert "llm_model" in saved_data
|
||||
assert "llm_provider" in saved_data
|
||||
|
||||
assert len(saved_data["messages"]) == 3
|
||||
assert saved_data["messages"][0]["text"] == "Hello!"
|
||||
assert saved_data["messages"][1]["text"] == "Hi there!"
|
||||
assert saved_data["messages"][2]["text"] == "How are you?"
|
||||
|
||||
|
||||
def test_load_conversation(sample_conversation, temp_json_file):
|
||||
"""Test loading a conversation from a JSON file."""
|
||||
sample_conversation.save(temp_json_file)
|
||||
|
||||
loaded_conv = Conversation.load(temp_json_file)
|
||||
|
||||
assert loaded_conv.id == sample_conversation.id
|
||||
assert loaded_conv.llm_model == sample_conversation.llm_model
|
||||
assert loaded_conv.llm_provider == sample_conversation.llm_provider
|
||||
assert len(loaded_conv.messages) == len(sample_conversation.messages)
|
||||
|
||||
for original_msg, loaded_msg in zip(
|
||||
sample_conversation.messages, loaded_conv.messages
|
||||
):
|
||||
assert loaded_msg.role == original_msg.role
|
||||
assert loaded_msg.text == original_msg.text
|
||||
assert loaded_msg.meta == original_msg.meta
|
||||
|
||||
|
||||
def test_save_load_with_plugins(sample_conversation, temp_json_file):
|
||||
"""Test that plugins are properly excluded from serialization."""
|
||||
|
||||
# Create a dummy plugin
|
||||
class DummyPlugin(BasePlugin):
|
||||
def initialize_hook(self, conversation):
|
||||
pass
|
||||
|
||||
sample_conversation.add_plugin(DummyPlugin())
|
||||
|
||||
sample_conversation.save(temp_json_file)
|
||||
loaded_conv = Conversation.load(temp_json_file)
|
||||
|
||||
assert len(loaded_conv.plugins) == 0
|
||||
|
||||
Reference in New Issue
Block a user