Add Session class to manage API call configurations and enhance conversation creation

This commit is contained in:
2024-10-30 18:24:45 -04:00
parent e9e47e27a1
commit 8474f101f2
2 changed files with 53 additions and 2 deletions
+52 -2
View File
@@ -1,10 +1,59 @@
from typing import List, Optional
from typing import List, Optional, Type
from .models import Conversation, BasePlugin
from .models import Conversation, BasePlugin, BaseModel
from .utils import find_provider
from .settings import settings
class Session:
"""A session object that maintains configuration across multiple API calls.
Similar to `requests.Session`, this allows you to specify default settings
that will be used for all operations within the session.
"""
def __init__(
self,
*,
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
llm_model: str = settings.DEFAULT_LLM_MODEL,
**kwargs,
):
self.llm_provider = llm_provider
self.llm_model = llm_model
self.default_kwargs = kwargs
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the session's default provider and model."""
merged_kwargs = {**self.default_kwargs, **kwargs}
return generate_text(
prompt=prompt,
llm_provider=self.llm_provider,
llm_model=self.llm_model,
**merged_kwargs,
)
def generate_data(
self, prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""Generate structured data using the session's default provider and model."""
merged_kwargs = {**self.default_kwargs, **kwargs}
return generate_data(
prompt=prompt,
response_model=response_model,
llm_provider=self.llm_provider,
llm_model=self.llm_model,
**merged_kwargs,
)
def create_conversation(self, **kwargs) -> "Conversation":
"""Create a conversation using the session's default provider and model."""
merged_kwargs = {**self.default_kwargs, **kwargs}
return create_conversation(
llm_provider=self.llm_provider, llm_model=self.llm_model, **merged_kwargs
)
def create_conversation(
llm_model=None, llm_provider=None, *, plugins: Optional[List[BasePlugin]] = None
):
@@ -53,4 +102,5 @@ __all__ = [
"generate_text",
"settings",
"BasePlugin",
"Session",
]
+1
View File
@@ -17,6 +17,7 @@ class Settings(BaseSettings):
)
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"