mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
improved type hinting
This commit is contained in:
+17
-6
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Type
|
||||
|
||||
from .models import Conversation, BasePlugin, BaseModel
|
||||
from .utils import find_provider
|
||||
@@ -56,9 +56,9 @@ class Session:
|
||||
|
||||
def create_conversation(
|
||||
*,
|
||||
llm_model=None,
|
||||
llm_provider=None,
|
||||
plugins: Optional[List[BasePlugin]] = None,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
plugins: List[BasePlugin] | None = None,
|
||||
**kwargs,
|
||||
) -> Conversation:
|
||||
"""Create a new conversation."""
|
||||
@@ -77,7 +77,12 @@ def create_conversation(
|
||||
|
||||
|
||||
def generate_data(
|
||||
prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
response_model: Type[BaseModel] = None,
|
||||
**kwargs,
|
||||
) -> BaseModel:
|
||||
"""Generate structured data from a given prompt."""
|
||||
|
||||
@@ -92,7 +97,13 @@ def generate_data(
|
||||
)
|
||||
|
||||
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str:
|
||||
def generate_text(
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate text from a given prompt."""
|
||||
|
||||
# Find the provider.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
@@ -73,7 +72,7 @@ class Conversation(SMBaseModel):
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
plugins: List[Any] = []
|
||||
plugins: List[BasePlugin] = []
|
||||
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
@@ -99,7 +98,7 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
def prepend_system_message(
|
||||
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
self, role: str, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
||||
@@ -127,7 +126,9 @@ class Conversation(SMBaseModel):
|
||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||
|
||||
def send(
|
||||
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
||||
self,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
@@ -156,10 +157,10 @@ class Conversation(SMBaseModel):
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||
"""Get the last message with the given role."""
|
||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||
|
||||
def add_plugin(self, plugin: Any):
|
||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Add a plugin to the conversation."""
|
||||
self.plugins.append(plugin)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
|
||||
@@ -15,7 +13,7 @@ class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@@ -57,13 +55,13 @@ class Anthropic(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model, response_model, **kwargs):
|
||||
def structured_response(self, model: str, response_model, **kwargs):
|
||||
response = self.structured_client.messages.create(
|
||||
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
|
||||
@@ -14,7 +12,7 @@ class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
|
||||
@@ -15,7 +15,7 @@ class Ollama(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(self, host_url: str = None):
|
||||
def __init__(self, host_url: str | None = None):
|
||||
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
||||
|
||||
@property
|
||||
@@ -57,7 +57,9 @@ class Ollama(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
||||
def structured_response(
|
||||
self, prompt: str, response_model, *, llm_model: str, **kwargs
|
||||
):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -70,7 +72,7 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model):
|
||||
def generate_text(self, prompt: str, *, llm_model: str):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
@@ -14,7 +12,7 @@ class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@@ -53,7 +51,9 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
||||
def structured_response(
|
||||
self, prompt: str, response_model, *, llm_model: str, **kwargs
|
||||
):
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
@@ -67,7 +67,7 @@ class OpenAI(BaseProvider):
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
@@ -16,7 +14,7 @@ class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@@ -60,10 +58,10 @@ class XAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model):
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model: str):
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import difflib
|
||||
from typing import Optional, Type
|
||||
|
||||
from .providers import providers, BaseProvider
|
||||
|
||||
|
||||
Reference in New Issue
Block a user