From 52f34de9b7914b4836fece4bdf2ef9a0a1f7e696 Mon Sep 17 00:00:00 2001 From: Anush Date: Sat, 11 Nov 2023 00:21:52 +0530 Subject: [PATCH] feat: FastEmbed embedding provider (#13109) ## Description: This PR intends to add [Qdrant/FastEmbed](https://qdrant.github.io/fastembed/) as a local embeddings provider, associated tests and documentation. **Documentation preview:** https://langchain-git-fork-anush008-master-langchain.vercel.app/docs/integrations/text_embedding/fastembed --------- Co-authored-by: Eugene Yurtsev --- .../text_embedding/fastembed.ipynb | 154 ++++++++++++++++++ .../langchain/embeddings/__init__.py | 2 + .../langchain/embeddings/fastembed.py | 108 ++++++++++++ .../embeddings/test_fastembed.py | 76 +++++++++ .../unit_tests/embeddings/test_imports.py | 1 + 5 files changed, 341 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/fastembed.ipynb create mode 100644 libs/langchain/langchain/embeddings/fastembed.py create mode 100644 libs/langchain/tests/integration_tests/embeddings/test_fastembed.py diff --git a/docs/docs/integrations/text_embedding/fastembed.ipynb b/docs/docs/integrations/text_embedding/fastembed.ipynb new file mode 100644 index 000000000..9d6826f92 --- /dev/null +++ b/docs/docs/integrations/text_embedding/fastembed.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qdrant FastEmbed\n", + "\n", + "[FastEmbed](https://qdrant.github.io/fastembed/) is a lightweight, fast, Python library built for embedding generation. \n", + "\n", + "- Quantized model weights\n", + "- ONNX Runtime, no PyTorch dependency\n", + "- CPU-first design\n", + "- Data-parallelism for encoding of large datasets." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2a773d8d", + "metadata": {}, + "source": [ + "## Dependencies\n", + "\n", + "To use FastEmbed with LangChain, install the `fastembed` Python package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91ea14ce-831d-409a-a88f-30353acdabd1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install fastembed" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "426f1156", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3f5dc9d7-65e3-4b5b-9086-3327d016cfe0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.embeddings.fastembed import FastEmbedEmbeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiating FastEmbed\n", + " \n", + "### Parameters\n", + "- `model_name: str` (default: \"BAAI/bge-small-en-v1.5\")\n", + " > Name of the FastEmbedding model to use. You can find the list of supported models [here](https://qdrant.github.io/fastembed/examples/Supported_Models/).\n", + "\n", + "- `max_length: int` (default: 512)\n", + " > The maximum number of tokens. Unknown behavior for values > 512.\n", + "\n", + "- `cache_dir: Optional[str]`\n", + " > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n", + "\n", + "- `threads: Optional[int]`\n", + " > The number of threads a single onnxruntime session can use. Defaults to None.\n", + "\n", + "- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n", + " > \"default\": Uses FastEmbed's default embedding method.\n", + " \n", + " > \"passage\": Prefixes the text with \"passage\" before embedding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fb585dd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "embeddings = FastEmbedEmbeddings()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "### Generating document embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "document_embeddings = embeddings.embed_documents([\"This is a document\", \"This is some other document\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating query embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_embeddings = embeddings.embed_query(\"This is a query\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index dfbb814fb..8ead8ee1c 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -32,6 +32,7 @@ from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings from langchain.embeddings.ernie import ErnieEmbeddings from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings +from langchain.embeddings.fastembed import FastEmbedEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings from langchain.embeddings.gradient_ai import GradientEmbeddings @@ -77,6 +78,7 @@ __all__ = [ "ClarifaiEmbeddings", "CohereEmbeddings", "ElasticsearchEmbeddings", + "FastEmbedEmbeddings", "HuggingFaceEmbeddings", "HuggingFaceInferenceAPIEmbeddings", "GradientEmbeddings", diff --git a/libs/langchain/langchain/embeddings/fastembed.py b/libs/langchain/langchain/embeddings/fastembed.py new file mode 100644 index 000000000..cbc2c9ff1 --- /dev/null +++ b/libs/langchain/langchain/embeddings/fastembed.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Literal, Optional + +import numpy as np + +from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain.schema.embeddings import Embeddings + + +class FastEmbedEmbeddings(BaseModel, Embeddings): + """Qdrant FastEmbedding models. + FastEmbed is a lightweight, fast, Python library built for embedding generation. + See more documentation at: + * https://github.com/qdrant/fastembed/ + * https://qdrant.github.io/fastembed/ + + To use this class, you must install the `fastembed` Python package. + + `pip install fastembed` + Example: + from langchain.embeddings import FastEmbedEmbeddings + fastembed = FastEmbedEmbeddings() + """ + + model_name: str = "BAAI/bge-small-en-v1.5" + """Name of the FastEmbedding model to use + Defaults to "BAAI/bge-small-en-v1.5" + Find the list of supported models at + https://qdrant.github.io/fastembed/examples/Supported_Models/ + """ + + max_length: int = 512 + """The maximum number of tokens. Defaults to 512. + Unknown behavior for values > 512. + """ + + cache_dir: Optional[str] + """The path to the cache directory. + Defaults to `local_cache` in the parent directory + """ + + threads: Optional[int] + """The number of threads single onnxruntime session can use. + Defaults to None + """ + + doc_embed_type: Literal["default", "passage"] = "default" + """Type of embedding to use for documents + "default": Uses FastEmbed's default embedding method + "passage": Prefixes the text with "passage" before embedding. + """ + + _model: Any # : :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that FastEmbed has been installed.""" + try: + from fastembed.embedding import FlagEmbedding + + model_name = values.get("model_name") + max_length = values.get("max_length") + cache_dir = values.get("cache_dir") + threads = values.get("threads") + values["_model"] = FlagEmbedding( + model_name=model_name, + max_length=max_length, + cache_dir=cache_dir, + threads=threads, + ) + except ImportError as ie: + raise ImportError( + "Could not import 'fastembed' Python package. " + "Please install it with `pip install fastembed`." + ) from ie + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for documents using FastEmbed. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings: List[np.ndarray] + if self.doc_embed_type == "passage": + embeddings = self._model.passage_embed(texts) + else: + embeddings = self._model.embed(texts) + return [e.tolist() for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Generate query embeddings using FastEmbed. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + query_embeddings: np.ndarray = next(self._model.query_embed(text)) + return query_embeddings.tolist() diff --git a/libs/langchain/tests/integration_tests/embeddings/test_fastembed.py b/libs/langchain/tests/integration_tests/embeddings/test_fastembed.py new file mode 100644 index 000000000..84e4ea67f --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_fastembed.py @@ -0,0 +1,76 @@ +"""Test FastEmbed embeddings.""" +import pytest + +from langchain.embeddings.fastembed import FastEmbedEmbeddings + + +@pytest.mark.parametrize( + "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] +) +@pytest.mark.parametrize("max_length", [50, 512]) +@pytest.mark.parametrize("doc_embed_type", ["default", "passage"]) +@pytest.mark.parametrize("threads", [0, 10]) +def test_fastembed_embedding_documents( + model_name: str, max_length: int, doc_embed_type: str, threads: int +) -> None: + """Test fastembed embeddings for documents.""" + documents = ["foo bar", "bar foo"] + embedding = FastEmbedEmbeddings( + model_name=model_name, + max_length=max_length, + doc_embed_type=doc_embed_type, + threads=threads, + ) + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 384 + + +@pytest.mark.parametrize( + "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] +) +@pytest.mark.parametrize("max_length", [50, 512]) +def test_fastembed_embedding_query(model_name: str, max_length: int) -> None: + """Test fastembed embeddings for query.""" + document = "foo bar" + embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) + output = embedding.embed_query(document) + assert len(output) == 384 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] +) +@pytest.mark.parametrize("max_length", [50, 512]) +@pytest.mark.parametrize("doc_embed_type", ["default", "passage"]) +@pytest.mark.parametrize("threads", [0, 10]) +async def test_fastembed_async_embedding_documents( + model_name: str, max_length: int, doc_embed_type: str, threads: int +) -> None: + """Test fastembed embeddings for documents.""" + documents = ["foo bar", "bar foo"] + embedding = FastEmbedEmbeddings( + model_name=model_name, + max_length=max_length, + doc_embed_type=doc_embed_type, + threads=threads, + ) + output = await embedding.aembed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 384 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] +) +@pytest.mark.parametrize("max_length", [50, 512]) +async def test_fastembed_async_embedding_query( + model_name: str, max_length: int +) -> None: + """Test fastembed embeddings for query.""" + document = "foo bar" + embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) + output = await embedding.aembed_query(document) + assert len(output) == 384 diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index 6fe7a85cf..a8884d23b 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -7,6 +7,7 @@ EXPECTED_ALL = [ "ClarifaiEmbeddings", "CohereEmbeddings", "ElasticsearchEmbeddings", + "FastEmbedEmbeddings", "HuggingFaceEmbeddings", "HuggingFaceInferenceAPIEmbeddings", "GradientEmbeddings",