From 4d722884873a339f5e903feacbdcdb51bd67da73 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 9 Aug 2023 00:25:38 -0700 Subject: [PATCH] async output parser (#8894) Co-authored-by: Nuno Campos --- libs/langchain/langchain/agents/agent.py | 3 +- .../langchain/schema/output_parser.py | 81 +++++++++++++++++++ libs/langchain/langchain/schema/runnable.py | 32 ++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 505cecc99..9ac57726f 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -475,7 +475,8 @@ class Agent(BaseSingleActionAgent): """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) - return self.output_parser.parse(full_output) + agent_output = await self.output_parser.aparse(full_output) + return agent_output def get_full_inputs( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index b21ef5943..cce0f7056 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, TypeVar, Union @@ -27,6 +28,20 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): Structured output. """ + async def aparse_result(self, result: List[Generation]) -> T: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return await asyncio.get_running_loop().run_in_executor( + None, self.parse_result, result + ) + class BaseGenerationOutputParser( BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] @@ -51,6 +66,26 @@ class BaseGenerationOutputParser( run_type="parser", ) + async def ainvoke( + self, input: str | BaseMessage, config: RunnableConfig | None = None + ) -> T: + if isinstance(input, BaseMessage): + return await self._acall_with_config( + lambda inner_input: self.aparse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]): """Base class to parse the output of an LLM call. @@ -99,6 +134,26 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] run_type="parser", ) + async def ainvoke( + self, input: str | BaseMessage, config: RunnableConfig | None = None + ) -> T: + if isinstance(input, BaseMessage): + return await self._acall_with_config( + lambda inner_input: self.aparse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + def parse_result(self, result: List[Generation]) -> T: """Parse a list of candidate model Generations into a specific format. @@ -125,6 +180,32 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] Structured output. """ + async def aparse_result(self, result: List[Generation]) -> T: + """Parse a list of candidate model Generations into a specific format. + + The return value is parsed from only the first Generation in the result, which + is assumed to be the highest-likelihood Generation. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return await self.aparse(result[0].text) + + async def aparse(self, text: str) -> T: + """Parse a single string model output into some structure. + + Args: + text: String output of a language model. + + Returns: + Structured output. + """ + return await asyncio.get_running_loop().run_in_executor(None, self.parse, text) + # TODO: rename 'completion' -> 'text'. def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: """Parse the output of an LLM call with the input prompt for context. diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 4cfd3f913..5ff837ee7 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncIterator, + Awaitable, Callable, Coroutine, Dict, @@ -192,6 +193,37 @@ class Runnable(Generic[Input, Output], ABC): ) return output + async def _acall_with_config( + self, + func: Callable[[Input], Awaitable[Output]], + input: Input, + config: Optional[RunnableConfig], + run_type: Optional[str] = None, + ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + run_manager = await callback_manager.on_chain_start( + dumpd(self), + input if isinstance(input, dict) else {"input": input}, + run_type=run_type, + ) + try: + output = await func(input) + except Exception as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + return output + def with_fallbacks( self, fallbacks: Sequence[Runnable[Input, Output]],