Callbacks Refactor [base] (#3256)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ankush Gola
2023-04-30 11:14:09 -07:00
committed by GitHub
parent 18ec22fe56
commit d3ec00b566
208 changed files with 6394 additions and 3353 deletions
+7 -20
View File
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -25,11 +25,9 @@ class FakeMemory(BaseMemory):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Pass."""
pass
def clear(self) -> None:
"""Pass."""
pass
class FakeChain(Chain):
@@ -49,7 +47,11 @@ class FakeChain(Chain):
"""Output key of bar."""
return self.the_output_keys
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
if self.be_correct:
return {"bar": "baz"}
else:
@@ -143,25 +145,10 @@ def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(
callback_manager=CallbackManager(handlers=[handler]), verbose=True
callbacks=[handler],
)
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_not_verbose() -> None:
"""Test run method works when callback manager is passed and not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0