From a2991eec0c18562d3b5ed7ccb14902d445af75ea Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Tue, 26 Nov 2024 23:37:30 +0800 Subject: [PATCH] add conversation save/load functionality --- simplemind/models.py | 36 +++++++++++++---- tests/test_conversations.py | 77 ++++++++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/simplemind/models.py b/simplemind/models.py index 5ba0a2b..3f7a522 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +from os import PathLike from types import TracebackType from typing import Any, Dict, List, Literal, Optional @@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel): """Cleanup a hook for the plugin.""" raise NotImplementedError - def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any: + def add_message_hook( + self, conversation: "Conversation", message: "Message" + ) -> Any: """Add a message hook for the plugin.""" raise NotImplementedError @@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel): """Pre-send hook for the plugin.""" raise NotImplementedError - def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any: + def post_send_hook( + self, conversation: "Conversation", response: "Message" + ) -> Any: """Post-send hook for the plugin.""" raise NotImplementedError @@ -59,7 +64,7 @@ class Message(SMBaseModel): role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} - raw: Optional[Any] = None + raw: Optional[Any] = Field(default=None, exclude=True) llm_model: Optional[str] = None llm_provider: Optional[str] = None @@ -90,7 +95,7 @@ class Conversation(SMBaseModel): messages: List[Message] = [] llm_model: Optional[str] = None llm_provider: Optional[str] = None - plugins: List[BasePlugin] = [] + plugins: List[BasePlugin] = Field(default_factory=list, exclude=True) def __str__(self): return f"" @@ -120,7 +125,9 @@ class Conversation(SMBaseModel): except NotImplementedError: pass - def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None): + def prepend_system_message( + self, text: str, meta: Dict[str, Any] | None = None + ): """Prepend a system message to the conversation.""" self.messages = [ Message(role="system", text=text, meta=meta or {}) @@ -184,14 +191,29 @@ class Conversation(SMBaseModel): pass # Add the response to the conversation. - self.add_message(role="assistant", text=response.text, meta=response.meta) + self.add_message( + role="assistant", text=response.text, meta=response.meta + ) return response def get_last_message(self, role: MESSAGE_ROLE) -> Message | None: """Get the last message with the given role.""" - return next((m for m in reversed(self.messages) if m.role == role), None) + return next( + (m for m in reversed(self.messages) if m.role == role), None + ) def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" self.plugins.append(plugin) + + def save(self, path: PathLike | str) -> None: + """Save the conversation to a JSON file.""" + with open(path, "w") as f: + f.write(self.model_dump_json()) + + @classmethod + def load(cls, path: PathLike | str) -> "Conversation": + """Load a conversation from a JSON file.""" + with open(path, "r") as f: + return cls.model_validate_json(f.read()) diff --git a/tests/test_conversations.py b/tests/test_conversations.py index b0c85dd..7488cb2 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -1,8 +1,10 @@ +import json + import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - import simplemind as sm +from simplemind.models import BasePlugin, Conversation +from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( @@ -26,3 +28,74 @@ def test_generate_data(provider_cls): assert isinstance(data.text, str) assert len(data.text) > 0 + + +@pytest.fixture +def sample_conversation(): + """Create a sample conversation for testing.""" + conv = Conversation(llm_provider="openai") + conv.add_message(role="user", text="Hello!") + conv.add_message(role="assistant", text="Hi there!") + conv.add_message(role="user", text="How are you?") + return conv + + +@pytest.fixture +def temp_json_file(tmp_path): + """Create a temporary file path for testing.""" + return tmp_path / "conversation.json" + + +def test_save_conversation(sample_conversation, temp_json_file): + """Test saving a conversation to a JSON file.""" + sample_conversation.save(temp_json_file) + + assert temp_json_file.exists() + + with open(temp_json_file) as f: + saved_data = json.load(f) + + assert "id" in saved_data + assert "messages" in saved_data + assert "llm_model" in saved_data + assert "llm_provider" in saved_data + + assert len(saved_data["messages"]) == 3 + assert saved_data["messages"][0]["text"] == "Hello!" + assert saved_data["messages"][1]["text"] == "Hi there!" + assert saved_data["messages"][2]["text"] == "How are you?" + + +def test_load_conversation(sample_conversation, temp_json_file): + """Test loading a conversation from a JSON file.""" + sample_conversation.save(temp_json_file) + + loaded_conv = Conversation.load(temp_json_file) + + assert loaded_conv.id == sample_conversation.id + assert loaded_conv.llm_model == sample_conversation.llm_model + assert loaded_conv.llm_provider == sample_conversation.llm_provider + assert len(loaded_conv.messages) == len(sample_conversation.messages) + + for original_msg, loaded_msg in zip( + sample_conversation.messages, loaded_conv.messages + ): + assert loaded_msg.role == original_msg.role + assert loaded_msg.text == original_msg.text + assert loaded_msg.meta == original_msg.meta + + +def test_save_load_with_plugins(sample_conversation, temp_json_file): + """Test that plugins are properly excluded from serialization.""" + + # Create a dummy plugin + class DummyPlugin(BasePlugin): + def initialize_hook(self, conversation): + pass + + sample_conversation.add_plugin(DummyPlugin()) + + sample_conversation.save(temp_json_file) + loaded_conv = Conversation.load(temp_json_file) + + assert len(loaded_conv.plugins) == 0