[Bugfix] Mistral tokenizer encode accept list of str (#12149)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
58fd57ff1d
commit
54cacf008f
@ -18,6 +18,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
|||||||
Tekkenizer)
|
Tekkenizer)
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
@ -27,7 +28,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Encoding:
|
class Encoding:
|
||||||
input_ids: List[int]
|
input_ids: Union[List[int], List[List[int]]]
|
||||||
|
|
||||||
|
|
||||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||||
@ -223,17 +224,25 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Union[str, List[str], List[int]],
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Mistral Tokenizers should not add special tokens
|
input_ids: Union[List[int], List[List[int]]]
|
||||||
input_ids = self.encode(prompt)
|
# For List[str], original prompt text
|
||||||
|
if is_list_of(prompt, str):
|
||||||
if truncation:
|
input_ids_: List[List[int]] = []
|
||||||
input_ids = input_ids[:max_length]
|
for p in prompt:
|
||||||
|
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||||
|
input_ids_.append(each_input_ids)
|
||||||
|
input_ids = input_ids_
|
||||||
|
# For List[int], apply chat template output, already tokens.
|
||||||
|
elif is_list_of(prompt, int):
|
||||||
|
input_ids = prompt
|
||||||
|
# For str, single prompt text
|
||||||
|
else:
|
||||||
|
input_ids = self.encode_one(prompt, truncation, max_length)
|
||||||
return Encoding(input_ids=input_ids)
|
return Encoding(input_ids=input_ids)
|
||||||
|
|
||||||
def get_vocab(self) -> Dict[str, int]:
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
@ -245,6 +254,19 @@ class MistralTokenizer:
|
|||||||
# Mistral tokenizers have no added vocabulary
|
# Mistral tokenizers have no added vocabulary
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def encode_one(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
# Mistral Tokenizers should not add special tokens
|
||||||
|
input_ids = self.encode(prompt)
|
||||||
|
|
||||||
|
if truncation:
|
||||||
|
input_ids = input_ids[:max_length]
|
||||||
|
return input_ids
|
||||||
|
|
||||||
def encode(self, prompt: str) -> List[int]:
|
def encode(self, prompt: str) -> List[int]:
|
||||||
# `encode` should only be used for prompt completion
|
# `encode` should only be used for prompt completion
|
||||||
# it should never be used for chat_completion.
|
# it should never be used for chat_completion.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user