Callbacks Refactor [base] (#3256)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ankush Gola
2023-04-30 11:14:09 -07:00
committed by GitHub
parent 18ec22fe56
commit d3ec00b566
208 changed files with 6394 additions and 3353 deletions
@@ -27,7 +27,7 @@ def test_sql_database_run() -> None:
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db)
output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output == expected_output
@@ -41,7 +41,7 @@ def test_sql_database_run_update() -> None:
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db)
output = db_chain.run("Update Harrison's workplace to Bar")
expected_output = " Harrison's workplace has been updated to Bar."
assert output == expected_output
@@ -59,9 +59,7 @@ def test_sql_database_sequential_chain_run() -> None:
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseSequentialChain.from_llm(
llm=OpenAI(temperature=0), database=db
)
db_chain = SQLDatabaseSequentialChain.from_llm(OpenAI(temperature=0), db)
output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output == expected_output
@@ -77,7 +75,7 @@ def test_sql_database_sequential_chain_intermediate_steps() -> None:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseSequentialChain.from_llm(
llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True
OpenAI(temperature=0), db, return_intermediate_steps=True
)
output = db_chain("What company does Harrison work at?")
expected_output = " Harrison works at Foo."