From ea997aae7b418d4c57837dd3b820ed3e2aa3525c Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Wed, 13 Nov 2024 12:24:02 +0800 Subject: [PATCH] add tool decorator and example --- README.md | 65 +++++++++++++++++++++++- examples/distance_calculator.py | 76 +++++++++++++++++++++++++++++ simplemind/__init__.py | 28 ++++++++++- simplemind/models.py | 26 +++++++--- simplemind/providers/__init__.py | 2 + simplemind/providers/_base.py | 6 +-- simplemind/providers/_base_tools.py | 75 +++++++++++++++++----------- simplemind/providers/anthropic.py | 40 +++++++++------ 8 files changed, 262 insertions(+), 56 deletions(-) create mode 100644 examples/distance_calculator.py diff --git a/README.md b/README.md index 644859b..7e1c605 100644 --- a/README.md +++ b/README.md @@ -303,7 +303,70 @@ def get_weather( return f"42 {unit}" ``` -Functions can be defined with type hints and Pydantic models for validation. The AI will intelligently choose when to call the functions and incorporate the results into its responses. +Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses. + +#### 🪄 Using LLM for automatic tool definition (Experimental) + +Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema: + +```python +@simplemind.tool(llm_provider="anthropic") +def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + r = 6371 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d +``` +Notice how we have not added any docstrings or `Field` for the function. +The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details: + +```json +{ + "name": "haversine", + "description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates", + "input_schema": { + "type": "object", + "properties": { + "lat1": { + "type": "number", + "description": "Latitude of the first point in decimal degrees", + }, + "lon1": { + "type": "number", + "description": "Longitude of the first point in decimal degrees", + }, + "lat2": { + "type": "number", + "description": "Latitude of the second point in decimal degrees", + }, + "lon2": { + "type": "number", + "description": "Longitude of the second point in decimal degrees", + } + }, + "required": ["lat1", "lon1", "lat2", "lon2"], + }, +} +``` + +The decorated function can then be used like any other tool with the conversation API. + +```python +conversation = sm.create_conversation() +conversation.add_message("user", "How far is London from my location") +response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed +``` + +See [examples/distance_calculator.py](examples/distance_calculator.py) for more. ### Logging diff --git a/examples/distance_calculator.py b/examples/distance_calculator.py new file mode 100644 index 0000000..3d03dea --- /dev/null +++ b/examples/distance_calculator.py @@ -0,0 +1,76 @@ +import math + +from _context import sm +from pydantic import Field +from typing_extensions import Literal + + +@sm.tool(llm_provider="anthropic") +def haversine( + lat1: float, + lon1: float, + lat2: float, + lon2: float, + unit: Literal["km", "miles"], +) -> float: + r = 6378.0937 if unit == "km" else 3961 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d + + +def get_user_location() -> str: + """Get the closest city from the user""" + return "San Francisco" + + +def get_coords( + city_name: str = Field( + description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)" + ), +): + """Get latitude and logitude of a City.""" + _data = { + "Rome": (41.9028, 12.4964), + "London": (51.5074, -0.1278), + "Madrid": (40.4168, -3.7038), + "San Francisco": (37.7749, -122.4194), + "Los Angeles": (34.0522, -118.2437), + } + + return _data.get(city_name) + + +def distance_calculator(prompt: str): + conversation = sm.create_conversation(llm_provider="anthropic") + conversation.add_message("user", prompt) + return conversation.send( + tools=[get_user_location, get_coords, haversine] + ).text + + +print(distance_calculator("How far is London from where I am?")) +# Prints something like: +""" +The distance between your location (San Francisco) and London is approximately 5,357 miles. +""" + +print( + distance_calculator( + "What is the distance between Rome and Madrid in Kilometers?" + ) +) + + +""" +The distance between Rome and Madrid is approximately 1,366 kilometers. +""" diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 90e9d4f..4f2673b 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Type +import inspect +from typing import Callable, List, Type from .models import BaseModel, BasePlugin, Conversation from .settings import settings @@ -127,6 +128,30 @@ def enable_logfire() -> None: """Enable logfire logging.""" settings.logging.enable_logfire() +def tool( + llm_provider: str | None = None, + llm_model: str | None = None, +): + provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER) + + def decorator(func: Callable): + sig = inspect.signature(func) + res = generate_data( + ( + "Based on this function signature, fill up the required fieds." + f"\nSignature: {func.__name__}{sig}" + "Make sure to properly add the required field in `required` if there are no defaults" + ), + llm_provider=llm_provider, + response_model=provider.tool, + ) + res.raw_func = func + res.__signature__ = sig + res.__doc__ = func.__doc__ + + return res + + return decorator # Syntax sugar. Plugin = BasePlugin @@ -141,4 +166,5 @@ __all__ = [ "Session", "Plugin", "enable_logfire", + "tool" ] diff --git a/simplemind/models.py b/simplemind/models.py index 5ba0a2b..f8ff5be 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,10 +1,11 @@ import uuid from datetime import datetime from types import TracebackType -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from pydantic import BaseModel, Field +from .providers._base_tools import BaseTool from .utils import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] @@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel): """Cleanup a hook for the plugin.""" raise NotImplementedError - def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any: + def add_message_hook( + self, conversation: "Conversation", message: "Message" + ) -> Any: """Add a message hook for the plugin.""" raise NotImplementedError @@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel): """Pre-send hook for the plugin.""" raise NotImplementedError - def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any: + def post_send_hook( + self, conversation: "Conversation", response: "Message" + ) -> Any: """Post-send hook for the plugin.""" raise NotImplementedError @@ -120,7 +125,9 @@ class Conversation(SMBaseModel): except NotImplementedError: pass - def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None): + def prepend_system_message( + self, text: str, meta: Dict[str, Any] | None = None + ): """Prepend a system message to the conversation.""" self.messages = [ Message(role="system", text=text, meta=meta or {}) @@ -158,6 +165,7 @@ class Conversation(SMBaseModel): self, llm_model: str | None = None, llm_provider: str | None = None, + tools: list[Callable | BaseTool] | None = None, ) -> Message: """Send the conversation to the LLM.""" @@ -173,7 +181,7 @@ class Conversation(SMBaseModel): # Find the provider and send the conversation. provider = find_provider(llm_provider or self.llm_provider) - response = provider.send_conversation(self) + response = provider.send_conversation(self, tools=tools) # Execute all post-send hooks. for plugin in self.plugins: @@ -184,13 +192,17 @@ class Conversation(SMBaseModel): pass # Add the response to the conversation. - self.add_message(role="assistant", text=response.text, meta=response.meta) + self.add_message( + role="assistant", text=response.text, meta=response.meta + ) return response def get_last_message(self, role: MESSAGE_ROLE) -> Message | None: """Get the last message with the given role.""" - return next((m for m in reversed(self.messages) if m.role == role), None) + return next( + (m for m in reversed(self.messages) if m.role == role), None + ) def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 1f72f78..1f92912 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,6 +1,7 @@ from typing import List, Type from ._base import BaseProvider +from ._base_tools import BaseTool from .amazon import Amazon from .anthropic import Anthropic from .gemini import Gemini @@ -29,4 +30,5 @@ __all__ = [ "Amazon", "providers", "BaseProvider", + "BaseTool", ] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 235e030..c3ee6cd 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -37,7 +37,7 @@ class BaseProvider(ABC): def send_conversation( self, conversation: "Conversation", - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, ) -> "Message": """Send a conversation to the provider.""" raise NotImplementedError @@ -54,7 +54,7 @@ class BaseProvider(ABC): self, prompt: str, *, - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, stream: bool = False, **kwargs, ) -> str: @@ -67,7 +67,7 @@ class BaseProvider(ABC): """The tool implementation for the provider.""" raise NotImplementedError - def make_tools(self, tools: list[Callable] | None): + def make_tools(self, tools: list[Callable | BaseTool] | None): if tools is not None: return [self.tool.from_function(func) for func in tools] else: diff --git a/simplemind/providers/_base_tools.py b/simplemind/providers/_base_tools.py index 7a2f37f..3961e3e 100644 --- a/simplemind/providers/_base_tools.py +++ b/simplemind/providers/_base_tools.py @@ -26,6 +26,7 @@ class BaseToolConfig(BaseModel): TYPE_CONVERSION: dict[type, str] = { str: "string", int: "integer", + float: "number", bool: "boolean", } @@ -42,13 +43,20 @@ class BaseTool(BaseModel, ABC): properties: dict[str, BaseToolProperty] required: list[str] | None = None config: ClassVar[BaseToolConfig] = BaseToolConfig() - raw_func: Callable + raw_func: Any | None = None tool_id: str | None = None function_result: str | None = None + def __call__(self, *args: Any, **kwargs: Any) -> Any: + assert self.raw_func is not None + return self.raw_func(*args, **kwargs) + def is_executed(self) -> bool: return self.function_result is not None + def reset_result(self) -> None: + self.function_result = None + @classmethod def convert_type(cls, field_type) -> str: if _is_literal(field_type): @@ -68,7 +76,11 @@ class BaseTool(BaseModel, ABC): } @classmethod - def from_function(cls, func: Callable): + def from_function(cls, func: Callable | "BaseTool"): + # Check if the func passed is an instace of BaseTool + if hasattr(func, "raw_func"): + return func + annotations = getattr(func, "__annotations__", {}) properties = {} required = [] @@ -76,34 +88,39 @@ class BaseTool(BaseModel, ABC): func_signature = inspect.signature(func) for n, (arg_name, arg_type) in enumerate(annotations.items()): - # Check if argument has metadata (from Annotated) - if hasattr(arg_type, "__metadata__"): - field = arg_type.__metadata__[0] # Get Field info from metadata - field_type = arg_type.__origin__ # Get actual type - # Check if argument has a default value in signature - elif ( - sig_param := func_signature.parameters[arg_name] - ).default is not inspect.Parameter.empty: - field = sig_param.default # Use default as Field - field_type = arg_type # Use plain type annotation - else: - # Raise error if no Field annotation found - raise ValueError( - f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + if ( # Skipping 'return' annotation (i.e.```-> str```) + arg_name != "return" + ): + # Check if argument has metadata (from Annotated) + if hasattr(arg_type, "__metadata__"): + field = arg_type.__metadata__[ + 0 + ] # Get Field info from metadata + field_type = arg_type.__origin__ # Get actual type + # Check if argument has a default value in signature + elif ( + sig_param := func_signature.parameters[arg_name] + ).default is not inspect.Parameter.empty: + field = sig_param.default # Use default as Field + field_type = arg_type # Use plain type annotation + else: + # Raise error if no Field annotation found + raise ValueError( + f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + ) + + field_type_converted = cls.convert_type(field_type) + + if _is_literal(field_type): + enum_values = [str(x) for x in field_type.__args__] + + properties[arg_name] = BaseToolProperty( + type=field_type_converted, + description=field.description, + enum=enum_values, ) - - field_type_converted = cls.convert_type(field_type) - - if _is_literal(field_type): - enum_values = [str(x) for x in field_type.__args__] - - properties[arg_name] = BaseToolProperty( - type=field_type_converted, - description=field.description, - enum=enum_values, - ) - if _is_required(field, func_signature, arg_name): - required.append(arg_name) + if _is_required(field, func_signature, arg_name): + required.append(arg_name) return cls( name=func.__name__, diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ea4a730..c714b4e 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -18,6 +18,9 @@ T = TypeVar("T", bound=BaseModel) class AnthropicTool(BaseTool): def get_response_schema(self) -> Any: assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" return { "type": "tool_result", "tool_use_id": self.tool_id, @@ -28,6 +31,7 @@ class AnthropicTool(BaseTool): def handle(self, response, messages) -> None: """Handle the tool execution result from an API response.""" msg = {"role": "assistant", "content": []} + tool_used = False for content in response.content: if content.type == "tool_use" and content.name == self.name: msg["content"].append( @@ -41,12 +45,15 @@ class AnthropicTool(BaseTool): # Function execution: self.function_result = str(self.raw_func(**content.input)) self.tool_id = content.id - else: + tool_used = True + elif content.type == "text": msg["content"].append({"type": "text", "text": content.text}) - messages.append(msg) - messages.append( - {"role": "user", "content": [self.get_response_schema()]} - ) + + if tool_used: + messages.append(msg) + messages.append( + {"role": "user", "content": [self.get_response_schema()]} + ) def get_input_schema(self): return { @@ -93,7 +100,7 @@ class Anthropic(BaseProvider): def send_conversation( self, conversation: "Conversation", - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, **kwargs, ) -> "Message": """Send a conversation to the Anthropic API.""" @@ -117,16 +124,19 @@ class Anthropic(BaseProvider): **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, ) - for tool in converted_tools: - tool.handle(response, messages) - if tool.is_executed(): - response = self.client.messages.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, - ) + while response.content[-1].type != "text": + print(response) + for tool in converted_tools: + tool.handle(response, messages) + if tool.is_executed(): + response = self.client.messages.create( + model=conversation.llm_model or self.DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, + ) + tool.reset_result() - assistant_message = response.content[0].text + assistant_message = response.content[-1].text # Create and return a properly formatted Message instance return Message(