diff --git a/libs/langchain/langchain/retrievers/parent_document_retriever.py b/libs/langchain/langchain/retrievers/parent_document_retriever.py index f427b1f7b..a4f775e44 100644 --- a/libs/langchain/langchain/retrievers/parent_document_retriever.py +++ b/libs/langchain/langchain/retrievers/parent_document_retriever.py @@ -1,7 +1,9 @@ import uuid -from typing import Any, Dict, List, Optional +from typing import List, Optional -from langchain.callbacks.base import Callbacks +from pydantic import Field + +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema.document import Document from langchain.schema.retriever import BaseRetriever from langchain.schema.storage import BaseStore @@ -71,17 +73,20 @@ class ParentDocumentRetriever(BaseRetriever): parent_splitter: Optional[TextSplitter] = None """The text splitter to use to create parent documents. If none, then the parent documents will be the raw documents passed in.""" + search_kwargs: dict = Field(default_factory=dict) + """Keyword arguments to pass to the search function.""" - def get_relevant_documents( - self, - query: str, - *, - callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: - sub_docs = self.vectorstore.similarity_search(query) + """Get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: