Merge pull request #51 from lucianosrp/feat/save-conversation

feat: add conversation save/load functionality
This commit is contained in:
2025-01-09 17:08:14 -05:00
committed by GitHub
2 changed files with 89 additions and 3 deletions
+14 -2
View File
@@ -1,5 +1,6 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from os import PathLike
from types import TracebackType from types import TracebackType
from typing import Any, Callable, Dict, List, Literal, Optional from typing import Any, Callable, Dict, List, Literal, Optional
@@ -64,7 +65,7 @@ class Message(SMBaseModel):
role: MESSAGE_ROLE role: MESSAGE_ROLE
text: str text: str
meta: Dict[str, Any] = {} meta: Dict[str, Any] = {}
raw: Optional[Any] = None raw: Optional[Any] = Field(default=None, exclude=True)
llm_model: Optional[str] = None llm_model: Optional[str] = None
llm_provider: Optional[str] = None llm_provider: Optional[str] = None
@@ -95,7 +96,7 @@ class Conversation(SMBaseModel):
messages: List[Message] = [] messages: List[Message] = []
llm_model: Optional[str] = None llm_model: Optional[str] = None
llm_provider: Optional[str] = None llm_provider: Optional[str] = None
plugins: List[BasePlugin] = [] plugins: List[BasePlugin] = Field(default_factory=list, exclude=True)
def __str__(self): def __str__(self):
return f"<Conversation id={self.id!r}>" return f"<Conversation id={self.id!r}>"
@@ -207,3 +208,14 @@ class Conversation(SMBaseModel):
def add_plugin(self, plugin: BasePlugin) -> None: def add_plugin(self, plugin: BasePlugin) -> None:
"""Add a plugin to the conversation.""" """Add a plugin to the conversation."""
self.plugins.append(plugin) self.plugins.append(plugin)
def save(self, path: PathLike | str) -> None:
"""Save the conversation to a JSON file."""
with open(path, "w") as f:
f.write(self.model_dump_json())
@classmethod
def load(cls, path: PathLike | str) -> "Conversation":
"""Load a conversation from a JSON file."""
with open(path, "r") as f:
return cls.model_validate_json(f.read())
+75 -1
View File
@@ -1,7 +1,10 @@
import json
import pytest import pytest
import simplemind as sm import simplemind as sm
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI from simplemind.models import BasePlugin, Conversation
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -25,3 +28,74 @@ def test_generate_data(provider_cls):
assert isinstance(data.text, str) assert isinstance(data.text, str)
assert len(data.text) > 0 assert len(data.text) > 0
@pytest.fixture
def sample_conversation():
"""Create a sample conversation for testing."""
conv = Conversation(llm_provider="openai")
conv.add_message(role="user", text="Hello!")
conv.add_message(role="assistant", text="Hi there!")
conv.add_message(role="user", text="How are you?")
return conv
@pytest.fixture
def temp_json_file(tmp_path):
"""Create a temporary file path for testing."""
return tmp_path / "conversation.json"
def test_save_conversation(sample_conversation, temp_json_file):
"""Test saving a conversation to a JSON file."""
sample_conversation.save(temp_json_file)
assert temp_json_file.exists()
with open(temp_json_file) as f:
saved_data = json.load(f)
assert "id" in saved_data
assert "messages" in saved_data
assert "llm_model" in saved_data
assert "llm_provider" in saved_data
assert len(saved_data["messages"]) == 3
assert saved_data["messages"][0]["text"] == "Hello!"
assert saved_data["messages"][1]["text"] == "Hi there!"
assert saved_data["messages"][2]["text"] == "How are you?"
def test_load_conversation(sample_conversation, temp_json_file):
"""Test loading a conversation from a JSON file."""
sample_conversation.save(temp_json_file)
loaded_conv = Conversation.load(temp_json_file)
assert loaded_conv.id == sample_conversation.id
assert loaded_conv.llm_model == sample_conversation.llm_model
assert loaded_conv.llm_provider == sample_conversation.llm_provider
assert len(loaded_conv.messages) == len(sample_conversation.messages)
for original_msg, loaded_msg in zip(
sample_conversation.messages, loaded_conv.messages
):
assert loaded_msg.role == original_msg.role
assert loaded_msg.text == original_msg.text
assert loaded_msg.meta == original_msg.meta
def test_save_load_with_plugins(sample_conversation, temp_json_file):
"""Test that plugins are properly excluded from serialization."""
# Create a dummy plugin
class DummyPlugin(BasePlugin):
def initialize_hook(self, conversation):
pass
sample_conversation.add_plugin(DummyPlugin())
sample_conversation.save(temp_json_file)
loaded_conv = Conversation.load(temp_json_file)
assert len(loaded_conv.plugins) == 0