mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@@ -25,11 +25,9 @@ class FakeMemory(BaseMemory):
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeChain(Chain):
|
||||
@@ -49,7 +47,11 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
@@ -143,25 +145,10 @@ def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||
callbacks=[handler],
|
||||
)
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_not_verbose() -> None:
|
||||
"""Test run method works when callback manager is passed and not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
Reference in New Issue
Block a user