diff --git a/simplemind/models.py b/simplemind/models.py index f8ff5be..a51ad5b 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +from os import PathLike from types import TracebackType from typing import Any, Callable, Dict, List, Literal, Optional @@ -64,7 +65,7 @@ class Message(SMBaseModel): role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} - raw: Optional[Any] = None + raw: Optional[Any] = Field(default=None, exclude=True) llm_model: Optional[str] = None llm_provider: Optional[str] = None @@ -95,7 +96,7 @@ class Conversation(SMBaseModel): messages: List[Message] = [] llm_model: Optional[str] = None llm_provider: Optional[str] = None - plugins: List[BasePlugin] = [] + plugins: List[BasePlugin] = Field(default_factory=list, exclude=True) def __str__(self): return f"" @@ -207,3 +208,14 @@ class Conversation(SMBaseModel): def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" self.plugins.append(plugin) + + def save(self, path: PathLike | str) -> None: + """Save the conversation to a JSON file.""" + with open(path, "w") as f: + f.write(self.model_dump_json()) + + @classmethod + def load(cls, path: PathLike | str) -> "Conversation": + """Load a conversation from a JSON file.""" + with open(path, "r") as f: + return cls.model_validate_json(f.read()) diff --git a/tests/test_conversations.py b/tests/test_conversations.py index 4747f65..7488cb2 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -1,7 +1,10 @@ +import json + import pytest import simplemind as sm -from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI +from simplemind.models import BasePlugin, Conversation +from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( @@ -25,3 +28,74 @@ def test_generate_data(provider_cls): assert isinstance(data.text, str) assert len(data.text) > 0 + + +@pytest.fixture +def sample_conversation(): + """Create a sample conversation for testing.""" + conv = Conversation(llm_provider="openai") + conv.add_message(role="user", text="Hello!") + conv.add_message(role="assistant", text="Hi there!") + conv.add_message(role="user", text="How are you?") + return conv + + +@pytest.fixture +def temp_json_file(tmp_path): + """Create a temporary file path for testing.""" + return tmp_path / "conversation.json" + + +def test_save_conversation(sample_conversation, temp_json_file): + """Test saving a conversation to a JSON file.""" + sample_conversation.save(temp_json_file) + + assert temp_json_file.exists() + + with open(temp_json_file) as f: + saved_data = json.load(f) + + assert "id" in saved_data + assert "messages" in saved_data + assert "llm_model" in saved_data + assert "llm_provider" in saved_data + + assert len(saved_data["messages"]) == 3 + assert saved_data["messages"][0]["text"] == "Hello!" + assert saved_data["messages"][1]["text"] == "Hi there!" + assert saved_data["messages"][2]["text"] == "How are you?" + + +def test_load_conversation(sample_conversation, temp_json_file): + """Test loading a conversation from a JSON file.""" + sample_conversation.save(temp_json_file) + + loaded_conv = Conversation.load(temp_json_file) + + assert loaded_conv.id == sample_conversation.id + assert loaded_conv.llm_model == sample_conversation.llm_model + assert loaded_conv.llm_provider == sample_conversation.llm_provider + assert len(loaded_conv.messages) == len(sample_conversation.messages) + + for original_msg, loaded_msg in zip( + sample_conversation.messages, loaded_conv.messages + ): + assert loaded_msg.role == original_msg.role + assert loaded_msg.text == original_msg.text + assert loaded_msg.meta == original_msg.meta + + +def test_save_load_with_plugins(sample_conversation, temp_json_file): + """Test that plugins are properly excluded from serialization.""" + + # Create a dummy plugin + class DummyPlugin(BasePlugin): + def initialize_hook(self, conversation): + pass + + sample_conversation.add_plugin(DummyPlugin()) + + sample_conversation.save(temp_json_file) + loaded_conv = Conversation.load(temp_json_file) + + assert len(loaded_conv.plugins) == 0