mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Add Session class to manage API call configurations and enhance conversation creation
This commit is contained in:
+52
-2
@@ -1,10 +1,59 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from .models import Conversation, BasePlugin
|
||||
from .models import Conversation, BasePlugin, BaseModel
|
||||
from .utils import find_provider
|
||||
from .settings import settings
|
||||
|
||||
|
||||
class Session:
|
||||
"""A session object that maintains configuration across multiple API calls.
|
||||
|
||||
Similar to `requests.Session`, this allows you to specify default settings
|
||||
that will be used for all operations within the session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
|
||||
llm_model: str = settings.DEFAULT_LLM_MODEL,
|
||||
**kwargs,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.llm_model = llm_model
|
||||
self.default_kwargs = kwargs
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return generate_text(
|
||||
prompt=prompt,
|
||||
llm_provider=self.llm_provider,
|
||||
llm_model=self.llm_model,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def generate_data(
|
||||
self, prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""Generate structured data using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return generate_data(
|
||||
prompt=prompt,
|
||||
response_model=response_model,
|
||||
llm_provider=self.llm_provider,
|
||||
llm_model=self.llm_model,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def create_conversation(self, **kwargs) -> "Conversation":
|
||||
"""Create a conversation using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return create_conversation(
|
||||
llm_provider=self.llm_provider, llm_model=self.llm_model, **merged_kwargs
|
||||
)
|
||||
|
||||
|
||||
def create_conversation(
|
||||
llm_model=None, llm_provider=None, *, plugins: Optional[List[BasePlugin]] = None
|
||||
):
|
||||
@@ -53,4 +102,5 @@ __all__ = [
|
||||
"generate_text",
|
||||
"settings",
|
||||
"BasePlugin",
|
||||
"Session",
|
||||
]
|
||||
|
||||
@@ -17,6 +17,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
||||
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
|
||||
Reference in New Issue
Block a user