From bd0c739c9aef5c63add07c65e7a83554cf36e5d8 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Thu, 31 Oct 2024 11:42:38 +0530 Subject: [PATCH] improved type hinting --- simplemind/__init__.py | 23 +++++++++++++++++------ simplemind/models.py | 13 +++++++------ simplemind/providers/anthropic.py | 8 +++----- simplemind/providers/groq.py | 4 +--- simplemind/providers/ollama.py | 8 +++++--- simplemind/providers/openai.py | 10 +++++----- simplemind/providers/xai.py | 8 +++----- simplemind/utils.py | 1 - 8 files changed, 41 insertions(+), 34 deletions(-) diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 3d87feb..49a3f99 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type +from typing import List, Type from .models import Conversation, BasePlugin, BaseModel from .utils import find_provider @@ -56,9 +56,9 @@ class Session: def create_conversation( *, - llm_model=None, - llm_provider=None, - plugins: Optional[List[BasePlugin]] = None, + llm_model: str | None = None, + llm_provider: str | None = None, + plugins: List[BasePlugin] | None = None, **kwargs, ) -> Conversation: """Create a new conversation.""" @@ -77,7 +77,12 @@ def create_conversation( def generate_data( - prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs + prompt: str, + *, + llm_model: str | None = None, + llm_provider: str | None = None, + response_model: Type[BaseModel] = None, + **kwargs, ) -> BaseModel: """Generate structured data from a given prompt.""" @@ -92,7 +97,13 @@ def generate_data( ) -def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str: +def generate_text( + prompt: str, + *, + llm_model: str | None = None, + llm_provider: str | None = None, + **kwargs, +) -> str: """Generate text from a given prompt.""" # Find the provider. diff --git a/simplemind/models.py b/simplemind/models.py index ec9f173..c2a8e67 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,5 +1,4 @@ import uuid -from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Literal, Optional @@ -73,7 +72,7 @@ class Conversation(SMBaseModel): messages: List[Message] = [] llm_model: Optional[str] = None llm_provider: Optional[str] = None - plugins: List[Any] = [] + plugins: List[BasePlugin] = [] def __str__(self): return f"" @@ -99,7 +98,7 @@ class Conversation(SMBaseModel): pass def prepend_system_message( - self, role: str, text: str, meta: Optional[Dict[str, Any]] = None + self, role: str, text: str, meta: Dict[str, Any] | None = None ): """Prepend a system message to the conversation.""" self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages @@ -127,7 +126,9 @@ class Conversation(SMBaseModel): self.messages.append(Message(role=role, text=text, meta=meta)) def send( - self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None + self, + llm_model: str | None = None, + llm_provider: str | None = None, ) -> Message: """Send the conversation to the LLM.""" @@ -156,10 +157,10 @@ class Conversation(SMBaseModel): return response - def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]: + def get_last_message(self, role: MESSAGE_ROLE) -> Message | None: """Get the last message with the given role.""" return next((m for m in reversed(self.messages) if m.role == role), None) - def add_plugin(self, plugin: Any): + def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" self.plugins.append(plugin) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 1028b09..77e4d20 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,3 @@ -from typing import Union - import anthropic import instructor @@ -15,7 +13,7 @@ class Anthropic(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property @@ -57,13 +55,13 @@ class Anthropic(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, model, response_model, **kwargs): + def structured_response(self, model: str, response_model, **kwargs): response = self.structured_client.messages.create( model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 47f14ae..62bd54c 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,5 +1,3 @@ -from typing import Union - import groq import instructor @@ -14,7 +12,7 @@ class Groq(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 5ba019a..24a1ea7 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -15,7 +15,7 @@ class Ollama(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL TIMEOUT = DEFAULT_TIMEOUT - def __init__(self, host_url: str = None): + def __init__(self, host_url: str | None = None): self.host_url = host_url or settings.OLLAMA_HOST_URL @property @@ -57,7 +57,9 @@ class Ollama(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + def structured_response( + self, prompt: str, response_model, *, llm_model: str, **kwargs + ): messages = [ {"role": "user", "content": prompt}, ] @@ -70,7 +72,7 @@ class Ollama(BaseProvider): ) return response - def generate_text(self, prompt, *, llm_model): + def generate_text(self, prompt: str, *, llm_model: str): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 3895096..2a296cc 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,5 +1,3 @@ -from typing import Union - import instructor import openai as oa @@ -14,7 +12,7 @@ class OpenAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property @@ -53,7 +51,9 @@ class OpenAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + def structured_response( + self, prompt: str, response_model, *, llm_model: str, **kwargs + ): # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, @@ -67,7 +67,7 @@ class OpenAI(BaseProvider): ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 8b8b84d..697c407 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,5 +1,3 @@ -from typing import Union - import instructor import openai as oa @@ -16,7 +14,7 @@ class XAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property @@ -60,10 +58,10 @@ class XAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model, *, llm_model): + def structured_response(self, prompt: str, response_model, *, llm_model: str): raise NotImplementedError("XAI does not support structured responses") - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/utils.py b/simplemind/utils.py index 67477ca..67723c5 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,5 +1,4 @@ import difflib -from typing import Optional, Type from .providers import providers, BaseProvider