improved type hinting

This commit is contained in:
Siddhesh Agarwal
2024-10-31 18:42:54 +05:30
parent 48291c37c5
commit 5505a3e18d
2 changed files with 46 additions and 18 deletions
+38 -11
View File
@@ -1,15 +1,19 @@
from abc import ABC, abstractmethod
from types import TracebackType
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .utils import find_provider
from .providers import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
class SMBaseModel(BaseModel):
"""The base SimpleMind model class."""
date_created: datetime = Field(default_factory=datetime.now)
def __str__(self):
@@ -19,34 +23,41 @@ class SMBaseModel(BaseModel):
return str(self)
class BasePlugin:
class BasePlugin(SMBaseModel, ABC):
"""The base conversation plugin class."""
# Plugin metadata.
meta: Dict[str, Any] = {}
def initialize_hook(self, conversation: "Conversation"):
@abstractmethod
def initialize_hook(self, conversation: "Conversation") -> Any:
"""Initialize a hook for the plugin."""
raise NotImplementedError
def cleanup_hook(self, conversation: "Conversation"):
@abstractmethod
def cleanup_hook(self, conversation: "Conversation") -> Any:
"""Cleanup a hook for the plugin."""
raise NotImplementedError
def add_message_hook(self, conversation: "Conversation", message: "Message"):
@abstractmethod
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
"""Add a message hook for the plugin."""
raise NotImplementedError
def pre_send_hook(self, conversation: "Conversation"):
@abstractmethod
def pre_send_hook(self, conversation: "Conversation") -> Any:
"""Pre-send hook for the plugin."""
raise NotImplementedError
def post_send_hook(self, conversation: "Conversation", response: "Message"):
@abstractmethod
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
"""Post-send hook for the plugin."""
raise NotImplementedError
class Message(SMBaseModel):
"""A message in a conversation."""
role: MESSAGE_ROLE
text: str
meta: Dict[str, Any] = {}
@@ -58,7 +69,16 @@ class Message(SMBaseModel):
return f"<Message role={self.role} text={self.text!r}>"
@classmethod
def from_raw_response(cls, *, text: str, raw):
def from_raw_response(cls, *, text: str, raw: Any) -> "Message":
"""Create a Message instance from a raw response.
Args:
text (str): The message text.
raw (Any): The raw response data.
Returns:
Message: A new Message instance.
"""
self = cls()
self.text = text
self.raw = raw
@@ -66,6 +86,8 @@ class Message(SMBaseModel):
class Conversation(SMBaseModel):
"""A conversation between a user and an assistant."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
messages: List[Message] = []
llm_model: Optional[str] = None
@@ -86,8 +108,13 @@ class Conversation(SMBaseModel):
return self
def __exit__(self, exc_type, exc_value, traceback):
# Execute all cleanup hooks.
def __exit__(
self,
exc_type: type[BaseException],
exc_value: BaseException,
traceback: TracebackType,
) -> None:
"""Execute all cleanup hooks."""
for plugin in self.plugins:
if hasattr(plugin, "cleanup_hook"):
try:
@@ -96,7 +123,7 @@ class Conversation(SMBaseModel):
pass
def prepend_system_message(
self, role: str, text: str, meta: Dict[str, Any] | None = None
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
):
"""Prepend a system message to the conversation."""
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
+8 -7
View File
@@ -5,19 +5,22 @@ from .providers import BaseProvider, providers
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
def find_provider(provider_name: str) -> BaseProvider:
def find_provider(provider_name: str | None) -> BaseProvider:
"""
Find and instantiate a provider by name.
Parameters:
provider_name (Union[str, None]): The name of the provider to find.
provider_name (Union[str, None]): The name of the provider to find.
Returns:
An instance of the provider class if found.
An instance of the provider class if found.
Raises:
ValueError: If the provider is not found, with a suggestion for the closest match.
ValueError: If the provider is not specified or is not found, with a suggestion for the closest match.
"""
if provider_name is None:
raise ValueError("No provider specified.")
# Find the provider by name.
for provider_class in providers:
if provider_class.NAME.lower() == provider_name.lower():
@@ -28,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider:
provider_found = difflib.get_close_matches(
provider_name.lower(), _PROVIDER_NAMES, n=1
)
if provider_found:
raise ValueError(
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
)
else:
raise ValueError(f"Provider {provider_name} not found.")
raise ValueError(f"Provider {provider_name} not found.")