From b599f91e3333cc2ba450ddb39d775644cdc34d0a Mon Sep 17 00:00:00 2001 From: Hugues Date: Fri, 29 Sep 2023 00:00:38 +0200 Subject: [PATCH] LLMonitor Callback handler: fix bug (#11128) Here is a small bug fix for the LLMonitor callback handler. I've also added user identification capabilities. --- .../integrations/callbacks/llmonitor.md | 20 ++++- .../langchain/callbacks/llmonitor_callback.py | 81 ++++++++++++++++--- 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/docs/extras/integrations/callbacks/llmonitor.md b/docs/extras/integrations/callbacks/llmonitor.md index 9d81ce12f..9cbf1e367 100644 --- a/docs/extras/integrations/callbacks/llmonitor.md +++ b/docs/extras/integrations/callbacks/llmonitor.md @@ -37,10 +37,10 @@ llm = OpenAI( callbacks=[handler], ) -chat = ChatOpenAI( - callbacks=[handler], - metadata={"userId": "123"}, # you can assign user ids to models in the metadata -) +chat = ChatOpenAI(callbacks=[handler]) + +llm("Tell me a joke") + ``` ## Usage with chains and agents @@ -100,6 +100,18 @@ agent.run( ) ``` +## User Tracking +User tracking allows you to identify your users, track their cost, conversations and more. + +```python +from langchain.callbacks.llmonitor_callback import LLMonitorCallbackHandler, identify + +with identify("user-123"): + llm("Tell me a joke") + +with identify("user-456", user_props={"email": "user456@test.com"}): + agen.run("Who is Leo DiCaprio's girlfriend?") +``` ## Support For any question or issue with integration you can reach out to the LLMonitor team on [Discord](http://discord.com/invite/8PafSG58kK) or via [email](mailto:vince@llmonitor.com). diff --git a/libs/langchain/langchain/callbacks/llmonitor_callback.py b/libs/langchain/langchain/callbacks/llmonitor_callback.py index 6e9e4e532..beadce906 100644 --- a/libs/langchain/langchain/callbacks/llmonitor_callback.py +++ b/libs/langchain/langchain/callbacks/llmonitor_callback.py @@ -1,5 +1,6 @@ import os import traceback +from contextvars import ContextVar from datetime import datetime from typing import Any, Dict, List, Literal, Union from uuid import UUID @@ -13,6 +14,26 @@ from langchain.schema.output import LLMResult DEFAULT_API_URL = "https://app.llmonitor.com" +user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None) +user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None) + + +class UserContextManager: + def __init__(self, user_id: str, user_props: Any = None) -> None: + user_ctx.set(user_id) + user_props_ctx.set(user_props) + + def __enter__(self) -> Any: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any: + user_ctx.set(None) + user_props_ctx.set(None) + + +def identify(user_id: str, user_props: Any = None) -> UserContextManager: + return UserContextManager(user_id, user_props) + def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]: if hasattr(obj, "to_json"): @@ -94,13 +115,24 @@ def _parse_lc_role( def _get_user_id(metadata: Any) -> Any: + if user_ctx.get() is not None: + return user_ctx.get() + metadata = metadata or {} user_id = metadata.get("user_id") if user_id is None: - user_id = metadata.get("userId") + user_id = metadata.get("userId") # legacy, to delete in the future return user_id +def _get_user_props(metadata: Any) -> Any: + if user_props_ctx.get() is not None: + return user_props_ctx.get() + + metadata = metadata or {} + return metadata.get("user_props") + + def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]: parsed = {"text": message.content, "role": _parse_lc_role(message.type)} @@ -198,10 +230,13 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): metadata: Union[Dict[str, Any], None] = None, **kwargs: Any, ) -> None: + user_id = _get_user_id(metadata) + user_props = _get_user_props(metadata) + event = { "event": "start", "type": "llm", - "userId": (metadata or {}).get("userId"), + "userId": user_id, "runId": str(run_id), "parentRunId": str(parent_run_id) if parent_run_id else None, "input": _parse_input(prompts), @@ -209,6 +244,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): "tags": tags, "metadata": metadata, } + if user_props: + event["userProps"] = user_props + self.__send_event(event) def on_chat_model_start( @@ -223,6 +261,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> Any: user_id = _get_user_id(metadata) + user_props = _get_user_props(metadata) event = { "event": "start", @@ -235,6 +274,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): "tags": tags, "metadata": metadata, } + if user_props: + event["userProps"] = user_props + self.__send_event(event) def on_llm_end( @@ -247,12 +289,24 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): ) -> None: token_usage = (response.llm_output or {}).get("token_usage", {}) - parsed_output = _parse_lc_messages( - map( - lambda o: o.message if hasattr(o, "message") else None, - response.generations[0], - ) - ) + parsed_output = [ + { + "text": generation.text, + "role": "ai", + **( + { + "functionCall": generation.message.additional_kwargs[ + "function_call" + ] + } + if hasattr(generation, "message") + and hasattr(generation.message, "additional_kwargs") + and "function_call" in generation.message.additional_kwargs + else {} + ), + } + for generation in response.generations[0] + ] event = { "event": "end", @@ -279,6 +333,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: user_id = _get_user_id(metadata) + user_props = _get_user_props(metadata) + event = { "event": "start", "type": "tool", @@ -290,6 +346,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): "tags": tags, "metadata": metadata, } + if user_props: + event["userProps"] = user_props + self.__send_event(event) def on_tool_end( @@ -339,6 +398,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): type = "chain" user_id = _get_user_id(metadata) + user_props = _get_user_props(metadata) event = { "event": "start", @@ -351,6 +411,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): "metadata": metadata, "name": name, } + if user_props: + event["userProps"] = user_props + self.__send_event(event) def on_chain_end( @@ -456,4 +519,4 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): self.__send_event(event) -__all__ = ["LLMonitorCallbackHandler"] +__all__ = ["LLMonitorCallbackHandler", "identify"]