mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
add token reduction to ConversationalRetrievalChain (#2075)
This worked for me, but I'm not sure if its the right way to approach something like this, so I'm open to suggestions. Adds class properties `reduce_k_below_max_tokens: bool` and `max_tokens_limit: int` to the `ConversationalRetrievalChain`. The code is basically copied from [`RetreivalQAWithSourcesChain`](https://github.com/nkov/langchain/blob/46d141c6cb6c0fdebb308336d8ae140d8368945a/langchain/chains/qa_with_sources/retrieval.py#L24)
This commit is contained in:
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
@@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
retriever: BaseRetriever
|
||||
"""Index to connect to."""
|
||||
max_tokens_limit: Optional[int] = None
|
||||
"""If set, restricts the docs to return from store based on tokens, enforced only
|
||||
for StuffDocumentChain"""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.max_tokens_limit and isinstance(
|
||||
self.combine_docs_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
while token_count > self.max_tokens_limit:
|
||||
num_docs -= 1
|
||||
token_count -= tokens[num_docs]
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
|
||||
Reference in New Issue
Block a user