mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
"""Test CallbackManager."""
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentFinish, LLMResult
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@@ -18,19 +15,26 @@ from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager_chain.on_chain_end({})
|
||||
run_manager_chain.on_chain_error(Exception())
|
||||
run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
|
||||
run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager_chain.on_text("foo")
|
||||
|
||||
run_manager_tool = manager.on_tool_start({}, "")
|
||||
run_manager_tool.on_tool_end("")
|
||||
run_manager_tool.on_tool_error(Exception())
|
||||
run_manager_tool.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
@@ -38,75 +42,62 @@ async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
await manager.on_llm_start({}, [])
|
||||
await manager.on_llm_end(LLMResult(generations=[]))
|
||||
await manager.on_llm_error(Exception())
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, "")
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = await manager.on_llm_start({}, [])
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
|
||||
await run_manager_chain.on_chain_end({})
|
||||
await run_manager_chain.on_chain_error(Exception())
|
||||
await run_manager_chain.on_agent_action(
|
||||
AgentAction(tool_input="foo", log="", tool="")
|
||||
)
|
||||
await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
await run_manager_chain.on_text("foo")
|
||||
|
||||
run_manager_tool = await manager.on_tool_start({}, "")
|
||||
await run_manager_tool.on_tool_end("")
|
||||
await run_manager_tool.on_tool_error(Exception())
|
||||
await run_manager_tool.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
else:
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def _test_callback_manager_pass_in_verbose(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.starts == 4
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
assert handler.text == 3
|
||||
|
||||
assert handler.llm_starts == 1
|
||||
assert handler.llm_ends == 1
|
||||
assert handler.llm_streams == 1
|
||||
|
||||
assert handler.chain_starts == 1
|
||||
assert handler.chain_ends == 1
|
||||
|
||||
assert handler.tool_starts == 1
|
||||
assert handler.tool_ends == 1
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=False)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_callback_manager_pass_in_verbose() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_ignore_llm() -> None:
|
||||
"""Test ignore llm param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -117,12 +108,12 @@ def test_ignore_llm() -> None:
|
||||
|
||||
def test_ignore_chain() -> None:
|
||||
"""Test ignore chain param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager.on_chain_end({})
|
||||
run_manager.on_chain_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -133,39 +124,24 @@ def test_ignore_chain() -> None:
|
||||
|
||||
def test_ignore_agent() -> None:
|
||||
"""Test ignore agent param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
run_manager = manager.on_tool_start({}, "")
|
||||
run_manager.on_tool_end("")
|
||||
run_manager.on_tool_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
assert handler2.starts == 1
|
||||
assert handler2.ends == 2
|
||||
assert handler2.ends == 1
|
||||
assert handler2.errors == 1
|
||||
|
||||
|
||||
def test_shared_callback_manager() -> None:
|
||||
"""Test the SharedCallbackManager."""
|
||||
manager1 = SharedCallbackManager()
|
||||
manager2 = SharedCallbackManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeAsyncCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
@@ -174,8 +150,95 @@ async def test_async_callback_manager() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager_sync_handler() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler3 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2, handler3])
|
||||
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
||||
|
||||
|
||||
def test_callback_manager_inheritance() -> None:
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
callback_manager1 = CallbackManager([handler1, handler2])
|
||||
assert callback_manager1.handlers == [handler1, handler2]
|
||||
assert callback_manager1.inheritable_handlers == []
|
||||
|
||||
callback_manager2 = CallbackManager([])
|
||||
assert callback_manager2.handlers == []
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.set_handlers([handler1, handler2])
|
||||
assert callback_manager2.handlers == [handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1, handler2]
|
||||
|
||||
callback_manager2.set_handlers([handler3, handler4], inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4]
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.add_handler(handler1)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
callback_manager2.add_handler(handler2, inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
run_manager = callback_manager2.on_chain_start({"name": "foo"}, {})
|
||||
child_manager = run_manager.get_child()
|
||||
assert child_manager.handlers == [handler1]
|
||||
assert child_manager.inheritable_handlers == [handler1]
|
||||
|
||||
run_manager_tool = child_manager.on_tool_start({}, "")
|
||||
assert run_manager_tool.handlers == [handler1]
|
||||
assert run_manager_tool.inheritable_handlers == [handler1]
|
||||
|
||||
child_manager2 = run_manager_tool.get_child()
|
||||
assert child_manager2.handlers == [handler1]
|
||||
assert child_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
|
||||
def test_callback_manager_configure() -> None:
|
||||
"""Test callback manager configuration."""
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
||||
local_callbacks: List[BaseCallbackHandler] = [handler3, handler4]
|
||||
configured_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=local_callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert len(configured_manager.handlers) == 5
|
||||
assert len(configured_manager.inheritable_handlers) == 2
|
||||
assert configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks
|
||||
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
||||
assert isinstance(configured_manager, CallbackManager)
|
||||
|
||||
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
|
||||
async_configured_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=async_local_callbacks,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert len(async_configured_manager.handlers) == 4
|
||||
assert len(async_configured_manager.inheritable_handlers) == 2
|
||||
assert async_configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert async_configured_manager.handlers == inheritable_callbacks + [
|
||||
handler3,
|
||||
handler4,
|
||||
]
|
||||
assert isinstance(async_configured_manager, AsyncCallbackManager)
|
||||
|
||||
Reference in New Issue
Block a user