mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Refactor code to add support for generating custom temperature text
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from _context import sm
|
||||
|
||||
text = sm.generate_text(
|
||||
prompt="Write a short summary of 'Pride and Prejudice'.",
|
||||
llm_provider="openai",
|
||||
llm_model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
print(text)
|
||||
@@ -23,11 +23,11 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
|
||||
)
|
||||
|
||||
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None):
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
||||
"""Generate text from a given prompt."""
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model)
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -30,7 +30,7 @@ class Anthropic(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -42,6 +42,7 @@ class Anthropic(BaseProvider):
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
@@ -62,13 +63,16 @@ class Anthropic(BaseProvider):
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model):
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=llm_model, messages=messages, max_tokens=DEFAULT_MAX_TOKENS
|
||||
model=llm_model,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@@ -29,7 +29,11 @@ class Groq(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Groq API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -38,7 +42,9 @@ class Groq(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
@@ -53,7 +59,7 @@ class Groq(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model):
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
@@ -62,16 +68,25 @@ class Groq(BaseProvider):
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str):
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages, model=llm_model
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -29,7 +29,7 @@ class OpenAI(BaseProvider):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -38,7 +38,7 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -53,24 +53,24 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str):
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages, model=llm_model, response_model=response_model
|
||||
messages=messages, model=llm_model, response_model=response_model, **kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model):
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model
|
||||
messages=messages, model=llm_model, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -34,7 +34,7 @@ class XAI(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -43,7 +43,9 @@ class XAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -61,13 +63,15 @@ class XAI(BaseProvider):
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model):
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
def generate_text(self, prompt, *, llm_model):
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
Reference in New Issue
Block a user