first basic working version (anthropic)

This commit is contained in:
Luciano Scarpulla
2024-11-12 11:48:27 +08:00
parent c2303114ab
commit 1709055e1a
2 changed files with 157 additions and 7 deletions
+53 -7
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -21,8 +21,40 @@ DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class AntrhopicTool(BaseTool):
def get_schema(self):
class AnthropicTool(BaseTool):
def get_response_schema(self) -> Any:
assert self.is_executed, f"Tool {self.name} was not executed."
return {
"type": "tool_result",
"tool_use_id": self.tool_id,
"content": self.function_result,
}
@logger
def handle(self, response, messages) -> None:
"""Handle the tool execution result from an API response."""
msg = {"role": "assistant", "content": []}
for content in response.content:
if content.type == "tool_use" and content.name == self.name:
msg["content"].append(
{
"type": "tool_use",
"id": content.id,
"name": content.name,
"input": content.input,
}
)
# Function execution:
self.function_result = str(self.raw_func(**content.input))
self.tool_id = content.id
else:
msg["content"].append({"type": "text", "text": content.text})
messages.append(msg)
messages.append(
{"role": "user", "content": [self.get_response_schema()]}
)
def get_input_schema(self):
return {
"name": self.name,
"description": self.description,
@@ -77,14 +109,28 @@ class Anthropic(BaseProvider):
for msg in conversation.messages
]
converted_tools = self.make_tools(tools)
tools_kwarg = (
{}
if tools is None
else {"tools": [t.get_input_schema() for t in converted_tools]}
)
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
tools=self.make_tools(tools),
**{**self.DEFAULT_KWARGS, **kwargs},
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
)
# Get the response content from the Anthropic response
for tool in converted_tools:
tool.handle(response, messages)
if tool.is_executed():
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg},
)
assistant_message = response.content[0].text
# Create and return a properly formatted Message instance
@@ -152,4 +198,4 @@ class Anthropic(BaseProvider):
@cached_property
def tool(self) -> Type[BaseTool]:
"""The tool implementation for Antrhopic."""
return AntrhopicTool
return AnthropicTool
+104
View File
@@ -0,0 +1,104 @@
from typing import Annotated, Literal
import pytest
from pydantic import Field
import simplemind as sm
from simplemind.providers import Anthropic
MODELS = [
Anthropic,
# Gemini,
# OpenAI,
# Groq,
# Ollama,
# Amazon
]
def get_weather(
location: Annotated[
str, Field(description="The city and state, e.g. San Francisco, CA")
],
unit: Annotated[
Literal["celcius", "fahrenheit"],
Field(
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
),
] = "celcius",
):
"""
Get the current weather in a given location
"""
return f"42 {unit}"
def get_location():
"""Get the current location"""
return "San Francisco,CA"
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_args(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is the weather in San Francisco?")
data = conv.send(tools=[get_weather])
assert "42" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_no_args(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is my current location")
data = conv.send(tools=[get_location])
assert "San Francisco" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_partial(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="Can you tell me the weather?")
conv.send(tools=[get_weather])
# Will answer something like:
"""
I can help you check the weather, but I need to know the location you're interested in.
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
where you'd like to know the weather?
"""
conv.add_message(text="San Francisco, CA")
data = conv.send(tools=[get_weather])
assert "42" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_multiple_tools(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is the wheather at my current location?")
data = conv.send(tools=[get_location, get_weather])
assert "San Francisco" in data.text
assert "42" in data.text