mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
add tool decorator and example
This commit is contained in:
@@ -303,7 +303,70 @@ def get_weather(
|
||||
return f"42 {unit}"
|
||||
```
|
||||
|
||||
Functions can be defined with type hints and Pydantic models for validation. The AI will intelligently choose when to call the functions and incorporate the results into its responses.
|
||||
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
|
||||
|
||||
#### 🪄 Using LLM for automatic tool definition (Experimental)
|
||||
|
||||
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
|
||||
|
||||
```python
|
||||
@simplemind.tool(llm_provider="anthropic")
|
||||
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
r = 6371
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
```
|
||||
Notice how we have not added any docstrings or `Field` for the function.
|
||||
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "haversine",
|
||||
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lat1": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the first point in decimal degrees",
|
||||
},
|
||||
"lon1": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the first point in decimal degrees",
|
||||
},
|
||||
"lat2": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the second point in decimal degrees",
|
||||
},
|
||||
"lon2": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the second point in decimal degrees",
|
||||
}
|
||||
},
|
||||
"required": ["lat1", "lon1", "lat2", "lon2"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The decorated function can then be used like any other tool with the conversation API.
|
||||
|
||||
```python
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message("user", "How far is London from my location")
|
||||
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
|
||||
```
|
||||
|
||||
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
|
||||
|
||||
### Logging
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import math
|
||||
|
||||
from _context import sm
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@sm.tool(llm_provider="anthropic")
|
||||
def haversine(
|
||||
lat1: float,
|
||||
lon1: float,
|
||||
lat2: float,
|
||||
lon2: float,
|
||||
unit: Literal["km", "miles"],
|
||||
) -> float:
|
||||
r = 6378.0937 if unit == "km" else 3961
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
|
||||
|
||||
def get_user_location() -> str:
|
||||
"""Get the closest city from the user"""
|
||||
return "San Francisco"
|
||||
|
||||
|
||||
def get_coords(
|
||||
city_name: str = Field(
|
||||
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
|
||||
),
|
||||
):
|
||||
"""Get latitude and logitude of a City."""
|
||||
_data = {
|
||||
"Rome": (41.9028, 12.4964),
|
||||
"London": (51.5074, -0.1278),
|
||||
"Madrid": (40.4168, -3.7038),
|
||||
"San Francisco": (37.7749, -122.4194),
|
||||
"Los Angeles": (34.0522, -118.2437),
|
||||
}
|
||||
|
||||
return _data.get(city_name)
|
||||
|
||||
|
||||
def distance_calculator(prompt: str):
|
||||
conversation = sm.create_conversation(llm_provider="anthropic")
|
||||
conversation.add_message("user", prompt)
|
||||
return conversation.send(
|
||||
tools=[get_user_location, get_coords, haversine]
|
||||
).text
|
||||
|
||||
|
||||
print(distance_calculator("How far is London from where I am?"))
|
||||
# Prints something like:
|
||||
"""
|
||||
The distance between your location (San Francisco) and London is approximately 5,357 miles.
|
||||
"""
|
||||
|
||||
print(
|
||||
distance_calculator(
|
||||
"What is the distance between Rome and Madrid in Kilometers?"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
The distance between Rome and Madrid is approximately 1,366 kilometers.
|
||||
"""
|
||||
+27
-1
@@ -1,4 +1,5 @@
|
||||
from typing import Generator, List, Type
|
||||
import inspect
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from .models import BaseModel, BasePlugin, Conversation
|
||||
from .settings import settings
|
||||
@@ -127,6 +128,30 @@ def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
def tool(
|
||||
llm_provider: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
):
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
def decorator(func: Callable):
|
||||
sig = inspect.signature(func)
|
||||
res = generate_data(
|
||||
(
|
||||
"Based on this function signature, fill up the required fieds."
|
||||
f"\nSignature: {func.__name__}{sig}"
|
||||
"Make sure to properly add the required field in `required` if there are no defaults"
|
||||
),
|
||||
llm_provider=llm_provider,
|
||||
response_model=provider.tool,
|
||||
)
|
||||
res.raw_func = func
|
||||
res.__signature__ = sig
|
||||
res.__doc__ = func.__doc__
|
||||
|
||||
return res
|
||||
|
||||
return decorator
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
@@ -141,4 +166,5 @@ __all__ = [
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
"tool"
|
||||
]
|
||||
|
||||
+19
-7
@@ -1,10 +1,11 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers._base_tools import BaseTool
|
||||
from .utils import find_provider
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
@@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||
def add_message_hook(
|
||||
self, conversation: "Conversation", message: "Message"
|
||||
) -> Any:
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||
def post_send_hook(
|
||||
self, conversation: "Conversation", response: "Message"
|
||||
) -> Any:
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -120,7 +125,9 @@ class Conversation(SMBaseModel):
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
|
||||
def prepend_system_message(
|
||||
self, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [
|
||||
Message(role="system", text=text, meta=meta or {})
|
||||
@@ -158,6 +165,7 @@ class Conversation(SMBaseModel):
|
||||
self,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
@@ -173,7 +181,7 @@ class Conversation(SMBaseModel):
|
||||
|
||||
# Find the provider and send the conversation.
|
||||
provider = find_provider(llm_provider or self.llm_provider)
|
||||
response = provider.send_conversation(self)
|
||||
response = provider.send_conversation(self, tools=tools)
|
||||
|
||||
# Execute all post-send hooks.
|
||||
for plugin in self.plugins:
|
||||
@@ -184,13 +192,17 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
# Add the response to the conversation.
|
||||
self.add_message(role="assistant", text=response.text, meta=response.meta)
|
||||
self.add_message(
|
||||
role="assistant", text=response.text, meta=response.meta
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||
"""Get the last message with the given role."""
|
||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||
return next(
|
||||
(m for m in reversed(self.messages) if m.role == role), None
|
||||
)
|
||||
|
||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Add a plugin to the conversation."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Type
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
from .amazon import Amazon
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
@@ -29,4 +30,5 @@ __all__ = [
|
||||
"Amazon",
|
||||
"providers",
|
||||
"BaseProvider",
|
||||
"BaseTool",
|
||||
]
|
||||
|
||||
@@ -37,7 +37,7 @@ class BaseProvider(ABC):
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable] | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
@@ -54,7 +54,7 @@ class BaseProvider(ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
tools: list[Callable] | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
@@ -67,7 +67,7 @@ class BaseProvider(ABC):
|
||||
"""The tool implementation for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_tools(self, tools: list[Callable] | None):
|
||||
def make_tools(self, tools: list[Callable | BaseTool] | None):
|
||||
if tools is not None:
|
||||
return [self.tool.from_function(func) for func in tools]
|
||||
else:
|
||||
|
||||
@@ -26,6 +26,7 @@ class BaseToolConfig(BaseModel):
|
||||
TYPE_CONVERSION: dict[type, str] = {
|
||||
str: "string",
|
||||
int: "integer",
|
||||
float: "number",
|
||||
bool: "boolean",
|
||||
}
|
||||
|
||||
@@ -42,13 +43,20 @@ class BaseTool(BaseModel, ABC):
|
||||
properties: dict[str, BaseToolProperty]
|
||||
required: list[str] | None = None
|
||||
config: ClassVar[BaseToolConfig] = BaseToolConfig()
|
||||
raw_func: Callable
|
||||
raw_func: Any | None = None
|
||||
tool_id: str | None = None
|
||||
function_result: str | None = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert self.raw_func is not None
|
||||
return self.raw_func(*args, **kwargs)
|
||||
|
||||
def is_executed(self) -> bool:
|
||||
return self.function_result is not None
|
||||
|
||||
def reset_result(self) -> None:
|
||||
self.function_result = None
|
||||
|
||||
@classmethod
|
||||
def convert_type(cls, field_type) -> str:
|
||||
if _is_literal(field_type):
|
||||
@@ -68,7 +76,11 @@ class BaseTool(BaseModel, ABC):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, func: Callable):
|
||||
def from_function(cls, func: Callable | "BaseTool"):
|
||||
# Check if the func passed is an instace of BaseTool
|
||||
if hasattr(func, "raw_func"):
|
||||
return func
|
||||
|
||||
annotations = getattr(func, "__annotations__", {})
|
||||
properties = {}
|
||||
required = []
|
||||
@@ -76,34 +88,39 @@ class BaseTool(BaseModel, ABC):
|
||||
func_signature = inspect.signature(func)
|
||||
|
||||
for n, (arg_name, arg_type) in enumerate(annotations.items()):
|
||||
# Check if argument has metadata (from Annotated)
|
||||
if hasattr(arg_type, "__metadata__"):
|
||||
field = arg_type.__metadata__[0] # Get Field info from metadata
|
||||
field_type = arg_type.__origin__ # Get actual type
|
||||
# Check if argument has a default value in signature
|
||||
elif (
|
||||
sig_param := func_signature.parameters[arg_name]
|
||||
).default is not inspect.Parameter.empty:
|
||||
field = sig_param.default # Use default as Field
|
||||
field_type = arg_type # Use plain type annotation
|
||||
else:
|
||||
# Raise error if no Field annotation found
|
||||
raise ValueError(
|
||||
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
|
||||
if ( # Skipping 'return' annotation (i.e.```-> str```)
|
||||
arg_name != "return"
|
||||
):
|
||||
# Check if argument has metadata (from Annotated)
|
||||
if hasattr(arg_type, "__metadata__"):
|
||||
field = arg_type.__metadata__[
|
||||
0
|
||||
] # Get Field info from metadata
|
||||
field_type = arg_type.__origin__ # Get actual type
|
||||
# Check if argument has a default value in signature
|
||||
elif (
|
||||
sig_param := func_signature.parameters[arg_name]
|
||||
).default is not inspect.Parameter.empty:
|
||||
field = sig_param.default # Use default as Field
|
||||
field_type = arg_type # Use plain type annotation
|
||||
else:
|
||||
# Raise error if no Field annotation found
|
||||
raise ValueError(
|
||||
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
|
||||
)
|
||||
|
||||
field_type_converted = cls.convert_type(field_type)
|
||||
|
||||
if _is_literal(field_type):
|
||||
enum_values = [str(x) for x in field_type.__args__]
|
||||
|
||||
properties[arg_name] = BaseToolProperty(
|
||||
type=field_type_converted,
|
||||
description=field.description,
|
||||
enum=enum_values,
|
||||
)
|
||||
|
||||
field_type_converted = cls.convert_type(field_type)
|
||||
|
||||
if _is_literal(field_type):
|
||||
enum_values = [str(x) for x in field_type.__args__]
|
||||
|
||||
properties[arg_name] = BaseToolProperty(
|
||||
type=field_type_converted,
|
||||
description=field.description,
|
||||
enum=enum_values,
|
||||
)
|
||||
if _is_required(field, func_signature, arg_name):
|
||||
required.append(arg_name)
|
||||
if _is_required(field, func_signature, arg_name):
|
||||
required.append(arg_name)
|
||||
|
||||
return cls(
|
||||
name=func.__name__,
|
||||
|
||||
@@ -18,6 +18,9 @@ T = TypeVar("T", bound=BaseModel)
|
||||
class AnthropicTool(BaseTool):
|
||||
def get_response_schema(self) -> Any:
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_id,
|
||||
@@ -28,6 +31,7 @@ class AnthropicTool(BaseTool):
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
msg = {"role": "assistant", "content": []}
|
||||
tool_used = False
|
||||
for content in response.content:
|
||||
if content.type == "tool_use" and content.name == self.name:
|
||||
msg["content"].append(
|
||||
@@ -41,12 +45,15 @@ class AnthropicTool(BaseTool):
|
||||
# Function execution:
|
||||
self.function_result = str(self.raw_func(**content.input))
|
||||
self.tool_id = content.id
|
||||
else:
|
||||
tool_used = True
|
||||
elif content.type == "text":
|
||||
msg["content"].append({"type": "text", "text": content.text})
|
||||
messages.append(msg)
|
||||
messages.append(
|
||||
{"role": "user", "content": [self.get_response_schema()]}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
messages.append(msg)
|
||||
messages.append(
|
||||
{"role": "user", "content": [self.get_response_schema()]}
|
||||
)
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
@@ -93,7 +100,7 @@ class Anthropic(BaseProvider):
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable] | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
@@ -117,16 +124,19 @@ class Anthropic(BaseProvider):
|
||||
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
|
||||
)
|
||||
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, messages)
|
||||
if tool.is_executed():
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
|
||||
)
|
||||
while response.content[-1].type != "text":
|
||||
print(response)
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, messages)
|
||||
if tool.is_executed():
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
|
||||
)
|
||||
tool.reset_result()
|
||||
|
||||
assistant_message = response.content[0].text
|
||||
assistant_message = response.content[-1].text
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
|
||||
Reference in New Issue
Block a user