mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b5a901efaf | |||
| 9ccef9abdc | |||
| 3421de0fc1 | |||
| 54b0007947 | |||
| 90af44ace0 | |||
| cff3bff3d5 | |||
| 3abbb79f6c | |||
| 59c1bd3a0f | |||
| 8181f37fed |
@@ -5,3 +5,4 @@ export OLLAMA_HOST_URL=""
|
||||
export OPENAI_API_KEY=""
|
||||
export XAI_API_KEY=""
|
||||
export AMAZON_PROFILE_NAME=""
|
||||
export DEEPSEEK_API_KEY=""
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
## 0.3.3 (2024-02-08)
|
||||
|
||||
- Improve openai provider by removing debug print statements.
|
||||
|
||||
## 0.3.2 (2024-01-27)
|
||||
|
||||
- Improve Deepseek provider.
|
||||
|
||||
@@ -88,7 +88,7 @@ First, authenticate your API keys by setting them in the environment variables:
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `DEEPSEEK_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
|
||||
Next, import Simplemind and start using it:
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
+117
-11
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
@@ -14,6 +15,78 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class GroqTool(BaseTool):
|
||||
def get_response_schema(self):
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
tool_used = False
|
||||
|
||||
# Get the message from the response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[
|
||||
0
|
||||
] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
self.function_result = str(self.raw_func(**function_args))
|
||||
self.tool_id = tool_call.id
|
||||
tool_used = True
|
||||
|
||||
# Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
# Add tool response message
|
||||
messages.append(self.get_response_schema())
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
@@ -46,28 +119,56 @@ class Groq(BaseProvider):
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Groq API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
assistant_message = response.choices[0].message
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
print(response)
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(
|
||||
**request_kwargs
|
||||
)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
@@ -136,3 +237,8 @@ class Groq(BaseProvider):
|
||||
raise RuntimeError(
|
||||
f"Failed to generate streaming text with Groq API: {e}"
|
||||
) from e
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for Groq."""
|
||||
return GroqTool
|
||||
@@ -38,9 +38,7 @@ class OpenAITool(BaseTool):
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[
|
||||
0
|
||||
] # Get the first tool call
|
||||
tool_call = assistant_message.tool_calls[0] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
@@ -128,8 +126,7 @@ class OpenAI(BaseProvider):
|
||||
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
# Set up tools if provided
|
||||
@@ -154,15 +151,12 @@ class OpenAI(BaseProvider):
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
print(response)
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(
|
||||
**request_kwargs
|
||||
)
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
@@ -182,13 +176,21 @@ class OpenAI(BaseProvider):
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the OpenAI API."""
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -199,12 +201,24 @@ class OpenAI(BaseProvider):
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -214,15 +228,27 @@ class OpenAI(BaseProvider):
|
||||
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the OpenAI API.
|
||||
|
||||
Yields chunks of text as they are generated by the model.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
|
||||
Reference in New Issue
Block a user