mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
first basic working version (anthropic)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -21,8 +21,40 @@ DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class AntrhopicTool(BaseTool):
|
||||
def get_schema(self):
|
||||
class AnthropicTool(BaseTool):
|
||||
def get_response_schema(self) -> Any:
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
msg = {"role": "assistant", "content": []}
|
||||
for content in response.content:
|
||||
if content.type == "tool_use" and content.name == self.name:
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": content.id,
|
||||
"name": content.name,
|
||||
"input": content.input,
|
||||
}
|
||||
)
|
||||
# Function execution:
|
||||
self.function_result = str(self.raw_func(**content.input))
|
||||
self.tool_id = content.id
|
||||
else:
|
||||
msg["content"].append({"type": "text", "text": content.text})
|
||||
messages.append(msg)
|
||||
messages.append(
|
||||
{"role": "user", "content": [self.get_response_schema()]}
|
||||
)
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
@@ -77,14 +109,28 @@ class Anthropic(BaseProvider):
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_kwarg = (
|
||||
{}
|
||||
if tools is None
|
||||
else {"tools": [t.get_input_schema() for t in converted_tools]}
|
||||
)
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
tools=self.make_tools(tools),
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, messages)
|
||||
if tool.is_executed():
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
|
||||
)
|
||||
|
||||
assistant_message = response.content[0].text
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
@@ -152,4 +198,4 @@ class Anthropic(BaseProvider):
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for Antrhopic."""
|
||||
return AntrhopicTool
|
||||
return AnthropicTool
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
import simplemind as sm
|
||||
from simplemind.providers import Anthropic
|
||||
|
||||
MODELS = [
|
||||
Anthropic,
|
||||
# Gemini,
|
||||
# OpenAI,
|
||||
# Groq,
|
||||
# Ollama,
|
||||
# Amazon
|
||||
]
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[
|
||||
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||
],
|
||||
unit: Annotated[
|
||||
Literal["celcius", "fahrenheit"],
|
||||
Field(
|
||||
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||
),
|
||||
] = "celcius",
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
|
||||
|
||||
def get_location():
|
||||
"""Get the current location"""
|
||||
return "San Francisco,CA"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the weather in San Francisco?")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_no_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is my current location")
|
||||
data = conv.send(tools=[get_location])
|
||||
assert "San Francisco" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_partial(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="Can you tell me the weather?")
|
||||
conv.send(tools=[get_weather])
|
||||
# Will answer something like:
|
||||
"""
|
||||
I can help you check the weather, but I need to know the location you're interested in.
|
||||
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
|
||||
where you'd like to know the weather?
|
||||
"""
|
||||
|
||||
conv.add_message(text="San Francisco, CA")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_multiple_tools(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the wheather at my current location?")
|
||||
data = conv.send(tools=[get_location, get_weather])
|
||||
assert "San Francisco" in data.text
|
||||
assert "42" in data.text
|
||||
Reference in New Issue
Block a user