Refactor code to add support for generating custom temperature text

This commit is contained in:
2024-10-29 07:11:19 -04:00
parent 6ad5e21d0a
commit 58e3c6a3bd
6 changed files with 54 additions and 20 deletions
+11
View File
@@ -0,0 +1,11 @@
from _context import sm
text = sm.generate_text(
prompt="Write a short summary of 'Pride and Prejudice'.",
llm_provider="openai",
llm_model="gpt-4o",
temperature=0.5,
max_tokens=150,
)
print(text)
+2 -2
View File
@@ -23,11 +23,11 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
)
def generate_text(prompt, *, llm_model=None, llm_provider=None):
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
"""Generate text from a given prompt."""
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
return provider.generate_text(prompt=prompt, llm_model=llm_model)
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
__all__ = [
+7 -3
View File
@@ -30,7 +30,7 @@ class Anthropic(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation"):
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the Anthropic API."""
from ..models import Message
@@ -42,6 +42,7 @@ class Anthropic(BaseProvider):
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
# Get the response content from the Anthropic response
@@ -62,13 +63,16 @@ class Anthropic(BaseProvider):
)
return response
def generate_text(self, prompt, *, llm_model):
def generate_text(self, prompt, *, llm_model, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
response = self.client.messages.create(
model=llm_model, messages=messages, max_tokens=DEFAULT_MAX_TOKENS
model=llm_model,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response.content[0].text
+20 -5
View File
@@ -29,7 +29,11 @@ class Groq(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
def send_conversation(self, conversation: "Conversation") -> "Message":
def send_conversation(
self,
conversation: "Conversation",
**kwargs,
) -> "Message":
"""Send a conversation to the Groq API."""
from ..models import Message
@@ -38,7 +42,9 @@ class Groq(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**kwargs,
)
# Get the response content from the Groq response
@@ -53,7 +59,7 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model):
def structured_response(self, prompt: str, response_model, **kwargs):
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
@@ -62,16 +68,25 @@ class Groq(BaseProvider):
response = self.structured_client.chat.completions.create(
messages=messages,
response_model=response_model,
**kwargs,
)
return response
def generate_text(self, prompt: str, *, llm_model: str):
def generate_text(
self,
prompt: str,
*,
llm_model: str,
**kwargs,
):
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages, model=llm_model
messages=messages,
model=llm_model,
**kwargs,
)
return response.choices[0].message.content
+6 -6
View File
@@ -29,7 +29,7 @@ class OpenAI(BaseProvider):
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation"):
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -38,7 +38,7 @@ class OpenAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
)
# Get the response content from the OpenAI response
@@ -53,24 +53,24 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str):
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages, model=llm_model, response_model=response_model
messages=messages, model=llm_model, response_model=response_model, **kwargs
)
return response
def generate_text(self, prompt, *, llm_model):
def generate_text(self, prompt, *, llm_model, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model
messages=messages, model=llm_model, **kwargs
)
return response.choices[0].message.content
+8 -4
View File
@@ -34,7 +34,7 @@ class XAI(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation"):
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -43,7 +43,9 @@ class XAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**kwargs,
)
# Get the response content from the OpenAI response
@@ -61,13 +63,15 @@ class XAI(BaseProvider):
def structured_response(self, prompt: str, response_model, *, llm_model):
raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt, *, llm_model):
def generate_text(self, prompt, *, llm_model, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model
messages=messages,
model=llm_model,
**kwargs,
)
return response.choices[0].message.content