coderabbit suggestions fix

This commit is contained in:
Siddhesh Agarwal
2024-10-31 12:58:00 +05:30
parent ec4f6f9c06
commit 42fc0e6bc5
9 changed files with 47 additions and 22 deletions
+1 -1
View File
@@ -8,7 +8,7 @@ import os
import pickle
class ContextualMemoryPlugin:
class ContextualMemoryPlugin(sm.BasePlugin):
def __init__(
self,
api_key: str,
+1 -1
View File
@@ -25,7 +25,7 @@ class QuotesList(BaseModel):
quotes: List[MovieQuote]
def gen_quotes(n=10) -> Iterator[MovieQuote]:
def gen_quotes(n: int = 10) -> Iterator[MovieQuote]:
"""Generate a list of quotes from famous movies."""
for q in sm.generate_data(
+4 -2
View File
@@ -1,9 +1,11 @@
from _context import sm
class MathPlugin:
class MathPlugin(sm.BasePlugin):
def send_hook(self, conversation: sm.Conversation):
last_user_message = conversation.get_last_message(role="user")
if last_user_message is None:
return
if "calculate" in last_user_message.text.lower():
expression = last_user_message.text.lower().replace("calculate", "").strip()
try:
@@ -14,7 +16,7 @@ class MathPlugin:
except Exception:
conversation.add_message(
role="assistant",
text="I'm sorry, I couldn't compute that expression.",
text="I'm sorry, I couldn't compute that expression. Please try again.",
)
+1 -1
View File
@@ -8,7 +8,7 @@ class ConversationPlugin(sm.BasePlugin):
print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n")
def have_conversation(rounds=3):
def have_conversation(rounds: int = 3):
# Create two conversations - one for each AI
with (
sm.create_conversation(
+1 -1
View File
@@ -57,7 +57,7 @@ class Anthropic(BaseProvider):
def structured_response(self, model: str, response_model, **kwargs):
response = self.structured_client.messages.create(
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
)
return response
+25 -11
View File
@@ -23,7 +23,7 @@ class Gemini(BaseProvider):
if not self.api_key:
raise ValueError("Gemini API key is required")
self.model_name = model_name
return genai.GenerativeModel(model_name=model_name)
return genai.GenerativeModel(model_name=self.model_name)
@property
def structured_client(self):
@@ -42,9 +42,13 @@ class Gemini(BaseProvider):
for msg in conversation.messages
]
response = self.structured_client.chat.completions.create(
messages=messages, response_model=None
)
try:
response = self.structured_client.chat.completions.create(
messages=messages, response_model=None
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
# Create and return a properly formatted Message instance
return Message(
@@ -57,15 +61,25 @@ class Gemini(BaseProvider):
def structured_response(self, prompt: str, response_model, **kwargs):
"""Send a structured response to the Gemini API."""
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=response_model,
)
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=response_model,
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}"
) from e
return response
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], response_model=None
)
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], response_model=None
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
return str(response)
+3 -3
View File
@@ -26,7 +26,7 @@ class Ollama(BaseProvider):
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@property
def structured_client(self):
def structured_client(self) -> instructor.Instructor:
"""A client patched with Instructor."""
return instructor.from_openai(
OpenAI(
@@ -36,7 +36,7 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON,
)
def send_conversation(self, conversation: "Conversation"):
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Ollama API."""
from ..models import Message
@@ -81,4 +81,4 @@ class Ollama(BaseProvider):
messages=messages, model=llm_model or self.DEFAULT_MODEL
)
return response.get("message").get("content")
return response.get("message", {}).get("content", "")
+2 -2
View File
@@ -52,7 +52,7 @@ class OpenAI(BaseProvider):
)
def structured_response(
self, prompt: str, response_model, *, llm_model: str, **kwargs
self, prompt: str, response_model, *, llm_model: str | None = None, **kwargs
):
# Ensure messages are provided in kwargs
messages = [
@@ -67,7 +67,7 @@ class OpenAI(BaseProvider):
)
return response
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
+9
View File
@@ -0,0 +1,9 @@
import simplemind as sm
res = sm.generate_text(
prompt="Wish you a happy Diwali!",
llm_model="gpt-4o",
llm_provider="openai",
)
print(res)