[Bugfix] Mistral tokenizer encode accept list of str (#12149)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
58fd57ff1d
commit
54cacf008f
@ -18,6 +18,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
@ -27,7 +28,7 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class Encoding:
|
||||
input_ids: List[int]
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||
@ -223,17 +224,25 @@ class MistralTokenizer:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str], List[int]],
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
# For List[str], original prompt text
|
||||
if is_list_of(prompt, str):
|
||||
input_ids_: List[List[int]] = []
|
||||
for p in prompt:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For List[int], apply chat template output, already tokens.
|
||||
elif is_list_of(prompt, int):
|
||||
input_ids = prompt
|
||||
# For str, single prompt text
|
||||
else:
|
||||
input_ids = self.encode_one(prompt, truncation, max_length)
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
@ -245,6 +254,19 @@ class MistralTokenizer:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
return {}
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
prompt: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
return input_ids
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
|
Loading…
x
Reference in New Issue
Block a user