From 58e3c6a3bd048fba8532a8400aeba078e9c266a9 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Tue, 29 Oct 2024 07:11:19 -0400 Subject: [PATCH] Refactor code to add support for generating custom temperature text --- examples/custom_temperature.py | 11 +++++++++++ simplemind/__init__.py | 4 ++-- simplemind/providers/anthropic.py | 10 +++++++--- simplemind/providers/groq.py | 25 ++++++++++++++++++++----- simplemind/providers/openai.py | 12 ++++++------ simplemind/providers/xai.py | 12 ++++++++---- 6 files changed, 54 insertions(+), 20 deletions(-) create mode 100644 examples/custom_temperature.py diff --git a/examples/custom_temperature.py b/examples/custom_temperature.py new file mode 100644 index 0000000..9123105 --- /dev/null +++ b/examples/custom_temperature.py @@ -0,0 +1,11 @@ +from _context import sm + +text = sm.generate_text( + prompt="Write a short summary of 'Pride and Prejudice'.", + llm_provider="openai", + llm_model="gpt-4o", + temperature=0.5, + max_tokens=150, +) + +print(text) diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 2e06d1c..8438abe 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -23,11 +23,11 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N ) -def generate_text(prompt, *, llm_model=None, llm_provider=None): +def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs): """Generate text from a given prompt.""" provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER) - return provider.generate_text(prompt=prompt, llm_model=llm_model) + return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs) __all__ = [ diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index b164c83..5bb2886 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -30,7 +30,7 @@ class Anthropic(BaseProvider): """A client patched with Instructor.""" return instructor.from_anthropic(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: "Conversation", **kwargs): """Send a conversation to the Anthropic API.""" from ..models import Message @@ -42,6 +42,7 @@ class Anthropic(BaseProvider): model=conversation.llm_model or DEFAULT_MODEL, messages=messages, max_tokens=DEFAULT_MAX_TOKENS, + **kwargs, ) # Get the response content from the Anthropic response @@ -62,13 +63,16 @@ class Anthropic(BaseProvider): ) return response - def generate_text(self, prompt, *, llm_model): + def generate_text(self, prompt, *, llm_model, **kwargs): messages = [ {"role": "user", "content": prompt}, ] response = self.client.messages.create( - model=llm_model, messages=messages, max_tokens=DEFAULT_MAX_TOKENS + model=llm_model, + messages=messages, + max_tokens=DEFAULT_MAX_TOKENS, + **kwargs, ) return response.content[0].text diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 9fd4b71..ee1175d 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -29,7 +29,11 @@ class Groq(BaseProvider): """A client patched with Instructor.""" return instructor.from_groq(self.client) - def send_conversation(self, conversation: "Conversation") -> "Message": + def send_conversation( + self, + conversation: "Conversation", + **kwargs, + ) -> "Message": """Send a conversation to the Groq API.""" from ..models import Message @@ -38,7 +42,9 @@ class Groq(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages + model=conversation.llm_model or DEFAULT_MODEL, + messages=messages, + **kwargs, ) # Get the response content from the Groq response @@ -53,7 +59,7 @@ class Groq(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model): + def structured_response(self, prompt: str, response_model, **kwargs): # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, @@ -62,16 +68,25 @@ class Groq(BaseProvider): response = self.structured_client.chat.completions.create( messages=messages, response_model=response_model, + **kwargs, ) return response - def generate_text(self, prompt: str, *, llm_model: str): + def generate_text( + self, + prompt: str, + *, + llm_model: str, + **kwargs, + ): messages = [ {"role": "user", "content": prompt}, ] response = self.structured_client.chat.completions.create( - messages=messages, model=llm_model + messages=messages, + model=llm_model, + **kwargs, ) return response.choices[0].message.content diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index ee82105..cfa8471 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -29,7 +29,7 @@ class OpenAI(BaseProvider): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: "Conversation", **kwargs): """Send a conversation to the OpenAI API.""" from ..models import Message @@ -38,7 +38,7 @@ class OpenAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages + model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs ) # Get the response content from the OpenAI response @@ -53,24 +53,24 @@ class OpenAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model: str): + def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, ] response = self.structured_client.chat.completions.create( - messages=messages, model=llm_model, response_model=response_model + messages=messages, model=llm_model, response_model=response_model, **kwargs ) return response - def generate_text(self, prompt, *, llm_model): + def generate_text(self, prompt, *, llm_model, **kwargs): messages = [ {"role": "user", "content": prompt}, ] response = self.client.chat.completions.create( - messages=messages, model=llm_model + messages=messages, model=llm_model, **kwargs ) return response.choices[0].message.content diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 1d6e58b..a5dbac6 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -34,7 +34,7 @@ class XAI(BaseProvider): """A client patched with Instructor.""" return instructor.from_openai(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: "Conversation", **kwargs): """Send a conversation to the OpenAI API.""" from ..models import Message @@ -43,7 +43,9 @@ class XAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, messages=messages + model=conversation.llm_model or DEFAULT_MODEL, + messages=messages, + **kwargs, ) # Get the response content from the OpenAI response @@ -61,13 +63,15 @@ class XAI(BaseProvider): def structured_response(self, prompt: str, response_model, *, llm_model): raise NotImplementedError("XAI does not support structured responses") - def generate_text(self, prompt, *, llm_model): + def generate_text(self, prompt, *, llm_model, **kwargs): messages = [ {"role": "user", "content": prompt}, ] response = self.client.chat.completions.create( - messages=messages, model=llm_model + messages=messages, + model=llm_model, + **kwargs, ) return response.choices[0].message.content