mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from langchain.chains.pal.base import PALChain
|
||||
|
||||
def test_math_prompt() -> None:
|
||||
"""Test math prompt."""
|
||||
llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512)
|
||||
llm = OpenAI(temperature=0, max_tokens=512)
|
||||
pal_chain = PALChain.from_math_prompt(llm)
|
||||
question = (
|
||||
"Jan has three times the number of pets as Marcia. "
|
||||
@@ -19,7 +19,7 @@ def test_math_prompt() -> None:
|
||||
|
||||
def test_colored_object_prompt() -> None:
|
||||
"""Test colored object prompt."""
|
||||
llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512)
|
||||
llm = OpenAI(temperature=0, max_tokens=512)
|
||||
pal_chain = PALChain.from_colored_object_prompt(llm)
|
||||
question = (
|
||||
"On the desk, you see two blue booklets, "
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_sql_database_run() -> None:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
|
||||
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db)
|
||||
output = db_chain.run("What company does Harrison work at?")
|
||||
expected_output = " Harrison works at Foo."
|
||||
assert output == expected_output
|
||||
@@ -41,7 +41,7 @@ def test_sql_database_run_update() -> None:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
|
||||
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db)
|
||||
output = db_chain.run("Update Harrison's workplace to Bar")
|
||||
expected_output = " Harrison's workplace has been updated to Bar."
|
||||
assert output == expected_output
|
||||
@@ -59,9 +59,7 @@ def test_sql_database_sequential_chain_run() -> None:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(
|
||||
llm=OpenAI(temperature=0), database=db
|
||||
)
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(OpenAI(temperature=0), db)
|
||||
output = db_chain.run("What company does Harrison work at?")
|
||||
expected_output = " Harrison works at Foo."
|
||||
assert output == expected_output
|
||||
@@ -77,7 +75,7 @@ def test_sql_database_sequential_chain_intermediate_steps() -> None:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(
|
||||
llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True
|
||||
OpenAI(temperature=0), db, return_intermediate_steps=True
|
||||
)
|
||||
output = db_chain("What company does Harrison work at?")
|
||||
expected_output = " Harrison works at Foo."
|
||||
|
||||
Reference in New Issue
Block a user