From 5505a3e18d22d6fb850ae7b58dadacc15c994dbd Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Thu, 31 Oct 2024 18:42:54 +0530 Subject: [PATCH] improved type hinting --- simplemind/models.py | 49 ++++++++++++++++++++++++++++++++++---------- simplemind/utils.py | 15 +++++++------- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/simplemind/models.py b/simplemind/models.py index 6923ead..121e128 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,15 +1,19 @@ +from abc import ABC, abstractmethod +from types import TracebackType import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field -from .utils import find_provider +from .providers import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] class SMBaseModel(BaseModel): + """The base SimpleMind model class.""" + date_created: datetime = Field(default_factory=datetime.now) def __str__(self): @@ -19,34 +23,41 @@ class SMBaseModel(BaseModel): return str(self) -class BasePlugin: +class BasePlugin(SMBaseModel, ABC): """The base conversation plugin class.""" # Plugin metadata. meta: Dict[str, Any] = {} - def initialize_hook(self, conversation: "Conversation"): + @abstractmethod + def initialize_hook(self, conversation: "Conversation") -> Any: """Initialize a hook for the plugin.""" raise NotImplementedError - def cleanup_hook(self, conversation: "Conversation"): + @abstractmethod + def cleanup_hook(self, conversation: "Conversation") -> Any: """Cleanup a hook for the plugin.""" raise NotImplementedError - def add_message_hook(self, conversation: "Conversation", message: "Message"): + @abstractmethod + def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any: """Add a message hook for the plugin.""" raise NotImplementedError - def pre_send_hook(self, conversation: "Conversation"): + @abstractmethod + def pre_send_hook(self, conversation: "Conversation") -> Any: """Pre-send hook for the plugin.""" raise NotImplementedError - def post_send_hook(self, conversation: "Conversation", response: "Message"): + @abstractmethod + def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any: """Post-send hook for the plugin.""" raise NotImplementedError class Message(SMBaseModel): + """A message in a conversation.""" + role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} @@ -58,7 +69,16 @@ class Message(SMBaseModel): return f"" @classmethod - def from_raw_response(cls, *, text: str, raw): + def from_raw_response(cls, *, text: str, raw: Any) -> "Message": + """Create a Message instance from a raw response. + + Args: + text (str): The message text. + raw (Any): The raw response data. + + Returns: + Message: A new Message instance. + """ self = cls() self.text = text self.raw = raw @@ -66,6 +86,8 @@ class Message(SMBaseModel): class Conversation(SMBaseModel): + """A conversation between a user and an assistant.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) messages: List[Message] = [] llm_model: Optional[str] = None @@ -86,8 +108,13 @@ class Conversation(SMBaseModel): return self - def __exit__(self, exc_type, exc_value, traceback): - # Execute all cleanup hooks. + def __exit__( + self, + exc_type: type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + """Execute all cleanup hooks.""" for plugin in self.plugins: if hasattr(plugin, "cleanup_hook"): try: @@ -96,7 +123,7 @@ class Conversation(SMBaseModel): pass def prepend_system_message( - self, role: str, text: str, meta: Dict[str, Any] | None = None + self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None ): """Prepend a system message to the conversation.""" self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages diff --git a/simplemind/utils.py b/simplemind/utils.py index baabbae..0226686 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -5,19 +5,22 @@ from .providers import BaseProvider, providers _PROVIDER_NAMES = [provider.NAME.lower() for provider in providers] -def find_provider(provider_name: str) -> BaseProvider: +def find_provider(provider_name: str | None) -> BaseProvider: """ Find and instantiate a provider by name. Parameters: - provider_name (Union[str, None]): The name of the provider to find. + provider_name (Union[str, None]): The name of the provider to find. Returns: - An instance of the provider class if found. + An instance of the provider class if found. Raises: - ValueError: If the provider is not found, with a suggestion for the closest match. + ValueError: If the provider is not specified or is not found, with a suggestion for the closest match. """ + if provider_name is None: + raise ValueError("No provider specified.") + # Find the provider by name. for provider_class in providers: if provider_class.NAME.lower() == provider_name.lower(): @@ -28,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider: provider_found = difflib.get_close_matches( provider_name.lower(), _PROVIDER_NAMES, n=1 ) - if provider_found: raise ValueError( f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?" ) - else: - raise ValueError(f"Provider {provider_name} not found.") + raise ValueError(f"Provider {provider_name} not found.")