improved type hinting

This commit is contained in:
Siddhesh Agarwal
2024-10-31 11:42:38 +05:30
parent 473a054afa
commit bd0c739c9a
8 changed files with 41 additions and 34 deletions
+17 -6
View File
@@ -1,4 +1,4 @@
from typing import List, Optional, Type
from typing import List, Type
from .models import Conversation, BasePlugin, BaseModel
from .utils import find_provider
@@ -56,9 +56,9 @@ class Session:
def create_conversation(
*,
llm_model=None,
llm_provider=None,
plugins: Optional[List[BasePlugin]] = None,
llm_model: str | None = None,
llm_provider: str | None = None,
plugins: List[BasePlugin] | None = None,
**kwargs,
) -> Conversation:
"""Create a new conversation."""
@@ -77,7 +77,12 @@ def create_conversation(
def generate_data(
prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs
prompt: str,
*,
llm_model: str | None = None,
llm_provider: str | None = None,
response_model: Type[BaseModel] = None,
**kwargs,
) -> BaseModel:
"""Generate structured data from a given prompt."""
@@ -92,7 +97,13 @@ def generate_data(
)
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str:
def generate_text(
prompt: str,
*,
llm_model: str | None = None,
llm_provider: str | None = None,
**kwargs,
) -> str:
"""Generate text from a given prompt."""
# Find the provider.
+7 -6
View File
@@ -1,5 +1,4 @@
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
@@ -73,7 +72,7 @@ class Conversation(SMBaseModel):
messages: List[Message] = []
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
plugins: List[Any] = []
plugins: List[BasePlugin] = []
def __str__(self):
return f"<Conversation id={self.id!r}>"
@@ -99,7 +98,7 @@ class Conversation(SMBaseModel):
pass
def prepend_system_message(
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
self, role: str, text: str, meta: Dict[str, Any] | None = None
):
"""Prepend a system message to the conversation."""
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
@@ -127,7 +126,9 @@ class Conversation(SMBaseModel):
self.messages.append(Message(role=role, text=text, meta=meta))
def send(
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
self,
llm_model: str | None = None,
llm_provider: str | None = None,
) -> Message:
"""Send the conversation to the LLM."""
@@ -156,10 +157,10 @@ class Conversation(SMBaseModel):
return response
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
"""Get the last message with the given role."""
return next((m for m in reversed(self.messages) if m.role == role), None)
def add_plugin(self, plugin: Any):
def add_plugin(self, plugin: BasePlugin) -> None:
"""Add a plugin to the conversation."""
self.plugins.append(plugin)
+3 -5
View File
@@ -1,5 +1,3 @@
from typing import Union
import anthropic
import instructor
@@ -15,7 +13,7 @@ class Anthropic(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@@ -57,13 +55,13 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, model, response_model, **kwargs):
def structured_response(self, model: str, response_model, **kwargs):
response = self.structured_client.messages.create(
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
+1 -3
View File
@@ -1,5 +1,3 @@
from typing import Union
import groq
import instructor
@@ -14,7 +12,7 @@ class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
+5 -3
View File
@@ -15,7 +15,7 @@ class Ollama(BaseProvider):
DEFAULT_MODEL = DEFAULT_MODEL
TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str = None):
def __init__(self, host_url: str | None = None):
self.host_url = host_url or settings.OLLAMA_HOST_URL
@property
@@ -57,7 +57,9 @@ class Ollama(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
def structured_response(
self, prompt: str, response_model, *, llm_model: str, **kwargs
):
messages = [
{"role": "user", "content": prompt},
]
@@ -70,7 +72,7 @@ class Ollama(BaseProvider):
)
return response
def generate_text(self, prompt, *, llm_model):
def generate_text(self, prompt: str, *, llm_model: str):
messages = [
{"role": "user", "content": prompt},
]
+5 -5
View File
@@ -1,5 +1,3 @@
from typing import Union
import instructor
import openai as oa
@@ -14,7 +12,7 @@ class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@@ -53,7 +51,9 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
def structured_response(
self, prompt: str, response_model, *, llm_model: str, **kwargs
):
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
@@ -67,7 +67,7 @@ class OpenAI(BaseProvider):
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
+3 -5
View File
@@ -1,5 +1,3 @@
from typing import Union
import instructor
import openai as oa
@@ -16,7 +14,7 @@ class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@@ -60,10 +58,10 @@ class XAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, *, llm_model):
def structured_response(self, prompt: str, response_model, *, llm_model: str):
raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
-1
View File
@@ -1,5 +1,4 @@
import difflib
from typing import Optional, Type
from .providers import providers, BaseProvider