Refactor create_conversation and generate_data functions to improve type hints and maintain consistency in return types

This commit is contained in:
2024-10-30 19:07:04 -04:00
parent 1f66bac645
commit 74c09d5c87
+6 -6
View File
@@ -46,7 +46,7 @@ class Session:
**merged_kwargs,
)
def create_conversation(self, **kwargs) -> "Conversation":
def create_conversation(self, **kwargs) -> Conversation:
"""Create a conversation using the session's default provider and model."""
merged_kwargs = {**self.default_kwargs, **kwargs}
return create_conversation(
@@ -60,11 +60,9 @@ def create_conversation(
llm_provider=None,
plugins: Optional[List[BasePlugin]] = None,
**kwargs,
):
) -> Conversation:
"""Create a new conversation."""
# Note: kwargs are here to eat up any extra arguments passed in from sessions.
# Create the conversation.
conversation = Conversation(
llm_model=llm_model,
@@ -78,7 +76,9 @@ def create_conversation(
return conversation
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
def generate_data(
prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs
) -> BaseModel:
"""Generate structured data from a given prompt."""
# Find the provider.
@@ -92,7 +92,7 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
)
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str:
"""Generate text from a given prompt."""
# Find the provider.