From 34088107481a029b60bc6dd166770e80ff72fadd Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 22 Aug 2023 12:31:18 -0400 Subject: [PATCH] Add batch util (#9620) Add `batch` utility to langchain --- libs/langchain/langchain/utils/iter.py | 12 +++++++++++ .../tests/unit_tests/utils/__init__.py | 0 .../tests/unit_tests/utils/test_iter.py | 21 +++++++++++++++++++ 3 files changed, 33 insertions(+) create mode 100644 libs/langchain/tests/unit_tests/utils/__init__.py create mode 100644 libs/langchain/tests/unit_tests/utils/test_iter.py diff --git a/libs/langchain/langchain/utils/iter.py b/libs/langchain/langchain/utils/iter.py index 1b95f180e..60834163c 100644 --- a/libs/langchain/langchain/utils/iter.py +++ b/libs/langchain/langchain/utils/iter.py @@ -1,10 +1,12 @@ from collections import deque +from itertools import islice from typing import ( Any, ContextManager, Deque, Generator, Generic, + Iterable, Iterator, List, Optional, @@ -161,3 +163,13 @@ class Tee(Generic[T]): # Why this is needed https://stackoverflow.com/a/44638570 safetee = Tee + + +def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: + """Utility batching function.""" + it = iter(iterable) + while True: + chunk = list(islice(it, size)) + if not chunk: + return + yield chunk diff --git a/libs/langchain/tests/unit_tests/utils/__init__.py b/libs/langchain/tests/unit_tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/langchain/tests/unit_tests/utils/test_iter.py b/libs/langchain/tests/unit_tests/utils/test_iter.py new file mode 100644 index 000000000..f0fd8bf4c --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_iter.py @@ -0,0 +1,21 @@ +from typing import List + +import pytest + +from langchain.utils.iter import batch_iterate + + +@pytest.mark.parametrize( + "input_size, input_iterable, expected_output", + [ + (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), + (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]), + (1, [100, 200, 300], [[100], [200], [300]]), + (4, [], []), + ], +) +def test_batch_iterate( + input_size: int, input_iterable: List[str], expected_output: List[str] +) -> None: + """Test batching function.""" + assert list(batch_iterate(input_size, input_iterable)) == expected_output