mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
eef62bf4e9
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> Added support to `search_by_vector` to Qdrant Vector store. <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> ### Who can review VectorStores / Retrievers / Memory - @dev2049 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @hwchase17 -->
815 lines
31 KiB
Python
815 lines
31 KiB
Python
"""Wrapper around Qdrant vector database."""
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
import warnings
|
|
from itertools import islice
|
|
from operator import itemgetter
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
import numpy as np
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.vectorstores import VectorStore
|
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
if TYPE_CHECKING:
|
|
from qdrant_client.conversions import common_types
|
|
from qdrant_client.http import models as rest
|
|
|
|
DictFilter = Dict[str, Union[str, int, bool, dict, list]]
|
|
MetadataFilter = Union[DictFilter, common_types.Filter]
|
|
|
|
|
|
class Qdrant(VectorStore):
|
|
"""Wrapper around Qdrant vector database.
|
|
|
|
To use you should have the ``qdrant-client`` package installed.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from qdrant_client import QdrantClient
|
|
from langchain import Qdrant
|
|
|
|
client = QdrantClient()
|
|
collection_name = "MyCollection"
|
|
qdrant = Qdrant(client, collection_name, embedding_function)
|
|
"""
|
|
|
|
CONTENT_KEY = "page_content"
|
|
METADATA_KEY = "metadata"
|
|
|
|
def __init__(
|
|
self,
|
|
client: Any,
|
|
collection_name: str,
|
|
embeddings: Optional[Embeddings] = None,
|
|
content_payload_key: str = CONTENT_KEY,
|
|
metadata_payload_key: str = METADATA_KEY,
|
|
embedding_function: Optional[Callable] = None, # deprecated
|
|
):
|
|
"""Initialize with necessary components."""
|
|
try:
|
|
import qdrant_client
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import qdrant-client python package. "
|
|
"Please install it with `pip install qdrant-client`."
|
|
)
|
|
|
|
if not isinstance(client, qdrant_client.QdrantClient):
|
|
raise ValueError(
|
|
f"client should be an instance of qdrant_client.QdrantClient, "
|
|
f"got {type(client)}"
|
|
)
|
|
|
|
if embeddings is None and embedding_function is None:
|
|
raise ValueError(
|
|
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
|
)
|
|
|
|
if embeddings is not None and embedding_function is not None:
|
|
raise ValueError(
|
|
"Both `embeddings` and `embedding_function` are passed. "
|
|
"Use `embeddings` only."
|
|
)
|
|
|
|
self.embeddings = embeddings
|
|
self._embeddings_function = embedding_function
|
|
self.client: qdrant_client.QdrantClient = client
|
|
self.collection_name = collection_name
|
|
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
|
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
|
|
|
if embedding_function is not None:
|
|
warnings.warn(
|
|
"Using `embedding_function` is deprecated. "
|
|
"Pass `Embeddings` instance to `embeddings` instead."
|
|
)
|
|
|
|
if not isinstance(embeddings, Embeddings):
|
|
warnings.warn(
|
|
"`embeddings` should be an instance of `Embeddings`."
|
|
"Using `embeddings` as `embedding_function` which is deprecated"
|
|
)
|
|
self._embeddings_function = embeddings
|
|
self.embeddings = None
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[Sequence[str]] = None,
|
|
batch_size: int = 64,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Run more texts through the embeddings and add to the vectorstore.
|
|
|
|
Args:
|
|
texts: Iterable of strings to add to the vectorstore.
|
|
metadatas: Optional list of metadatas associated with the texts.
|
|
ids:
|
|
Optional list of ids to associate with the texts. Ids have to be
|
|
uuid-like strings.
|
|
batch_size:
|
|
How many vectors upload per-request.
|
|
Default: 64
|
|
|
|
Returns:
|
|
List of ids from adding the texts into the vectorstore.
|
|
"""
|
|
from qdrant_client.http import models as rest
|
|
|
|
added_ids = []
|
|
texts_iterator = iter(texts)
|
|
metadatas_iterator = iter(metadatas or [])
|
|
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
|
while batch_texts := list(islice(texts_iterator, batch_size)):
|
|
# Take the corresponding metadata and id for each text in a batch
|
|
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
|
batch_ids = list(islice(ids_iterator, batch_size))
|
|
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=rest.Batch.construct(
|
|
ids=batch_ids,
|
|
vectors=self._embed_texts(batch_texts),
|
|
payloads=self._build_payloads(
|
|
batch_texts,
|
|
batch_metadatas,
|
|
self.content_payload_key,
|
|
self.metadata_payload_key,
|
|
),
|
|
),
|
|
)
|
|
|
|
added_ids.extend(batch_ids)
|
|
|
|
return added_ids
|
|
|
|
def similarity_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
filter: Optional[MetadataFilter] = None,
|
|
search_params: Optional[common_types.SearchParams] = None,
|
|
offset: int = 0,
|
|
score_threshold: Optional[float] = None,
|
|
consistency: Optional[common_types.ReadConsistency] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs most similar to query.
|
|
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
filter: Filter by metadata. Defaults to None.
|
|
search_params: Additional search params
|
|
offset:
|
|
Offset of the first result to return.
|
|
May be used to paginate results.
|
|
Note: large offset values may cause performance issues.
|
|
score_threshold:
|
|
Define a minimal score threshold for the result.
|
|
If defined, less similar results will not be returned.
|
|
Score of the returned result might be higher or smaller than the
|
|
threshold depending on the Distance function used.
|
|
E.g. for cosine similarity only higher scores will be returned.
|
|
consistency:
|
|
Read consistency of the search. Defines how many replicas should be
|
|
queried before returning the result.
|
|
Values:
|
|
- int - number of replicas to query, values should present in all
|
|
queried replicas
|
|
- 'majority' - query all replicas, but return values present in the
|
|
majority of replicas
|
|
- 'quorum' - query the majority of replicas, return values present in
|
|
all of them
|
|
- 'all' - query all replicas, and return values present in all replicas
|
|
|
|
Returns:
|
|
List of Documents most similar to the query.
|
|
"""
|
|
results = self.similarity_search_with_score(
|
|
query,
|
|
k,
|
|
filter=filter,
|
|
search_params=search_params,
|
|
offset=offset,
|
|
score_threshold=score_threshold,
|
|
consistency=consistency,
|
|
**kwargs,
|
|
)
|
|
return list(map(itemgetter(0), results))
|
|
|
|
def similarity_search_with_score(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
filter: Optional[MetadataFilter] = None,
|
|
search_params: Optional[common_types.SearchParams] = None,
|
|
offset: int = 0,
|
|
score_threshold: Optional[float] = None,
|
|
consistency: Optional[common_types.ReadConsistency] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Return docs most similar to query.
|
|
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
filter: Filter by metadata. Defaults to None.
|
|
search_params: Additional search params
|
|
offset:
|
|
Offset of the first result to return.
|
|
May be used to paginate results.
|
|
Note: large offset values may cause performance issues.
|
|
score_threshold:
|
|
Define a minimal score threshold for the result.
|
|
If defined, less similar results will not be returned.
|
|
Score of the returned result might be higher or smaller than the
|
|
threshold depending on the Distance function used.
|
|
E.g. for cosine similarity only higher scores will be returned.
|
|
consistency:
|
|
Read consistency of the search. Defines how many replicas should be
|
|
queried before returning the result.
|
|
Values:
|
|
- int - number of replicas to query, values should present in all
|
|
queried replicas
|
|
- 'majority' - query all replicas, but return values present in the
|
|
majority of replicas
|
|
- 'quorum' - query the majority of replicas, return values present in
|
|
all of them
|
|
- 'all' - query all replicas, and return values present in all replicas
|
|
|
|
Returns:
|
|
List of documents most similar to the query text and cosine
|
|
distance in float for each.
|
|
Lower score represents more similarity.
|
|
"""
|
|
return self.similarity_search_with_score_by_vector(
|
|
self._embed_query(query),
|
|
k,
|
|
filter=filter,
|
|
search_params=search_params,
|
|
offset=offset,
|
|
score_threshold=score_threshold,
|
|
consistency=consistency,
|
|
**kwargs,
|
|
)
|
|
|
|
def similarity_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
filter: Optional[MetadataFilter] = None,
|
|
search_params: Optional[common_types.SearchParams] = None,
|
|
offset: int = 0,
|
|
score_threshold: Optional[float] = None,
|
|
consistency: Optional[common_types.ReadConsistency] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs most similar to embedding vector.
|
|
|
|
Args:
|
|
embedding: Embedding vector to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
filter: Filter by metadata. Defaults to None.
|
|
search_params: Additional search params
|
|
offset:
|
|
Offset of the first result to return.
|
|
May be used to paginate results.
|
|
Note: large offset values may cause performance issues.
|
|
score_threshold:
|
|
Define a minimal score threshold for the result.
|
|
If defined, less similar results will not be returned.
|
|
Score of the returned result might be higher or smaller than the
|
|
threshold depending on the Distance function used.
|
|
E.g. for cosine similarity only higher scores will be returned.
|
|
consistency:
|
|
Read consistency of the search. Defines how many replicas should be
|
|
queried before returning the result.
|
|
Values:
|
|
- int - number of replicas to query, values should present in all
|
|
queried replicas
|
|
- 'majority' - query all replicas, but return values present in the
|
|
majority of replicas
|
|
- 'quorum' - query the majority of replicas, return values present in
|
|
all of them
|
|
- 'all' - query all replicas, and return values present in all replicas
|
|
|
|
Returns:
|
|
List of Documents most similar to the query.
|
|
"""
|
|
|
|
results = self.similarity_search_with_score_by_vector(
|
|
embedding,
|
|
k,
|
|
filter=filter,
|
|
search_params=search_params,
|
|
offset=offset,
|
|
score_threshold=score_threshold,
|
|
consistency=consistency,
|
|
**kwargs,
|
|
)
|
|
return list(map(itemgetter(0), results))
|
|
|
|
def similarity_search_with_score_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
filter: Optional[MetadataFilter] = None,
|
|
search_params: Optional[common_types.SearchParams] = None,
|
|
offset: int = 0,
|
|
score_threshold: Optional[float] = None,
|
|
consistency: Optional[common_types.ReadConsistency] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Return docs most similar to embedding vector.
|
|
|
|
Args:
|
|
embedding: Embedding vector to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
filter: Filter by metadata. Defaults to None.
|
|
search_params: Additional search params
|
|
offset:
|
|
Offset of the first result to return.
|
|
May be used to paginate results.
|
|
Note: large offset values may cause performance issues.
|
|
score_threshold:
|
|
Define a minimal score threshold for the result.
|
|
If defined, less similar results will not be returned.
|
|
Score of the returned result might be higher or smaller than the
|
|
threshold depending on the Distance function used.
|
|
E.g. for cosine similarity only higher scores will be returned.
|
|
consistency:
|
|
Read consistency of the search. Defines how many replicas should be
|
|
queried before returning the result.
|
|
Values:
|
|
- int - number of replicas to query, values should present in all
|
|
queried replicas
|
|
- 'majority' - query all replicas, but return values present in the
|
|
majority of replicas
|
|
- 'quorum' - query the majority of replicas, return values present in
|
|
all of them
|
|
- 'all' - query all replicas, and return values present in all replicas
|
|
|
|
Returns:
|
|
List of documents most similar to the query text and cosine
|
|
distance in float for each.
|
|
Lower score represents more similarity.
|
|
"""
|
|
|
|
if filter is not None and isinstance(filter, dict):
|
|
warnings.warn(
|
|
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
|
"filters directly: "
|
|
"https://qdrant.tech/documentation/concepts/filtering/",
|
|
DeprecationWarning,
|
|
)
|
|
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
|
else:
|
|
qdrant_filter = filter
|
|
results = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=embedding,
|
|
query_filter=qdrant_filter,
|
|
search_params=search_params,
|
|
limit=k,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=False, # Langchain does not expect vectors to be returned
|
|
score_threshold=score_threshold,
|
|
consistency=consistency,
|
|
**kwargs,
|
|
)
|
|
return [
|
|
(
|
|
self._document_from_scored_point(
|
|
result, self.content_payload_key, self.metadata_payload_key
|
|
),
|
|
result.score,
|
|
)
|
|
for result in results
|
|
]
|
|
|
|
def _similarity_search_with_relevance_scores(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Return docs and relevance scores in the range [0, 1].
|
|
|
|
0 is dissimilar, 1 is most similar.
|
|
|
|
Args:
|
|
query: input text
|
|
k: Number of Documents to return. Defaults to 4.
|
|
**kwargs: kwargs to be passed to similarity search. Should include:
|
|
score_threshold: Optional, a floating point value between 0 to 1 to
|
|
filter the resulting set of retrieved docs
|
|
|
|
Returns:
|
|
List of Tuples of (doc, similarity_score)
|
|
"""
|
|
return self.similarity_search_with_score(query, k, **kwargs)
|
|
|
|
def max_marginal_relevance_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
fetch_k: int = 20,
|
|
lambda_mult: float = 0.5,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
among selected documents.
|
|
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
Defaults to 20.
|
|
lambda_mult: Number between 0 and 1 that determines the degree
|
|
of diversity among the results with 0 corresponding
|
|
to maximum diversity and 1 to minimum diversity.
|
|
Defaults to 0.5.
|
|
Returns:
|
|
List of Documents selected by maximal marginal relevance.
|
|
"""
|
|
|
|
embedding = self._embed_query(query)
|
|
results = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=embedding,
|
|
with_payload=True,
|
|
with_vectors=True,
|
|
limit=fetch_k,
|
|
)
|
|
embeddings = [result.vector for result in results]
|
|
mmr_selected = maximal_marginal_relevance(
|
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
|
)
|
|
return [
|
|
self._document_from_scored_point(
|
|
results[i], self.content_payload_key, self.metadata_payload_key
|
|
)
|
|
for i in mmr_selected
|
|
]
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls: Type[Qdrant],
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[Sequence[str]] = None,
|
|
location: Optional[str] = None,
|
|
url: Optional[str] = None,
|
|
port: Optional[int] = 6333,
|
|
grpc_port: int = 6334,
|
|
prefer_grpc: bool = False,
|
|
https: Optional[bool] = None,
|
|
api_key: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
host: Optional[str] = None,
|
|
path: Optional[str] = None,
|
|
collection_name: Optional[str] = None,
|
|
distance_func: str = "Cosine",
|
|
content_payload_key: str = CONTENT_KEY,
|
|
metadata_payload_key: str = METADATA_KEY,
|
|
batch_size: int = 64,
|
|
shard_number: Optional[int] = None,
|
|
replication_factor: Optional[int] = None,
|
|
write_consistency_factor: Optional[int] = None,
|
|
on_disk_payload: Optional[bool] = None,
|
|
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
|
|
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
|
|
wal_config: Optional[common_types.WalConfigDiff] = None,
|
|
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
|
init_from: Optional[common_types.InitFrom] = None,
|
|
**kwargs: Any,
|
|
) -> Qdrant:
|
|
"""Construct Qdrant wrapper from a list of texts.
|
|
|
|
Args:
|
|
texts: A list of texts to be indexed in Qdrant.
|
|
embedding: A subclass of `Embeddings`, responsible for text vectorization.
|
|
metadatas:
|
|
An optional list of metadata. If provided it has to be of the same
|
|
length as a list of texts.
|
|
ids:
|
|
Optional list of ids to associate with the texts. Ids have to be
|
|
uuid-like strings.
|
|
location:
|
|
If `:memory:` - use in-memory Qdrant instance.
|
|
If `str` - use it as a `url` parameter.
|
|
If `None` - fallback to relying on `host` and `port` parameters.
|
|
url: either host or str of "Optional[scheme], host, Optional[port],
|
|
Optional[prefix]". Default: `None`
|
|
port: Port of the REST API interface. Default: 6333
|
|
grpc_port: Port of the gRPC interface. Default: 6334
|
|
prefer_grpc:
|
|
If true - use gPRC interface whenever possible in custom methods.
|
|
Default: False
|
|
https: If true - use HTTPS(SSL) protocol. Default: None
|
|
api_key: API key for authentication in Qdrant Cloud. Default: None
|
|
prefix:
|
|
If not None - add prefix to the REST URL path.
|
|
Example: service/v1 will result in
|
|
http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
|
|
Default: None
|
|
timeout:
|
|
Timeout for REST and gRPC API requests.
|
|
Default: 5.0 seconds for REST and unlimited for gRPC
|
|
host:
|
|
Host name of Qdrant service. If url and host are None, set to
|
|
'localhost'. Default: None
|
|
path:
|
|
Path in which the vectors will be stored while using local mode.
|
|
Default: None
|
|
collection_name:
|
|
Name of the Qdrant collection to be used. If not provided,
|
|
it will be created randomly. Default: None
|
|
distance_func:
|
|
Distance function. One of: "Cosine" / "Euclid" / "Dot".
|
|
Default: "Cosine"
|
|
content_payload_key:
|
|
A payload key used to store the content of the document.
|
|
Default: "page_content"
|
|
metadata_payload_key:
|
|
A payload key used to store the metadata of the document.
|
|
Default: "metadata"
|
|
batch_size:
|
|
How many vectors upload per-request.
|
|
Default: 64
|
|
shard_number: Number of shards in collection. Default is 1, minimum is 1.
|
|
replication_factor:
|
|
Replication factor for collection. Default is 1, minimum is 1.
|
|
Defines how many copies of each shard will be created.
|
|
Have effect only in distributed mode.
|
|
write_consistency_factor:
|
|
Write consistency factor for collection. Default is 1, minimum is 1.
|
|
Defines how many replicas should apply the operation for us to consider
|
|
it successful. Increasing this number will make the collection more
|
|
resilient to inconsistencies, but will also make it fail if not enough
|
|
replicas are available.
|
|
Does not have any performance impact.
|
|
Have effect only in distributed mode.
|
|
on_disk_payload:
|
|
If true - point`s payload will not be stored in memory.
|
|
It will be read from the disk every time it is requested.
|
|
This setting saves RAM by (slightly) increasing the response time.
|
|
Note: those payload values that are involved in filtering and are
|
|
indexed - remain in RAM.
|
|
hnsw_config: Params for HNSW index
|
|
optimizers_config: Params for optimizer
|
|
wal_config: Params for Write-Ahead-Log
|
|
quantization_config:
|
|
Params for quantization, if None - quantization will be disabled
|
|
init_from:
|
|
Use data stored in another collection to initialize this collection
|
|
**kwargs:
|
|
Additional arguments passed directly into REST client initialization
|
|
|
|
This is a user-friendly interface that:
|
|
1. Creates embeddings, one for each text
|
|
2. Initializes the Qdrant database as an in-memory docstore by default
|
|
(and overridable to a remote docstore)
|
|
3. Adds the text embeddings to the Qdrant database
|
|
|
|
This is intended to be a quick way to get started.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import Qdrant
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
embeddings = OpenAIEmbeddings()
|
|
qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
|
|
"""
|
|
try:
|
|
import qdrant_client
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import qdrant-client python package. "
|
|
"Please install it with `pip install qdrant-client`."
|
|
)
|
|
|
|
from qdrant_client.http import models as rest
|
|
|
|
# Just do a single quick embedding to get vector size
|
|
partial_embeddings = embedding.embed_documents(texts[:1])
|
|
vector_size = len(partial_embeddings[0])
|
|
|
|
collection_name = collection_name or uuid.uuid4().hex
|
|
distance_func = distance_func.upper()
|
|
|
|
client = qdrant_client.QdrantClient(
|
|
location=location,
|
|
url=url,
|
|
port=port,
|
|
grpc_port=grpc_port,
|
|
prefer_grpc=prefer_grpc,
|
|
https=https,
|
|
api_key=api_key,
|
|
prefix=prefix,
|
|
timeout=timeout,
|
|
host=host,
|
|
path=path,
|
|
**kwargs,
|
|
)
|
|
|
|
client.recreate_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=rest.VectorParams(
|
|
size=vector_size,
|
|
distance=rest.Distance[distance_func],
|
|
),
|
|
shard_number=shard_number,
|
|
replication_factor=replication_factor,
|
|
write_consistency_factor=write_consistency_factor,
|
|
on_disk_payload=on_disk_payload,
|
|
hnsw_config=hnsw_config,
|
|
optimizers_config=optimizers_config,
|
|
wal_config=wal_config,
|
|
quantization_config=quantization_config,
|
|
init_from=init_from,
|
|
timeout=timeout, # type: ignore[arg-type]
|
|
)
|
|
|
|
texts_iterator = iter(texts)
|
|
metadatas_iterator = iter(metadatas or [])
|
|
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
|
while batch_texts := list(islice(texts_iterator, batch_size)):
|
|
# Take the corresponding metadata and id for each text in a batch
|
|
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
|
batch_ids = list(islice(ids_iterator, batch_size))
|
|
|
|
# Generate the embeddings for all the texts in a batch
|
|
batch_embeddings = embedding.embed_documents(batch_texts)
|
|
|
|
client.upsert(
|
|
collection_name=collection_name,
|
|
points=rest.Batch.construct(
|
|
ids=batch_ids,
|
|
vectors=batch_embeddings,
|
|
payloads=cls._build_payloads(
|
|
batch_texts,
|
|
batch_metadatas,
|
|
content_payload_key,
|
|
metadata_payload_key,
|
|
),
|
|
),
|
|
)
|
|
|
|
return cls(
|
|
client=client,
|
|
collection_name=collection_name,
|
|
embeddings=embedding,
|
|
content_payload_key=content_payload_key,
|
|
metadata_payload_key=metadata_payload_key,
|
|
)
|
|
|
|
@classmethod
|
|
def _build_payloads(
|
|
cls,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]],
|
|
content_payload_key: str,
|
|
metadata_payload_key: str,
|
|
) -> List[dict]:
|
|
payloads = []
|
|
for i, text in enumerate(texts):
|
|
if text is None:
|
|
raise ValueError(
|
|
"At least one of the texts is None. Please remove it before "
|
|
"calling .from_texts or .add_texts on Qdrant instance."
|
|
)
|
|
metadata = metadatas[i] if metadatas is not None else None
|
|
payloads.append(
|
|
{
|
|
content_payload_key: text,
|
|
metadata_payload_key: metadata,
|
|
}
|
|
)
|
|
|
|
return payloads
|
|
|
|
@classmethod
|
|
def _document_from_scored_point(
|
|
cls,
|
|
scored_point: Any,
|
|
content_payload_key: str,
|
|
metadata_payload_key: str,
|
|
) -> Document:
|
|
return Document(
|
|
page_content=scored_point.payload.get(content_payload_key),
|
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
|
)
|
|
|
|
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
|
from qdrant_client.http import models as rest
|
|
|
|
out = []
|
|
|
|
if isinstance(value, dict):
|
|
for _key, value in value.items():
|
|
out.extend(self._build_condition(f"{key}.{_key}", value))
|
|
elif isinstance(value, list):
|
|
for _value in value:
|
|
if isinstance(_value, dict):
|
|
out.extend(self._build_condition(f"{key}[]", _value))
|
|
else:
|
|
out.extend(self._build_condition(f"{key}", _value))
|
|
else:
|
|
out.append(
|
|
rest.FieldCondition(
|
|
key=f"{self.metadata_payload_key}.{key}",
|
|
match=rest.MatchValue(value=value),
|
|
)
|
|
)
|
|
|
|
return out
|
|
|
|
def _qdrant_filter_from_dict(
|
|
self, filter: Optional[DictFilter]
|
|
) -> Optional[rest.Filter]:
|
|
from qdrant_client.http import models as rest
|
|
|
|
if not filter:
|
|
return None
|
|
|
|
return rest.Filter(
|
|
must=[
|
|
condition
|
|
for key, value in filter.items()
|
|
for condition in self._build_condition(key, value)
|
|
]
|
|
)
|
|
|
|
def _embed_query(self, query: str) -> List[float]:
|
|
"""Embed query text.
|
|
|
|
Used to provide backward compatibility with `embedding_function` argument.
|
|
|
|
Args:
|
|
query: Query text.
|
|
|
|
Returns:
|
|
List of floats representing the query embedding.
|
|
"""
|
|
if self.embeddings is not None:
|
|
embedding = self.embeddings.embed_query(query)
|
|
else:
|
|
if self._embeddings_function is not None:
|
|
embedding = self._embeddings_function(query)
|
|
else:
|
|
raise ValueError("Neither of embeddings or embedding_function is set")
|
|
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
|
|
|
|
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
|
"""Embed search texts.
|
|
|
|
Used to provide backward compatibility with `embedding_function` argument.
|
|
|
|
Args:
|
|
texts: Iterable of texts to embed.
|
|
|
|
Returns:
|
|
List of floats representing the texts embedding.
|
|
"""
|
|
if self.embeddings is not None:
|
|
embeddings = self.embeddings.embed_documents(list(texts))
|
|
if hasattr(embeddings, "tolist"):
|
|
embeddings = embeddings.tolist()
|
|
elif self._embeddings_function is not None:
|
|
embeddings = []
|
|
for text in texts:
|
|
embedding = self._embeddings_function(text)
|
|
if hasattr(embeddings, "tolist"):
|
|
embedding = embedding.tolist()
|
|
embeddings.append(embedding)
|
|
else:
|
|
raise ValueError("Neither of embeddings or embedding_function is set")
|
|
|
|
return embeddings
|