mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor conversation creation and message handling
This commit refactors the `create_conversation` function in `simplemind/__init__.py` to use a more descriptive variable name (`conv`) for the conversation object. It also updates the `add_plugin` method to use the new variable name (`conv`) instead of `conversation`. In `simplemind/models.py`, the `prepend_system_message` method now accepts an optional `meta` parameter. The method also adds a system message to the conversation by prepending it to the list of messages. Additionally, the `add_message` method in `simplemind/models.py` has been modified to include type annotations and a default value for the `role` parameter. The method now requires the `text` parameter to be provided explicitly. A new test file, `tests/test_conversations.py`, has been added to the repository. This file contains a test case for the `generate_data` function, which tests the functionality of different LLM providers. Lastly, the test files `tests/test_generate_data.py` and `tests/test_generate_text.py` have been modified to remove the unused `Amazon` provider from the list of test cases.
This commit is contained in:
@@ -64,16 +64,16 @@ def create_conversation(
|
|||||||
"""Create a new conversation."""
|
"""Create a new conversation."""
|
||||||
|
|
||||||
# Create the conversation.
|
# Create the conversation.
|
||||||
conversation = Conversation(
|
conv = Conversation(
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add plugins to the conversation.
|
# Add plugins to the conversation.
|
||||||
for plugin in plugins or []:
|
for plugin in plugins or []:
|
||||||
conversation.add_plugin(plugin)
|
conv.add_plugin(plugin)
|
||||||
|
|
||||||
return conversation
|
return conv
|
||||||
|
|
||||||
|
|
||||||
def generate_data(
|
def generate_data(
|
||||||
|
|||||||
+13
-5
@@ -116,17 +116,23 @@ class Conversation(SMBaseModel):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepend_system_message(
|
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
|
||||||
self, text: str, meta: Dict[str, Any] | None = None
|
|
||||||
):
|
|
||||||
"""Prepend a system message to the conversation."""
|
"""Prepend a system message to the conversation."""
|
||||||
self.messages = [Message(role="system", text=text, meta=meta or {})] + self.messages
|
self.messages = [
|
||||||
|
Message(role="system", text=text, meta=meta or {})
|
||||||
|
] + self.messages
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
role: MESSAGE_ROLE = "user",
|
||||||
|
text: str | None = None,
|
||||||
|
*,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""Add a new message to the conversation."""
|
"""Add a new message to the conversation."""
|
||||||
|
|
||||||
|
assert text is not None
|
||||||
|
|
||||||
# Ensure meta is a dict.
|
# Ensure meta is a dict.
|
||||||
if meta is None:
|
if meta is None:
|
||||||
meta = {}
|
meta = {}
|
||||||
@@ -151,6 +157,8 @@ class Conversation(SMBaseModel):
|
|||||||
) -> Message:
|
) -> Message:
|
||||||
"""Send the conversation to the LLM."""
|
"""Send the conversation to the LLM."""
|
||||||
|
|
||||||
|
# TODO: llm_model and llm_provider should override the conversation's.
|
||||||
|
|
||||||
# Execute all pre send hooks.
|
# Execute all pre send hooks.
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
if hasattr(plugin, "pre_send_hook"):
|
if hasattr(plugin, "pre_send_hook"):
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||||
|
|
||||||
|
import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
OpenAI,
|
||||||
|
Groq,
|
||||||
|
Ollama,
|
||||||
|
# Amazon
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_data(provider_cls):
|
||||||
|
conv = sm.create_conversation(
|
||||||
|
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
conv.add_message(text="hey")
|
||||||
|
data = conv.send()
|
||||||
|
|
||||||
|
assert isinstance(data.text, str)
|
||||||
|
assert len(data.text) > 0
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||||
@@ -18,7 +17,7 @@ class ResponseModel(BaseModel):
|
|||||||
OpenAI,
|
OpenAI,
|
||||||
Groq,
|
Groq,
|
||||||
Ollama,
|
Ollama,
|
||||||
Amazon
|
# Amazon
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_generate_data(provider_cls):
|
def test_generate_data(provider_cls):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
|||||||
OpenAI,
|
OpenAI,
|
||||||
Groq,
|
Groq,
|
||||||
Ollama,
|
Ollama,
|
||||||
Amazon,
|
# Amazon,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_generate_text(provider_cls):
|
def test_generate_text(provider_cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user