mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor conversation creation and message handling
This commit refactors the `create_conversation` function in `simplemind/__init__.py` to use a more descriptive variable name (`conv`) for the conversation object. It also updates the `add_plugin` method to use the new variable name (`conv`) instead of `conversation`. In `simplemind/models.py`, the `prepend_system_message` method now accepts an optional `meta` parameter. The method also adds a system message to the conversation by prepending it to the list of messages. Additionally, the `add_message` method in `simplemind/models.py` has been modified to include type annotations and a default value for the `role` parameter. The method now requires the `text` parameter to be provided explicitly. A new test file, `tests/test_conversations.py`, has been added to the repository. This file contains a test case for the `generate_data` function, which tests the functionality of different LLM providers. Lastly, the test files `tests/test_generate_data.py` and `tests/test_generate_text.py` have been modified to remove the unused `Amazon` provider from the list of test cases.
This commit is contained in:
@@ -64,16 +64,16 @@ def create_conversation(
|
||||
"""Create a new conversation."""
|
||||
|
||||
# Create the conversation.
|
||||
conversation = Conversation(
|
||||
conv = Conversation(
|
||||
llm_model=llm_model,
|
||||
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
||||
)
|
||||
|
||||
# Add plugins to the conversation.
|
||||
for plugin in plugins or []:
|
||||
conversation.add_plugin(plugin)
|
||||
conv.add_plugin(plugin)
|
||||
|
||||
return conversation
|
||||
return conv
|
||||
|
||||
|
||||
def generate_data(
|
||||
|
||||
+13
-5
@@ -116,17 +116,23 @@ class Conversation(SMBaseModel):
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def prepend_system_message(
|
||||
self, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [Message(role="system", text=text, meta=meta or {})] + self.messages
|
||||
self.messages = [
|
||||
Message(role="system", text=text, meta=meta or {})
|
||||
] + self.messages
|
||||
|
||||
def add_message(
|
||||
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
self,
|
||||
role: MESSAGE_ROLE = "user",
|
||||
text: str | None = None,
|
||||
*,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Add a new message to the conversation."""
|
||||
|
||||
assert text is not None
|
||||
|
||||
# Ensure meta is a dict.
|
||||
if meta is None:
|
||||
meta = {}
|
||||
@@ -151,6 +157,8 @@ class Conversation(SMBaseModel):
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
# TODO: llm_model and llm_provider should override the conversation's.
|
||||
|
||||
# Execute all pre send hooks.
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "pre_send_hook"):
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
# Amazon
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="hey")
|
||||
data = conv.send()
|
||||
|
||||
assert isinstance(data.text, str)
|
||||
assert len(data.text) > 0
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
@@ -18,7 +17,7 @@ class ResponseModel(BaseModel):
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon
|
||||
# Amazon
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
|
||||
@@ -11,7 +11,7 @@ from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon,
|
||||
# Amazon,
|
||||
],
|
||||
)
|
||||
def test_generate_text(provider_cls):
|
||||
|
||||
Reference in New Issue
Block a user