mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Fix for KeyError in MlflowCallbackHandler (#7051)
- Description: `MlflowCallbackHandler` fails with `KeyError: "['name'] not in index"`. See https://github.com/hwchase17/langchain/issues/5770 for more details. Root cause is that LangChain does not pass "name" as a part of `serialized` argument to `on_llm_start()` callback method. The commit where this change was made is probably this: https://github.com/hwchase17/langchain/commit/18af149e91e62b3ac7728ddea420688d41043734. My bug fix derives "name" from "id" field. - Issue: https://github.com/hwchase17/langchain/issues/5770 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -551,8 +551,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
|
||||
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
|
||||
|
||||
llm_input_columns = ["step", "prompt"]
|
||||
if "name" in on_llm_start_records_df.columns:
|
||||
llm_input_columns.append("name")
|
||||
elif "id" in on_llm_start_records_df.columns:
|
||||
# id is llm class's full import path. For example:
|
||||
# ["langchain", "llms", "openai", "AzureOpenAI"]
|
||||
on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
|
||||
lambda id_: id_[-1]
|
||||
)
|
||||
llm_input_columns.append("name")
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[["step", "prompt", "name"]]
|
||||
on_llm_start_records_df[llm_input_columns]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user