mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@@ -25,11 +25,9 @@ class FakeMemory(BaseMemory):
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeChain(Chain):
|
||||
@@ -49,7 +47,11 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
@@ -143,25 +145,10 @@ def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||
callbacks=[handler],
|
||||
)
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_not_verbose() -> None:
|
||||
"""Test run method works when callback manager is passed and not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
@@ -3,6 +3,10 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -28,12 +32,18 @@ class FakeLLM(BaseLLM):
|
||||
n: int = 1
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
|
||||
from langchain.schema import OutputParserException
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_simple_question() -> None:
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == "2\n"
|
||||
|
||||
@@ -71,7 +71,7 @@ echo 'hello world'
|
||||
"""
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
with pytest.raises(OutputParserException):
|
||||
fake_llm_bash_chain.run(question)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain:
|
||||
): "I still don't know.",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMCheckerChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ def fake_llm_math_chain() -> LLMMathChain:
|
||||
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
|
||||
@@ -32,7 +32,9 @@ def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
|
||||
): "True",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMSummarizationCheckerChain.from_llm(
|
||||
fake_llm, input_key="q", output_key="a"
|
||||
)
|
||||
|
||||
|
||||
def test_simple_text(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
@@ -9,7 +10,12 @@ from langchain.llms.base import LLM
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Return `foo` if longer than 10000 words, else `bar`."""
|
||||
if len(prompt) > 10000:
|
||||
return "foo"
|
||||
@@ -28,7 +34,7 @@ class FakeLLM(LLM):
|
||||
|
||||
def test_proper_inputs() -> None:
|
||||
"""Test that natbot shortens inputs correctly."""
|
||||
nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
|
||||
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
|
||||
url = "foo" * 10000
|
||||
browser_content = "foo" * 10000
|
||||
output = nat_bot_chain.execute(url, browser_content)
|
||||
@@ -37,8 +43,8 @@ def test_proper_inputs() -> None:
|
||||
|
||||
def test_variable_key_naming() -> None:
|
||||
"""Test that natbot handles variable key naming correctly."""
|
||||
nat_bot_chain = NatBotChain(
|
||||
llm=FakeLLM(),
|
||||
nat_bot_chain = NatBotChain.from_llm(
|
||||
FakeLLM(),
|
||||
objective="testing",
|
||||
input_url_key="u",
|
||||
input_browser_content_key="b",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Test pipeline functionality."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
@@ -24,7 +25,11 @@ class FakeChain(Chain):
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
|
||||
Reference in New Issue
Block a user