From 4e28a7a5132e4e49bffe6bd261aa40d22214b813 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 14:12:48 +0100 Subject: [PATCH] Implement diff --- .../langchain/schema/output_parser.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 5f09ee48a..89a065e9a 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]): class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """Base class for an output parser that can handle streaming input.""" + diff: bool = False + + def _diff(self, prev: Optional[T], next: T) -> T: + raise NotImplementedError() + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + prev_parsed = None acc_gen = None for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[T]: + prev_parsed = None acc_gen = None - for chunk in input: + async for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed class StrOutputParser(BaseTransformOutputParser[str]):