[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
parent
72c2b68dc9
commit
3ee696a63d
@ -1275,11 +1275,12 @@ if __name__ == "__main__":
|
|||||||
'--tokenizer-mode',
|
'--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=['auto', 'slow', 'mistral'],
|
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
'always use the slow tokenizer. \n* '
|
'always use the slow tokenizer. \n* '
|
||||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||||
|
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
||||||
|
|
||||||
parser.add_argument("--served-model-name",
|
parser.add_argument("--served-model-name",
|
||||||
type=str,
|
type=str,
|
||||||
|
123
tests/tokenization/test_tokenizer_registry.py
Normal file
123
tests/tokenization/test_tokenizer_registry.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||||
|
TokenizerRegistry)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizer(TokenizerBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||||
|
return TestTokenizer()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_tokens_extended(self) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_tokens(self) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_ids(self) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token_id(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_fast(self) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_token_id(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Union[str, List[str], List[int]],
|
||||||
|
text_pair: Optional[str] = None,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_added_vocab(self) -> Dict[str, int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def encode_one(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def encode(self,
|
||||||
|
text: str,
|
||||||
|
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def apply_chat_template(self,
|
||||||
|
messages: List["ChatCompletionMessageParam"],
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
**kwargs) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def decode(self,
|
||||||
|
ids: Union[List[int], int],
|
||||||
|
skip_special_tokens: bool = True) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(
|
||||||
|
self,
|
||||||
|
ids: List[int],
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def test_customized_tokenizer():
|
||||||
|
TokenizerRegistry.register("test_tokenizer",
|
||||||
|
"tests.tokenization.test_tokenizer_registry",
|
||||||
|
"TestTokenizer")
|
||||||
|
|
||||||
|
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||||
|
assert isinstance(tokenizer, TestTokenizer)
|
||||||
|
assert tokenizer.bos_token_id == 0
|
||||||
|
assert tokenizer.eos_token_id == 1
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
|
||||||
|
assert isinstance(tokenizer, TestTokenizer)
|
||||||
|
assert tokenizer.bos_token_id == 0
|
||||||
|
assert tokenizer.eos_token_id == 1
|
@ -102,8 +102,9 @@ class ModelConfig:
|
|||||||
it; otherwise, you must specify explicitly which task to use.
|
it; otherwise, you must specify explicitly which task to use.
|
||||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
available, "slow" will always use the slow tokenizer, and
|
available, "slow" will always use the slow tokenizer,
|
||||||
"mistral" will always use the tokenizer from `mistral_common`.
|
"mistral" will always use the tokenizer from `mistral_common`, and
|
||||||
|
"custom" will use --tokenizer to select the preregistered tokenizer.
|
||||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
allowed_local_media_path: Allowing API requests to read local images or
|
allowed_local_media_path: Allowing API requests to read local images or
|
||||||
@ -467,10 +468,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||||
"either 'auto', 'slow' or 'mistral'.")
|
"either 'auto', 'slow', 'mistral' or 'custom'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _get_preferred_task(
|
def _get_preferred_task(
|
||||||
|
@ -284,11 +284,13 @@ class EngineArgs:
|
|||||||
'--tokenizer-mode',
|
'--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer_mode,
|
default=EngineArgs.tokenizer_mode,
|
||||||
choices=['auto', 'slow', 'mistral'],
|
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
'always use the slow tokenizer. \n* '
|
'always use the slow tokenizer. \n* '
|
||||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
'"mistral" will always use the `mistral_common` tokenizer. \n* '
|
||||||
|
'"custom" will use --tokenizer to select the '
|
||||||
|
'preregistered tokenizer.')
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Trust remote code from huggingface.')
|
help='Trust remote code from huggingface.')
|
||||||
|
@ -1051,9 +1051,9 @@ class LLM:
|
|||||||
|
|
||||||
def _cross_encoding_score(
|
def _cross_encoding_score(
|
||||||
self,
|
self,
|
||||||
tokenizer: Union[AnyTokenizer],
|
tokenizer: AnyTokenizer,
|
||||||
text_1: List[Union[str, TextPrompt, TokensPrompt]],
|
text_1: List[str],
|
||||||
text_2: List[Union[str, TextPrompt, TokensPrompt]],
|
text_2: List[str],
|
||||||
truncate_prompt_tokens: Optional[int] = None,
|
truncate_prompt_tokens: Optional[int] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
@ -1176,29 +1176,36 @@ class LLM:
|
|||||||
if isinstance(text_1, (str, dict)):
|
if isinstance(text_1, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
text_1 = [text_1]
|
text_1 = [text_1]
|
||||||
text_1 = [ensure_str(t) for t in text_1]
|
input_text_1: List[str] = [ensure_str(t) for t in text_1]
|
||||||
|
|
||||||
if isinstance(text_2, (str, dict)):
|
if isinstance(text_2, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
text_2 = [text_2]
|
text_2 = [text_2]
|
||||||
text_2 = [ensure_str(t) for t in text_2]
|
input_text_2: List[str] = [ensure_str(t) for t in text_2]
|
||||||
|
|
||||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
|
||||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||||
if len(text_1) == 0:
|
if len(input_text_1) == 0:
|
||||||
raise ValueError("At least one text element must be given")
|
raise ValueError("At least one text element must be given")
|
||||||
if len(text_2) == 0:
|
if len(input_text_2) == 0:
|
||||||
raise ValueError("At least one text_pair element must be given")
|
raise ValueError("At least one text_pair element must be given")
|
||||||
|
|
||||||
if self.llm_engine.model_config.is_cross_encoder:
|
if self.llm_engine.model_config.is_cross_encoder:
|
||||||
return self._cross_encoding_score(tokenizer, text_1, text_2,
|
return self._cross_encoding_score(tokenizer, input_text_1,
|
||||||
|
input_text_2,
|
||||||
truncate_prompt_tokens, use_tqdm,
|
truncate_prompt_tokens, use_tqdm,
|
||||||
lora_request,
|
lora_request,
|
||||||
prompt_adapter_request)
|
prompt_adapter_request)
|
||||||
else:
|
else:
|
||||||
return self._embedding_score(tokenizer, text_1, text_2,
|
|
||||||
truncate_prompt_tokens, use_tqdm,
|
return self._embedding_score(
|
||||||
lora_request, prompt_adapter_request)
|
tokenizer,
|
||||||
|
input_text_1, # type: ignore[arg-type]
|
||||||
|
input_text_2, # type: ignore[arg-type]
|
||||||
|
truncate_prompt_tokens,
|
||||||
|
use_tqdm,
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request)
|
||||||
|
|
||||||
def start_profile(self) -> None:
|
def start_profile(self) -> None:
|
||||||
self.llm_engine.start_profile()
|
self.llm_engine.start_profile()
|
||||||
|
@ -400,8 +400,7 @@ class OpenAIServing:
|
|||||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||||
|
|
||||||
request_prompt: Union[str, List[int]]
|
request_prompt: Union[str, List[int]]
|
||||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
if is_mistral_tokenizer:
|
|
||||||
request_prompt = apply_mistral_chat_template(
|
request_prompt = apply_mistral_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
|
|||||||
|
|
||||||
tokenize_async = make_async(tokenizer.__call__,
|
tokenize_async = make_async(tokenizer.__call__,
|
||||||
executor=self._tokenizer_executor)
|
executor=self._tokenizer_executor)
|
||||||
prompt_inputs = await tokenize_async(text=q,
|
prompt_inputs = await tokenize_async(q,
|
||||||
text_pair=t,
|
text_pair=t,
|
||||||
**tokenization_kwargs)
|
**tokenization_kwargs)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
|
|||||||
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
# Mistral tokenizers should not add special tokens
|
# Mistral tokenizers should not add special tokens
|
||||||
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||||
else:
|
else:
|
||||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||||
add_special_tokens=False)
|
add_special_tokens=False)
|
||||||
|
@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
|||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||||
|
TokenizerRegistry)
|
||||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
@ -21,7 +23,7 @@ from vllm.utils import make_async
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||||
MistralTokenizer]
|
TokenizerBase]
|
||||||
|
|
||||||
|
|
||||||
def decode_tokens(
|
def decode_tokens(
|
||||||
@ -47,11 +49,7 @@ def encode_tokens(
|
|||||||
Backend-agnostic equivalent of HF's
|
Backend-agnostic equivalent of HF's
|
||||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||||
"""
|
"""
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if add_special_tokens is not None:
|
||||||
return tokenizer.tokenizer.encode(text,
|
|
||||||
bos=add_special_tokens,
|
|
||||||
eos=add_special_tokens)
|
|
||||||
elif add_special_tokens is not None:
|
|
||||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
@ -183,9 +181,17 @@ def get_tokenizer(
|
|||||||
'encoding and decoding.',
|
'encoding and decoding.',
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
|
|
||||||
|
tokenizer: AnyTokenizer
|
||||||
if tokenizer_mode == "mistral":
|
if tokenizer_mode == "mistral":
|
||||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||||
revision=revision)
|
revision=revision)
|
||||||
|
elif tokenizer_mode == "custom":
|
||||||
|
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
|
||||||
|
*args,
|
||||||
|
revision=revision,
|
||||||
|
download_dir=download_dir,
|
||||||
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
146
vllm/transformers_utils/tokenizer_base.py
Normal file
146
vllm/transformers_utils/tokenizer_base.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerBase(ABC):
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def all_special_tokens_extended(self) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def all_special_tokens(self) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def all_special_ids(self) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def bos_token_id(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def eos_token_id(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def sep_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def pad_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_fast(self) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def max_token_id(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.vocab_size
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Union[str, List[str], List[int]],
|
||||||
|
text_pair: Optional[str] = None,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_added_vocab(self) -> Dict[str, int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode_one(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode(self,
|
||||||
|
text: str,
|
||||||
|
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_chat_template(self,
|
||||||
|
messages: List["ChatCompletionMessageParam"],
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
**kwargs) -> List[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode(self,
|
||||||
|
ids: Union[List[int], int],
|
||||||
|
skip_special_tokens: bool = True) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_ids_to_tokens(
|
||||||
|
self,
|
||||||
|
ids: List[int],
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerRegistry:
|
||||||
|
# Tokenizer name -> (tokenizer module, tokenizer class)
|
||||||
|
REGISTRY: Dict[str, Tuple[str, str]] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(name: str, module: str, class_name: str) -> None:
|
||||||
|
TokenizerRegistry.REGISTRY[name] = (module, class_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizer(
|
||||||
|
tokenizer_name: str,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> TokenizerBase:
|
||||||
|
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
|
||||||
|
if tokenizer_cls is None:
|
||||||
|
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
|
||||||
|
|
||||||
|
tokenizer_module = importlib.import_module(tokenizer_cls[0])
|
||||||
|
class_ = getattr(tokenizer_module, tokenizer_cls[1])
|
||||||
|
return class_.from_pretrained(*args, **kwargs)
|
@ -10,6 +10,7 @@ import huggingface_hub
|
|||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
|
|||||||
tools=tools) # type: ignore[type-var]
|
tools=tools) # type: ignore[type-var]
|
||||||
|
|
||||||
|
|
||||||
class MistralTokenizer:
|
class MistralTokenizer(TokenizerBase):
|
||||||
|
|
||||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||||
self.mistral = tokenizer
|
self.mistral = tokenizer
|
||||||
@ -251,6 +252,14 @@ class MistralTokenizer:
|
|||||||
def eos_token_id(self) -> int:
|
def eos_token_id(self) -> int:
|
||||||
return self.tokenizer.eos_id
|
return self.tokenizer.eos_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_fast(self) -> bool:
|
def is_fast(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -268,25 +277,26 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str], List[int]],
|
text: Union[str, List[str], List[int]],
|
||||||
|
text_pair: Optional[str] = None,
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
):
|
):
|
||||||
input_ids: Union[List[int], List[List[int]]]
|
input_ids: Union[List[int], List[List[int]]]
|
||||||
# For List[str], original prompt text
|
# For List[str], original prompt text
|
||||||
if is_list_of(prompt, str):
|
if is_list_of(text, str):
|
||||||
input_ids_: List[List[int]] = []
|
input_ids_: List[List[int]] = []
|
||||||
for p in prompt:
|
for p in text:
|
||||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||||
input_ids_.append(each_input_ids)
|
input_ids_.append(each_input_ids)
|
||||||
input_ids = input_ids_
|
input_ids = input_ids_
|
||||||
# For List[int], apply chat template output, already tokens.
|
# For List[int], apply chat template output, already tokens.
|
||||||
elif is_list_of(prompt, int):
|
elif is_list_of(text, int):
|
||||||
input_ids = prompt
|
input_ids = text
|
||||||
# For str, single prompt text
|
# For str, single prompt text
|
||||||
else:
|
else:
|
||||||
input_ids = self.encode_one(prompt, truncation, max_length)
|
input_ids = self.encode_one(text, truncation, max_length)
|
||||||
return Encoding(input_ids=input_ids)
|
return Encoding(input_ids=input_ids)
|
||||||
|
|
||||||
def get_vocab(self) -> Dict[str, int]:
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
@ -300,22 +310,29 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
def encode_one(
|
def encode_one(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
text: str,
|
||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
# Mistral Tokenizers should not add special tokens
|
# Mistral Tokenizers should not add special tokens
|
||||||
input_ids = self.encode(prompt)
|
input_ids = self.encode(text)
|
||||||
|
|
||||||
if truncation:
|
if truncation:
|
||||||
input_ids = input_ids[:max_length]
|
input_ids = input_ids[:max_length]
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def encode(self, prompt: str) -> List[int]:
|
def encode(self,
|
||||||
|
text: str,
|
||||||
|
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||||
# `encode` should only be used for prompt completion
|
# `encode` should only be used for prompt completion
|
||||||
# it should never be used for chat_completion.
|
# it should never be used for chat_completion.
|
||||||
# For chat completion use `apply_chat_template`
|
# For chat completion use `apply_chat_template`
|
||||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
if add_special_tokens is not None:
|
||||||
|
return self.tokenizer.encode(text,
|
||||||
|
bos=add_special_tokens,
|
||||||
|
eos=add_special_tokens)
|
||||||
|
else:
|
||||||
|
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||||
|
|
||||||
def apply_chat_template(self,
|
def apply_chat_template(self,
|
||||||
messages: List["ChatCompletionMessageParam"],
|
messages: List["ChatCompletionMessageParam"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user