add tool decorator and example

This commit is contained in:
Luciano Scarpulla
2024-11-13 12:24:02 +08:00
parent 081baf203c
commit ea997aae7b
8 changed files with 262 additions and 56 deletions
+64 -1
View File
@@ -303,7 +303,70 @@ def get_weather(
return f"42 {unit}"
```
Functions can be defined with type hints and Pydantic models for validation. The AI will intelligently choose when to call the functions and incorporate the results into its responses.
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
#### 🪄 Using LLM for automatic tool definition (Experimental)
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
```python
@simplemind.tool(llm_provider="anthropic")
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
r = 6371
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
d = r * c
return d
```
Notice how we have not added any docstrings or `Field` for the function.
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
```json
{
"name": "haversine",
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
"input_schema": {
"type": "object",
"properties": {
"lat1": {
"type": "number",
"description": "Latitude of the first point in decimal degrees",
},
"lon1": {
"type": "number",
"description": "Longitude of the first point in decimal degrees",
},
"lat2": {
"type": "number",
"description": "Latitude of the second point in decimal degrees",
},
"lon2": {
"type": "number",
"description": "Longitude of the second point in decimal degrees",
}
},
"required": ["lat1", "lon1", "lat2", "lon2"],
},
}
```
The decorated function can then be used like any other tool with the conversation API.
```python
conversation = sm.create_conversation()
conversation.add_message("user", "How far is London from my location")
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
```
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
### Logging
+76
View File
@@ -0,0 +1,76 @@
import math
from _context import sm
from pydantic import Field
from typing_extensions import Literal
@sm.tool(llm_provider="anthropic")
def haversine(
lat1: float,
lon1: float,
lat2: float,
lon2: float,
unit: Literal["km", "miles"],
) -> float:
r = 6378.0937 if unit == "km" else 3961
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
d = r * c
return d
def get_user_location() -> str:
"""Get the closest city from the user"""
return "San Francisco"
def get_coords(
city_name: str = Field(
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
),
):
"""Get latitude and logitude of a City."""
_data = {
"Rome": (41.9028, 12.4964),
"London": (51.5074, -0.1278),
"Madrid": (40.4168, -3.7038),
"San Francisco": (37.7749, -122.4194),
"Los Angeles": (34.0522, -118.2437),
}
return _data.get(city_name)
def distance_calculator(prompt: str):
conversation = sm.create_conversation(llm_provider="anthropic")
conversation.add_message("user", prompt)
return conversation.send(
tools=[get_user_location, get_coords, haversine]
).text
print(distance_calculator("How far is London from where I am?"))
# Prints something like:
"""
The distance between your location (San Francisco) and London is approximately 5,357 miles.
"""
print(
distance_calculator(
"What is the distance between Rome and Madrid in Kilometers?"
)
)
"""
The distance between Rome and Madrid is approximately 1,366 kilometers.
"""
+27 -1
View File
@@ -1,4 +1,5 @@
from typing import Generator, List, Type
import inspect
from typing import Callable, List, Type
from .models import BaseModel, BasePlugin, Conversation
from .settings import settings
@@ -127,6 +128,30 @@ def enable_logfire() -> None:
"""Enable logfire logging."""
settings.logging.enable_logfire()
def tool(
llm_provider: str | None = None,
llm_model: str | None = None,
):
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
def decorator(func: Callable):
sig = inspect.signature(func)
res = generate_data(
(
"Based on this function signature, fill up the required fieds."
f"\nSignature: {func.__name__}{sig}"
"Make sure to properly add the required field in `required` if there are no defaults"
),
llm_provider=llm_provider,
response_model=provider.tool,
)
res.raw_func = func
res.__signature__ = sig
res.__doc__ = func.__doc__
return res
return decorator
# Syntax sugar.
Plugin = BasePlugin
@@ -141,4 +166,5 @@ __all__ = [
"Session",
"Plugin",
"enable_logfire",
"tool"
]
+19 -7
View File
@@ -1,10 +1,11 @@
import uuid
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .providers._base_tools import BaseTool
from .utils import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
@@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel):
"""Cleanup a hook for the plugin."""
raise NotImplementedError
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
def add_message_hook(
self, conversation: "Conversation", message: "Message"
) -> Any:
"""Add a message hook for the plugin."""
raise NotImplementedError
@@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel):
"""Pre-send hook for the plugin."""
raise NotImplementedError
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
def post_send_hook(
self, conversation: "Conversation", response: "Message"
) -> Any:
"""Post-send hook for the plugin."""
raise NotImplementedError
@@ -120,7 +125,9 @@ class Conversation(SMBaseModel):
except NotImplementedError:
pass
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
def prepend_system_message(
self, text: str, meta: Dict[str, Any] | None = None
):
"""Prepend a system message to the conversation."""
self.messages = [
Message(role="system", text=text, meta=meta or {})
@@ -158,6 +165,7 @@ class Conversation(SMBaseModel):
self,
llm_model: str | None = None,
llm_provider: str | None = None,
tools: list[Callable | BaseTool] | None = None,
) -> Message:
"""Send the conversation to the LLM."""
@@ -173,7 +181,7 @@ class Conversation(SMBaseModel):
# Find the provider and send the conversation.
provider = find_provider(llm_provider or self.llm_provider)
response = provider.send_conversation(self)
response = provider.send_conversation(self, tools=tools)
# Execute all post-send hooks.
for plugin in self.plugins:
@@ -184,13 +192,17 @@ class Conversation(SMBaseModel):
pass
# Add the response to the conversation.
self.add_message(role="assistant", text=response.text, meta=response.meta)
self.add_message(
role="assistant", text=response.text, meta=response.meta
)
return response
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
"""Get the last message with the given role."""
return next((m for m in reversed(self.messages) if m.role == role), None)
return next(
(m for m in reversed(self.messages) if m.role == role), None
)
def add_plugin(self, plugin: BasePlugin) -> None:
"""Add a plugin to the conversation."""
+2
View File
@@ -1,6 +1,7 @@
from typing import List, Type
from ._base import BaseProvider
from ._base_tools import BaseTool
from .amazon import Amazon
from .anthropic import Anthropic
from .gemini import Gemini
@@ -29,4 +30,5 @@ __all__ = [
"Amazon",
"providers",
"BaseProvider",
"BaseTool",
]
+3 -3
View File
@@ -37,7 +37,7 @@ class BaseProvider(ABC):
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable] | None = None,
tools: list[Callable | BaseTool] | None = None,
) -> "Message":
"""Send a conversation to the provider."""
raise NotImplementedError
@@ -54,7 +54,7 @@ class BaseProvider(ABC):
self,
prompt: str,
*,
tools: list[Callable] | None = None,
tools: list[Callable | BaseTool] | None = None,
stream: bool = False,
**kwargs,
) -> str:
@@ -67,7 +67,7 @@ class BaseProvider(ABC):
"""The tool implementation for the provider."""
raise NotImplementedError
def make_tools(self, tools: list[Callable] | None):
def make_tools(self, tools: list[Callable | BaseTool] | None):
if tools is not None:
return [self.tool.from_function(func) for func in tools]
else:
+46 -29
View File
@@ -26,6 +26,7 @@ class BaseToolConfig(BaseModel):
TYPE_CONVERSION: dict[type, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
}
@@ -42,13 +43,20 @@ class BaseTool(BaseModel, ABC):
properties: dict[str, BaseToolProperty]
required: list[str] | None = None
config: ClassVar[BaseToolConfig] = BaseToolConfig()
raw_func: Callable
raw_func: Any | None = None
tool_id: str | None = None
function_result: str | None = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
assert self.raw_func is not None
return self.raw_func(*args, **kwargs)
def is_executed(self) -> bool:
return self.function_result is not None
def reset_result(self) -> None:
self.function_result = None
@classmethod
def convert_type(cls, field_type) -> str:
if _is_literal(field_type):
@@ -68,7 +76,11 @@ class BaseTool(BaseModel, ABC):
}
@classmethod
def from_function(cls, func: Callable):
def from_function(cls, func: Callable | "BaseTool"):
# Check if the func passed is an instace of BaseTool
if hasattr(func, "raw_func"):
return func
annotations = getattr(func, "__annotations__", {})
properties = {}
required = []
@@ -76,34 +88,39 @@ class BaseTool(BaseModel, ABC):
func_signature = inspect.signature(func)
for n, (arg_name, arg_type) in enumerate(annotations.items()):
# Check if argument has metadata (from Annotated)
if hasattr(arg_type, "__metadata__"):
field = arg_type.__metadata__[0] # Get Field info from metadata
field_type = arg_type.__origin__ # Get actual type
# Check if argument has a default value in signature
elif (
sig_param := func_signature.parameters[arg_name]
).default is not inspect.Parameter.empty:
field = sig_param.default # Use default as Field
field_type = arg_type # Use plain type annotation
else:
# Raise error if no Field annotation found
raise ValueError(
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
if ( # Skipping 'return' annotation (i.e.```-> str```)
arg_name != "return"
):
# Check if argument has metadata (from Annotated)
if hasattr(arg_type, "__metadata__"):
field = arg_type.__metadata__[
0
] # Get Field info from metadata
field_type = arg_type.__origin__ # Get actual type
# Check if argument has a default value in signature
elif (
sig_param := func_signature.parameters[arg_name]
).default is not inspect.Parameter.empty:
field = sig_param.default # Use default as Field
field_type = arg_type # Use plain type annotation
else:
# Raise error if no Field annotation found
raise ValueError(
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
)
field_type_converted = cls.convert_type(field_type)
if _is_literal(field_type):
enum_values = [str(x) for x in field_type.__args__]
properties[arg_name] = BaseToolProperty(
type=field_type_converted,
description=field.description,
enum=enum_values,
)
field_type_converted = cls.convert_type(field_type)
if _is_literal(field_type):
enum_values = [str(x) for x in field_type.__args__]
properties[arg_name] = BaseToolProperty(
type=field_type_converted,
description=field.description,
enum=enum_values,
)
if _is_required(field, func_signature, arg_name):
required.append(arg_name)
if _is_required(field, func_signature, arg_name):
required.append(arg_name)
return cls(
name=func.__name__,
+25 -15
View File
@@ -18,6 +18,9 @@ T = TypeVar("T", bound=BaseModel)
class AnthropicTool(BaseTool):
def get_response_schema(self) -> Any:
assert self.is_executed, f"Tool {self.name} was not executed."
assert isinstance(
self.tool_id, str
), f"Expected str for `tool_id` got {self.tool_id!r}"
return {
"type": "tool_result",
"tool_use_id": self.tool_id,
@@ -28,6 +31,7 @@ class AnthropicTool(BaseTool):
def handle(self, response, messages) -> None:
"""Handle the tool execution result from an API response."""
msg = {"role": "assistant", "content": []}
tool_used = False
for content in response.content:
if content.type == "tool_use" and content.name == self.name:
msg["content"].append(
@@ -41,12 +45,15 @@ class AnthropicTool(BaseTool):
# Function execution:
self.function_result = str(self.raw_func(**content.input))
self.tool_id = content.id
else:
tool_used = True
elif content.type == "text":
msg["content"].append({"type": "text", "text": content.text})
messages.append(msg)
messages.append(
{"role": "user", "content": [self.get_response_schema()]}
)
if tool_used:
messages.append(msg)
messages.append(
{"role": "user", "content": [self.get_response_schema()]}
)
def get_input_schema(self):
return {
@@ -93,7 +100,7 @@ class Anthropic(BaseProvider):
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable] | None = None,
tools: list[Callable | BaseTool] | None = None,
**kwargs,
) -> "Message":
"""Send a conversation to the Anthropic API."""
@@ -117,16 +124,19 @@ class Anthropic(BaseProvider):
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
)
for tool in converted_tools:
tool.handle(response, messages)
if tool.is_executed():
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
)
while response.content[-1].type != "text":
print(response)
for tool in converted_tools:
tool.handle(response, messages)
if tool.is_executed():
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
)
tool.reset_result()
assistant_message = response.content[0].text
assistant_message = response.content[-1].text
# Create and return a properly formatted Message instance
return Message(