Refactor conversation creation and message handling

This commit refactors the `create_conversation` function in `simplemind/__init__.py` to use a more descriptive variable name (`conv`) for the conversation object. It also updates the `add_plugin` method to use the new variable name (`conv`) instead of `conversation`.

In `simplemind/models.py`, the `prepend_system_message` method now accepts an optional `meta` parameter. The method also adds a system message to the conversation by prepending it to the list of messages.

Additionally, the `add_message` method in `simplemind/models.py` has been modified to include type annotations and a default value for the `role` parameter. The method now requires the `text` parameter to be provided explicitly.

A new test file, `tests/test_conversations.py`, has been added to the repository. This file contains a test case for the `generate_data` function, which tests the functionality of different LLM providers.

Lastly, the test files `tests/test_generate_data.py` and `tests/test_generate_text.py` have been modified to remove the unused `Amazon` provider from the list of test cases.
This commit is contained in:
2024-11-02 11:24:26 -04:00
parent 931285f8ce
commit 33e53562ae
5 changed files with 46 additions and 11 deletions
+3 -3
View File
@@ -64,16 +64,16 @@ def create_conversation(
"""Create a new conversation."""
# Create the conversation.
conversation = Conversation(
conv = Conversation(
llm_model=llm_model,
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
)
# Add plugins to the conversation.
for plugin in plugins or []:
conversation.add_plugin(plugin)
conv.add_plugin(plugin)
return conversation
return conv
def generate_data(
+13 -5
View File
@@ -116,17 +116,23 @@ class Conversation(SMBaseModel):
except NotImplementedError:
pass
def prepend_system_message(
self, text: str, meta: Dict[str, Any] | None = None
):
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
"""Prepend a system message to the conversation."""
self.messages = [Message(role="system", text=text, meta=meta or {})] + self.messages
self.messages = [
Message(role="system", text=text, meta=meta or {})
] + self.messages
def add_message(
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
self,
role: MESSAGE_ROLE = "user",
text: str | None = None,
*,
meta: Optional[Dict[str, Any]] = None,
):
"""Add a new message to the conversation."""
assert text is not None
# Ensure meta is a dict.
if meta is None:
meta = {}
@@ -151,6 +157,8 @@ class Conversation(SMBaseModel):
) -> Message:
"""Send the conversation to the LLM."""
# TODO: llm_model and llm_provider should override the conversation's.
# Execute all pre send hooks.
for plugin in self.plugins:
if hasattr(plugin, "pre_send_hook"):
+28
View File
@@ -0,0 +1,28 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
import simplemind as sm
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
# Amazon
],
)
def test_generate_data(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="hey")
data = conv.send()
assert isinstance(data.text, str)
assert len(data.text) > 0
+1 -2
View File
@@ -1,4 +1,3 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
@@ -18,7 +17,7 @@ class ResponseModel(BaseModel):
OpenAI,
Groq,
Ollama,
Amazon
# Amazon
],
)
def test_generate_data(provider_cls):
+1 -1
View File
@@ -11,7 +11,7 @@ from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
OpenAI,
Groq,
Ollama,
Amazon,
# Amazon,
],
)
def test_generate_text(provider_cls):