diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 53f6a7e..33d4229 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -64,16 +64,16 @@ def create_conversation( """Create a new conversation.""" # Create the conversation. - conversation = Conversation( + conv = Conversation( llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER, ) # Add plugins to the conversation. for plugin in plugins or []: - conversation.add_plugin(plugin) + conv.add_plugin(plugin) - return conversation + return conv def generate_data( diff --git a/simplemind/models.py b/simplemind/models.py index 8a32fb9..e2e22ab 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -116,17 +116,23 @@ class Conversation(SMBaseModel): except NotImplementedError: pass - def prepend_system_message( - self, text: str, meta: Dict[str, Any] | None = None - ): + def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None): """Prepend a system message to the conversation.""" - self.messages = [Message(role="system", text=text, meta=meta or {})] + self.messages + self.messages = [ + Message(role="system", text=text, meta=meta or {}) + ] + self.messages def add_message( - self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None + self, + role: MESSAGE_ROLE = "user", + text: str | None = None, + *, + meta: Optional[Dict[str, Any]] = None, ): """Add a new message to the conversation.""" + assert text is not None + # Ensure meta is a dict. if meta is None: meta = {} @@ -151,6 +157,8 @@ class Conversation(SMBaseModel): ) -> Message: """Send the conversation to the LLM.""" + # TODO: llm_model and llm_provider should override the conversation's. + # Execute all pre send hooks. for plugin in self.plugins: if hasattr(plugin, "pre_send_hook"): diff --git a/tests/test_conversations.py b/tests/test_conversations.py new file mode 100644 index 0000000..b0c85dd --- /dev/null +++ b/tests/test_conversations.py @@ -0,0 +1,28 @@ +import pytest + +from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon + +import simplemind as sm + + +@pytest.mark.parametrize( + "provider_cls", + [ + Anthropic, + Gemini, + OpenAI, + Groq, + Ollama, + # Amazon + ], +) +def test_generate_data(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="hey") + data = conv.send() + + assert isinstance(data.text, str) + assert len(data.text) > 0 diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index dc07300..610c96a 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,4 +1,3 @@ - import pytest from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon @@ -18,7 +17,7 @@ class ResponseModel(BaseModel): OpenAI, Groq, Ollama, - Amazon + # Amazon ], ) def test_generate_data(provider_cls): diff --git a/tests/test_generate_text.py b/tests/test_generate_text.py index 0611b1d..4ab62cf 100644 --- a/tests/test_generate_text.py +++ b/tests/test_generate_text.py @@ -11,7 +11,7 @@ from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon OpenAI, Groq, Ollama, - Amazon, + # Amazon, ], ) def test_generate_text(provider_cls):