mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
async output parser (#8894)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
@@ -475,7 +475,8 @@ class Agent(BaseSingleActionAgent):
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_output = await self.output_parser.aparse(full_output)
|
||||
return agent_output
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
|
||||
@@ -27,6 +28,20 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
@@ -51,6 +66,26 @@ class BaseGenerationOutputParser(
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
@@ -99,6 +134,26 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@@ -125,6 +180,32 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await self.aparse(result[0].text)
|
||||
|
||||
async def aparse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
|
||||
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
@@ -192,6 +193,37 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
return output
|
||||
|
||||
async def _acall_with_config(
|
||||
self,
|
||||
func: Callable[[Input], Awaitable[Output]],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
output = await func(input)
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
|
||||
Reference in New Issue
Block a user