[Bugfix] Mistral tokenizer encode accept list of str (#12149)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-01-18 00:47:53 +08:00 committed by GitHub
parent 58fd57ff1d
commit 54cacf008f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer) Tekkenizer)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -27,7 +28,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class Encoding: class Encoding:
input_ids: List[int] input_ids: Union[List[int], List[List[int]]]
def maybe_serialize_tool_calls(request: ChatCompletionRequest): def maybe_serialize_tool_calls(request: ChatCompletionRequest):
@ -223,17 +224,25 @@ class MistralTokenizer:
def __call__( def __call__(
self, self,
prompt: str, prompt: Union[str, List[str], List[int]],
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
): ):
# Mistral Tokenizers should not add special tokens input_ids: Union[List[int], List[List[int]]]
input_ids = self.encode(prompt) # For List[str], original prompt text
if is_list_of(prompt, str):
if truncation: input_ids_: List[List[int]] = []
input_ids = input_ids[:max_length] for p in prompt:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int):
input_ids = prompt
# For str, single prompt text
else:
input_ids = self.encode_one(prompt, truncation, max_length)
return Encoding(input_ids=input_ids) return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> Dict[str, int]:
@ -245,6 +254,19 @@ class MistralTokenizer:
# Mistral tokenizers have no added vocabulary # Mistral tokenizers have no added vocabulary
return {} return {}
def encode_one(
self,
prompt: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
def encode(self, prompt: str) -> List[int]: def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion # `encode` should only be used for prompt completion
# it should never be used for chat_completion. # it should never be used for chat_completion.