mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
improved type hinting
This commit is contained in:
+38
-11
@@ -1,15 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
from .providers import find_provider
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
class SMBaseModel(BaseModel):
|
||||
"""The base SimpleMind model class."""
|
||||
|
||||
date_created: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
def __str__(self):
|
||||
@@ -19,34 +23,41 @@ class SMBaseModel(BaseModel):
|
||||
return str(self)
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
class BasePlugin(SMBaseModel, ABC):
|
||||
"""The base conversation plugin class."""
|
||||
|
||||
# Plugin metadata.
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
def initialize_hook(self, conversation: "Conversation"):
|
||||
@abstractmethod
|
||||
def initialize_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Initialize a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_hook(self, conversation: "Conversation"):
|
||||
@abstractmethod
|
||||
def cleanup_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message"):
|
||||
@abstractmethod
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def pre_send_hook(self, conversation: "Conversation"):
|
||||
@abstractmethod
|
||||
def pre_send_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message"):
|
||||
@abstractmethod
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Message(SMBaseModel):
|
||||
"""A message in a conversation."""
|
||||
|
||||
role: MESSAGE_ROLE
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
@@ -58,7 +69,16 @@ class Message(SMBaseModel):
|
||||
return f"<Message role={self.role} text={self.text!r}>"
|
||||
|
||||
@classmethod
|
||||
def from_raw_response(cls, *, text: str, raw):
|
||||
def from_raw_response(cls, *, text: str, raw: Any) -> "Message":
|
||||
"""Create a Message instance from a raw response.
|
||||
|
||||
Args:
|
||||
text (str): The message text.
|
||||
raw (Any): The raw response data.
|
||||
|
||||
Returns:
|
||||
Message: A new Message instance.
|
||||
"""
|
||||
self = cls()
|
||||
self.text = text
|
||||
self.raw = raw
|
||||
@@ -66,6 +86,8 @@ class Message(SMBaseModel):
|
||||
|
||||
|
||||
class Conversation(SMBaseModel):
|
||||
"""A conversation between a user and an assistant."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
@@ -86,8 +108,13 @@ class Conversation(SMBaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# Execute all cleanup hooks.
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException],
|
||||
exc_value: BaseException,
|
||||
traceback: TracebackType,
|
||||
) -> None:
|
||||
"""Execute all cleanup hooks."""
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "cleanup_hook"):
|
||||
try:
|
||||
@@ -96,7 +123,7 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
def prepend_system_message(
|
||||
self, role: str, text: str, meta: Dict[str, Any] | None = None
|
||||
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
||||
|
||||
+8
-7
@@ -5,19 +5,22 @@ from .providers import BaseProvider, providers
|
||||
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
||||
|
||||
|
||||
def find_provider(provider_name: str) -> BaseProvider:
|
||||
def find_provider(provider_name: str | None) -> BaseProvider:
|
||||
"""
|
||||
Find and instantiate a provider by name.
|
||||
|
||||
Parameters:
|
||||
provider_name (Union[str, None]): The name of the provider to find.
|
||||
provider_name (Union[str, None]): The name of the provider to find.
|
||||
|
||||
Returns:
|
||||
An instance of the provider class if found.
|
||||
An instance of the provider class if found.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not found, with a suggestion for the closest match.
|
||||
ValueError: If the provider is not specified or is not found, with a suggestion for the closest match.
|
||||
"""
|
||||
if provider_name is None:
|
||||
raise ValueError("No provider specified.")
|
||||
|
||||
# Find the provider by name.
|
||||
for provider_class in providers:
|
||||
if provider_class.NAME.lower() == provider_name.lower():
|
||||
@@ -28,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider:
|
||||
provider_found = difflib.get_close_matches(
|
||||
provider_name.lower(), _PROVIDER_NAMES, n=1
|
||||
)
|
||||
|
||||
if provider_found:
|
||||
raise ValueError(
|
||||
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} not found.")
|
||||
raise ValueError(f"Provider {provider_name} not found.")
|
||||
|
||||
Reference in New Issue
Block a user