diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 5e6ca01..065985f 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -21,8 +21,40 @@ DEFAULT_MAX_TOKENS = 1_000 DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} -class AntrhopicTool(BaseTool): - def get_schema(self): +class AnthropicTool(BaseTool): + def get_response_schema(self) -> Any: + assert self.is_executed, f"Tool {self.name} was not executed." + return { + "type": "tool_result", + "tool_use_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + msg = {"role": "assistant", "content": []} + for content in response.content: + if content.type == "tool_use" and content.name == self.name: + msg["content"].append( + { + "type": "tool_use", + "id": content.id, + "name": content.name, + "input": content.input, + } + ) + # Function execution: + self.function_result = str(self.raw_func(**content.input)) + self.tool_id = content.id + else: + msg["content"].append({"type": "text", "text": content.text}) + messages.append(msg) + messages.append( + {"role": "user", "content": [self.get_response_schema()]} + ) + + def get_input_schema(self): return { "name": self.name, "description": self.description, @@ -77,14 +109,28 @@ class Anthropic(BaseProvider): for msg in conversation.messages ] + converted_tools = self.make_tools(tools) + tools_kwarg = ( + {} + if tools is None + else {"tools": [t.get_input_schema() for t in converted_tools]} + ) + response = self.client.messages.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, - tools=self.make_tools(tools), - **{**self.DEFAULT_KWARGS, **kwargs}, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, ) - # Get the response content from the Anthropic response + for tool in converted_tools: + tool.handle(response, messages) + if tool.is_executed(): + response = self.client.messages.create( + model=conversation.llm_model or self.DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, + ) + assistant_message = response.content[0].text # Create and return a properly formatted Message instance @@ -152,4 +198,4 @@ class Anthropic(BaseProvider): @cached_property def tool(self) -> Type[BaseTool]: """The tool implementation for Antrhopic.""" - return AntrhopicTool + return AnthropicTool diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..1926831 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,104 @@ +from typing import Annotated, Literal + +import pytest +from pydantic import Field + +import simplemind as sm +from simplemind.providers import Anthropic + +MODELS = [ + Anthropic, + # Gemini, + # OpenAI, + # Groq, + # Ollama, + # Amazon +] + + +def get_weather( + location: Annotated[ + str, Field(description="The city and state, e.g. San Francisco, CA") + ], + unit: Annotated[ + Literal["celcius", "fahrenheit"], + Field( + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), + ] = "celcius", +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" + + +def get_location(): + """Get the current location""" + return "San Francisco,CA" + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the weather in San Francisco?") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_no_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is my current location") + data = conv.send(tools=[get_location]) + assert "San Francisco" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_partial(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="Can you tell me the weather?") + conv.send(tools=[get_weather]) + # Will answer something like: + """ + I can help you check the weather, but I need to know the location you're interested in. + Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY") + where you'd like to know the weather? + """ + + conv.add_message(text="San Francisco, CA") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_multiple_tools(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the wheather at my current location?") + data = conv.send(tools=[get_location, get_weather]) + assert "San Francisco" in data.text + assert "42" in data.text