mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
7b5e160d28
Follow-up of @hinthornw's PR: - Migrate the Tool abstraction to a separate file (`BaseTool`). - `Tool` implementation of `BaseTool` takes in function and coroutine to more easily maintain backwards compatibility - Add a Toolkit abstraction that can own the generation of tools around a shared concept or state --------- Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Francisco Ingham <fpingham@gmail.com> Co-authored-by: Dhruv Anand <105786647+dhruv-anand-aintech@users.noreply.github.com> Co-authored-by: cragwolfe <cragcw@gmail.com> Co-authored-by: Anton Troynikov <atroyn@users.noreply.github.com> Co-authored-by: Oliver Klingefjord <oliver@klingefjord.com> Co-authored-by: William Fu-Hinthorn <whinthorn@Williams-MBP-3.attlocal.net> Co-authored-by: Bruno Bornsztein <bruno.bornsztein@gmail.com>
108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
"""Test tool utils."""
|
|
import pytest
|
|
|
|
from langchain.agents.tools import Tool, tool
|
|
from langchain.schema import AgentAction
|
|
|
|
|
|
def test_unnamed_decorator() -> None:
|
|
"""Test functionality with unnamed decorator."""
|
|
|
|
@tool
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search_api"
|
|
assert not search_api.return_direct
|
|
assert search_api("test") == "API result"
|
|
|
|
|
|
def test_named_tool_decorator() -> None:
|
|
"""Test functionality when arguments are provided as input to decorator."""
|
|
|
|
@tool("search")
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search"
|
|
assert not search_api.return_direct
|
|
|
|
|
|
def test_named_tool_decorator_return_direct() -> None:
|
|
"""Test functionality when arguments and return direct are provided as input."""
|
|
|
|
@tool("search", return_direct=True)
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search"
|
|
assert search_api.return_direct
|
|
|
|
|
|
def test_unnamed_tool_decorator_return_direct() -> None:
|
|
"""Test functionality when only return direct is provided."""
|
|
|
|
@tool(return_direct=True)
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search_api"
|
|
assert search_api.return_direct
|
|
|
|
|
|
def test_missing_docstring() -> None:
|
|
"""Test error is raised when docstring is missing."""
|
|
# expect to throw a value error if theres no docstring
|
|
with pytest.raises(AssertionError):
|
|
|
|
@tool
|
|
def search_api(query: str) -> str:
|
|
return "API result"
|
|
|
|
|
|
def test_create_tool_posistional_args() -> None:
|
|
"""Test that positional arguments are allowed."""
|
|
test_tool = Tool("test_name", lambda x: x, "test_description")
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
def test_create_tool_keyword_args() -> None:
|
|
"""Test that keyword arguments are allowed."""
|
|
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_async_tool() -> None:
|
|
"""Test that async tools are allowed."""
|
|
|
|
async def _test_func(x: str) -> str:
|
|
return x
|
|
|
|
test_tool = Tool(
|
|
name="test_name",
|
|
func=lambda x: x,
|
|
description="test_description",
|
|
coroutine=_test_func,
|
|
)
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
assert test_tool.coroutine is not None
|
|
assert (
|
|
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
|
|
== "foo"
|
|
)
|