mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
Add allowed and disallowed special arguments to BaseOpenAI (#3012)
## Background
This PR fixes this error when there are special tokens when querying the
chain:
```
Encountered text corresponding to disallowed special token '<|endofprompt|>'.
If you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endofprompt|>', ...}`.
If you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endofprompt|>'})`.
To disable this check for all special tokens, pass `disallowed_special=()`.
```
Refer to the code snippet below, it breaks in the chain line.
```
chain = ConversationalRetrievalChain.from_llm(
ChatOpenAI(openai_api_key=OPENAI_API_KEY),
retriever=vectorstore.as_retriever(),
qa_prompt=prompt,
condense_question_prompt=condense_prompt,
)
answer = chain({"question": f"{question}"})
```
However `ChatOpenAI` class is not accepting `allowed_special` and
`disallowed_special` at the moment so they cannot be passed to the
`encode()` in `get_num_tokens` method to avoid the errors.
## Change
- Add `allowed_special` and `disallowed_special` attributes to
`BaseOpenAI` class.
- Pass in `allowed_special` and `disallowed_special` as arguments of
`encode()` in tiktoken.
---------
Co-authored-by: samcarmen <“carmen.samkahman@gmail.com”>
This commit is contained in:
@@ -5,11 +5,14 @@ import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
@@ -150,6 +153,10 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
@@ -449,7 +456,11 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
@@ -602,6 +613,10 @@ class OpenAIChat(BaseLLM):
|
||||
"""Series of messages for Chat input."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -785,7 +800,11 @@ class OpenAIChat(BaseLLM):
|
||||
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
# encode the text using the GPT-3.5-Turbo encoder
|
||||
tokenized_text = enc.encode(text)
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
Reference in New Issue
Block a user