Callbacks Refactor [base] (#3256)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ankush Gola
2023-04-30 11:14:09 -07:00
committed by GitHub
parent 18ec22fe56
commit d3ec00b566
208 changed files with 6394 additions and 3353 deletions
+7 -20
View File
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -25,11 +25,9 @@ class FakeMemory(BaseMemory):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Pass."""
pass
def clear(self) -> None:
"""Pass."""
pass
class FakeChain(Chain):
@@ -49,7 +47,11 @@ class FakeChain(Chain):
"""Output key of bar."""
return self.the_output_keys
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
if self.be_correct:
return {"bar": "baz"}
else:
@@ -143,25 +145,10 @@ def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(
callback_manager=CallbackManager(handlers=[handler]), verbose=True
callbacks=[handler],
)
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_not_verbose() -> None:
"""Test run method works when callback manager is passed and not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
+12 -2
View File
@@ -3,6 +3,10 @@ from typing import List, Optional
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.embeddings.base import Embeddings
@@ -28,12 +32,18 @@ class FakeLLM(BaseLLM):
n: int = 1
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
+4 -4
View File
@@ -3,8 +3,8 @@ import sys
import pytest
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
from langchain.schema import OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -43,7 +43,7 @@ def test_simple_question() -> None:
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
output = fake_llm_bash_chain.run(question)
assert output == "2\n"
@@ -71,7 +71,7 @@ echo 'hello world'
"""
}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
with pytest.raises(OutputParserException):
fake_llm_bash_chain.run(question)
+1 -1
View File
@@ -33,7 +33,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain:
): "I still don't know.",
}
fake_llm = FakeLLM(queries=queries)
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
return LLMCheckerChain.from_llm(fake_llm, input_key="q", output_key="a")
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
+1 -1
View File
@@ -17,7 +17,7 @@ def fake_llm_math_chain() -> LLMMathChain:
_PROMPT_TEMPLATE.format(question="foo"): "foo",
}
fake_llm = FakeLLM(queries=queries)
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a")
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
@@ -32,7 +32,9 @@ def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
): "True",
}
fake_llm = FakeLLM(queries=queries)
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
return LLMSummarizationCheckerChain.from_llm(
fake_llm, input_key="q", output_key="a"
)
def test_simple_text(
+10 -4
View File
@@ -2,6 +2,7 @@
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM
@@ -9,7 +10,12 @@ from langchain.llms.base import LLM
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:
return "foo"
@@ -28,7 +34,7 @@ class FakeLLM(LLM):
def test_proper_inputs() -> None:
"""Test that natbot shortens inputs correctly."""
nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
url = "foo" * 10000
browser_content = "foo" * 10000
output = nat_bot_chain.execute(url, browser_content)
@@ -37,8 +43,8 @@ def test_proper_inputs() -> None:
def test_variable_key_naming() -> None:
"""Test that natbot handles variable key naming correctly."""
nat_bot_chain = NatBotChain(
llm=FakeLLM(),
nat_bot_chain = NatBotChain.from_llm(
FakeLLM(),
objective="testing",
input_url_key="u",
input_browser_content_key="b",
+7 -2
View File
@@ -1,8 +1,9 @@
"""Test pipeline functionality."""
from typing import Dict, List
from typing import Dict, List, Optional
import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory.simple import SimpleMemory
@@ -24,7 +25,11 @@ class FakeChain(Chain):
"""Input keys this chain returns."""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]