[Bugfix] Fix edge cases for MistralTokenizer (#9625)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
ba0d892074
commit
1dd4cb2935
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -7,11 +7,17 @@ from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
|||||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||||
detokenize_incrementally)
|
detokenize_incrementally)
|
||||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
TRUTH = [
|
TRUTH = [
|
||||||
"Hello here, this is a simple test",
|
"Hello here, this is a simple test",
|
||||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
||||||
"我很感谢你的热情"
|
"我很感谢你的热情",
|
||||||
|
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
|
||||||
|
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
|
||||||
|
# incomplete UTF-8 characters
|
||||||
|
# see https://github.com/vllm-project/vllm/pull/9625
|
||||||
|
"ပုံပြင်လေးပြောပြပါ်",
|
||||||
]
|
]
|
||||||
TOKENIZERS = [
|
TOKENIZERS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -24,6 +30,7 @@ TOKENIZERS = [
|
|||||||
"tiiuae/falcon-7b",
|
"tiiuae/falcon-7b",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
"codellama/CodeLlama-7b-hf",
|
"codellama/CodeLlama-7b-hf",
|
||||||
|
"mistralai/Pixtral-12B-2409",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -49,15 +56,55 @@ def _run_incremental_decode(tokenizer, all_input_ids,
|
|||||||
return decoded_text
|
return decoded_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer(tokenizer_name):
|
||||||
|
return (MistralTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
if "mistral" in tokenizer_name else
|
||||||
|
AutoTokenizer.from_pretrained(tokenizer_name))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"truth",
|
||||||
|
[
|
||||||
|
# Burmese text triggers an edge-case where tokens may map to bytes with
|
||||||
|
# incomplete UTF-8 characters
|
||||||
|
"ပုံပြင်လေးပြောပြပါ",
|
||||||
|
# Using "URGENCY" since "CY" has token id 130282
|
||||||
|
"URGENCY🌶️",
|
||||||
|
])
|
||||||
|
def test_mistral_edge_case(tokenizer, truth):
|
||||||
|
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
|
||||||
|
|
||||||
|
See https://github.com/vllm-project/vllm/pull/9625
|
||||||
|
"""
|
||||||
|
starting_index = 0
|
||||||
|
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
decoded_text = _run_incremental_decode(tokenizer,
|
||||||
|
all_input_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
starting_index=starting_index)
|
||||||
|
assert decoded_text == truth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||||
|
if "mistral" in tokenizer_name:
|
||||||
|
yield (
|
||||||
|
bool(True) if request.param else
|
||||||
|
pytest.skip("mistral doesn't support skip_special_tokens=False"))
|
||||||
|
else:
|
||||||
|
yield bool(True) if request.param else bool(False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("truth", TRUTH)
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||||
def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
||||||
skip_special_tokens):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
|
||||||
if with_prompt:
|
if with_prompt:
|
||||||
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||||
all_input_ids = prompt_input_ids + generated_input_ids
|
all_input_ids = prompt_input_ids + generated_input_ids
|
||||||
@ -68,7 +115,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
|||||||
else:
|
else:
|
||||||
generated = truth
|
generated = truth
|
||||||
starting_index = 0
|
starting_index = 0
|
||||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||||
if skip_special_tokens:
|
if skip_special_tokens:
|
||||||
if tokenizer.bos_token_id is not None:
|
if tokenizer.bos_token_id is not None:
|
||||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||||
@ -98,7 +145,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
|||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=100,
|
max_num_seqs=100,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
revision=None,
|
revision=None,
|
||||||
)
|
)
|
||||||
@ -113,9 +160,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
|||||||
|
|
||||||
@pytest.fixture(name="complete_sequence_token_ids")
|
@pytest.fixture(name="complete_sequence_token_ids")
|
||||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||||
tokenizer_name: str) -> List[int]:
|
tokenizer) -> List[int]:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||||
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
|
|
||||||
return complete_sequence_token_ids
|
return complete_sequence_token_ids
|
||||||
|
|
||||||
|
|
||||||
@ -150,7 +196,7 @@ def create_dummy_prompt_logprobs(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
||||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||||
complete_sequence_token_ids: List[int],
|
complete_sequence_token_ids: List[int],
|
||||||
detokenizer: Detokenizer,
|
detokenizer: Detokenizer,
|
||||||
@ -208,9 +254,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
|||||||
|
|
||||||
# decoded_prompt_logprobs doesn't contain the first token.
|
# decoded_prompt_logprobs doesn't contain the first token.
|
||||||
token_ids = complete_sequence_token_ids
|
token_ids = complete_sequence_token_ids
|
||||||
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
|
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||||
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
|
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||||
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
|
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
||||||
text = text_full[len(text_first):]
|
text = text_full[len(text_first):]
|
||||||
|
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
|
@ -16,9 +16,13 @@ from mistral_common.tokens.tokenizers.sentencepiece import (
|
|||||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||||
Tekkenizer)
|
Tekkenizer)
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Encoding:
|
class Encoding:
|
||||||
@ -72,20 +76,21 @@ class MistralTokenizer:
|
|||||||
# Make sure special tokens will not raise
|
# Make sure special tokens will not raise
|
||||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||||
|
|
||||||
self._vocab = {
|
|
||||||
token: idx
|
|
||||||
for idx, token in enumerate(tokenizer_.vocab())
|
|
||||||
}
|
|
||||||
elif isinstance(tokenizer_, SentencePieceTokenizer):
|
elif isinstance(tokenizer_, SentencePieceTokenizer):
|
||||||
self._vocab = {
|
pass
|
||||||
token: idx
|
|
||||||
for idx, token in enumerate(tokenizer_.vocab())
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||||
|
|
||||||
|
self._vocab = tokenizer_.vocab()
|
||||||
|
# Convert to a Dict[str, int] to match protocol, but this is a lossy
|
||||||
|
# conversion. There may be multiple token ids that decode to the same
|
||||||
|
# string due to partial UTF-8 byte sequences being converted to <20>
|
||||||
|
self._vocab_dict = {
|
||||||
|
token: idx
|
||||||
|
for idx, token in enumerate(self._vocab)
|
||||||
|
}
|
||||||
self.tokenizer = tokenizer_
|
self.tokenizer = tokenizer_
|
||||||
self._max_token_id = max(self._vocab.values())
|
self._max_token_id = self.vocab_size - 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
@ -182,7 +187,9 @@ class MistralTokenizer:
|
|||||||
return Encoding(input_ids=input_ids)
|
return Encoding(input_ids=input_ids)
|
||||||
|
|
||||||
def get_vocab(self) -> Dict[str, int]:
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
return self._vocab
|
# NB: the dictionary form of the vocabulary collapses token ids that map
|
||||||
|
# to the same string but have different bytes
|
||||||
|
return self._vocab_dict
|
||||||
|
|
||||||
def get_added_vocab(self) -> Dict[str, int]:
|
def get_added_vocab(self) -> Dict[str, int]:
|
||||||
# Mistral tokenizers have no added vocabulary
|
# Mistral tokenizers have no added vocabulary
|
||||||
@ -220,14 +227,20 @@ class MistralTokenizer:
|
|||||||
if any(isinstance(t, bytes) for t in tokens):
|
if any(isinstance(t, bytes) for t in tokens):
|
||||||
# we need to encode and decode all tokens again
|
# we need to encode and decode all tokens again
|
||||||
shift = self.tokenizer.num_special_tokens
|
shift = self.tokenizer.num_special_tokens
|
||||||
byte_tokens = [
|
|
||||||
t.encode("utf-8") if not isinstance(t, bytes) else t
|
def _token_to_id(t: str):
|
||||||
for t in tokens
|
t_bytes = t.encode("utf-8") \
|
||||||
]
|
if not isinstance(t, bytes) else t
|
||||||
ids = [
|
try:
|
||||||
self.tokenizer._tekken_token2id_nospecial[t] + shift
|
return shift + \
|
||||||
for t in byte_tokens
|
self.tokenizer._tekken_token2id_nospecial[t_bytes]
|
||||||
]
|
except KeyError:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to convert token %s to id,"
|
||||||
|
" replacing with <unk>", t_bytes)
|
||||||
|
return self.tokenizer.unk_id
|
||||||
|
|
||||||
|
ids = [_token_to_id(t) for t in tokens]
|
||||||
decoded = self.tokenizer.decode(ids)
|
decoded = self.tokenizer.decode(ids)
|
||||||
else:
|
else:
|
||||||
decoded = "".join(tokens)
|
decoded = "".join(tokens)
|
||||||
@ -236,7 +249,13 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
def decode(self, ids: Union[List[int], int]) -> str:
|
def decode(self,
|
||||||
|
ids: Union[List[int], int],
|
||||||
|
skip_special_tokens: bool = True) -> str:
|
||||||
|
assert (
|
||||||
|
skip_special_tokens
|
||||||
|
), "Skipping special tokens is not supported for Mistral tokenizers."
|
||||||
|
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, int):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
return self.tokenizer.decode(ids)
|
return self.tokenizer.decode(ids)
|
||||||
@ -257,10 +276,11 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||||
|
|
||||||
if any(t.strip() == "<EFBFBD>" for t in tokens):
|
if any("<EFBFBD>" in t for t in tokens):
|
||||||
# if any stripped decoded token is undefined
|
# if a decoded token contains the replacement character, then the
|
||||||
# because it's invalid unicode then pass bytes
|
# token has an incomplete UTF-8 character so we must use bytes
|
||||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||||
|
# https://github.com/vllm-project/vllm/pull/9625
|
||||||
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user