Fix edge case Mistral tokenizer (#10152)
This commit is contained in:
parent
b489fc3c91
commit
0535e5fe6c
@ -72,11 +72,12 @@ class MistralTokenizer:
|
|||||||
self.instruct = tokenizer.instruct_tokenizer
|
self.instruct = tokenizer.instruct_tokenizer
|
||||||
|
|
||||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||||
if isinstance(tokenizer_, Tekkenizer):
|
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||||
|
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||||
|
if self.is_tekken:
|
||||||
# Make sure special tokens will not raise
|
# Make sure special tokens will not raise
|
||||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||||
|
elif self.is_spm:
|
||||||
elif isinstance(tokenizer_, SentencePieceTokenizer):
|
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||||
@ -218,7 +219,7 @@ class MistralTokenizer:
|
|||||||
return encoded.tokens
|
return encoded.tokens
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
if isinstance(self.tokenizer, Tekkenizer):
|
if self.is_tekken:
|
||||||
tokens = [
|
tokens = [
|
||||||
t for t in tokens
|
t for t in tokens
|
||||||
if t not in self.tokenizer._all_special_tokens
|
if t not in self.tokenizer._all_special_tokens
|
||||||
@ -270,21 +271,20 @@ class MistralTokenizer:
|
|||||||
skip_special_tokens
|
skip_special_tokens
|
||||||
), "skip_special_tokens=False is not supported for Mistral tokenizers."
|
), "skip_special_tokens=False is not supported for Mistral tokenizers."
|
||||||
|
|
||||||
assert isinstance(self.tokenizer,
|
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||||
(Tekkenizer, SentencePieceTokenizer)), type(
|
|
||||||
self.tokenizer)
|
|
||||||
|
|
||||||
if isinstance(self.tokenizer, Tekkenizer):
|
if self.is_tekken:
|
||||||
# skip special tokens
|
# skip special tokens
|
||||||
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
|
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
|
||||||
|
|
||||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||||
|
|
||||||
if any("<EFBFBD>" in t for t in tokens):
|
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
|
||||||
# if a decoded token contains the replacement character, then the
|
# if a decoded token contains the replacement character, then the
|
||||||
# token has an incomplete UTF-8 character so we must use bytes
|
# token has an incomplete UTF-8 character so we must use bytes
|
||||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||||
# https://github.com/vllm-project/vllm/pull/9625
|
# https://github.com/vllm-project/vllm/pull/9625
|
||||||
|
# if underlying tokenizeir is sentencepiece, we just add "<22>"
|
||||||
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user