[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
Keyun Tong 2025-02-11 20:25:58 -08:00 committed by GitHub
parent 72c2b68dc9
commit 3ee696a63d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 343 additions and 41 deletions

View File

@ -1275,11 +1275,12 @@ if __name__ == "__main__":
'--tokenizer-mode',
type=str,
default="auto",
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
parser.add_argument("--served-model-name",
type=str,

View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
@property
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()
@property
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()
@property
def all_special_ids(self) -> List[int]:
raise NotImplementedError()
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
raise NotImplementedError()
@property
def vocab_size(self) -> int:
raise NotImplementedError()
@property
def max_token_id(self) -> int:
raise NotImplementedError()
def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()
def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()
def test_customized_tokenizer():
TokenizerRegistry.register("test_tokenizer",
"tests.tokenization.test_tokenizer_registry",
"TestTokenizer")
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -102,8 +102,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
@ -467,10 +468,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.")
"either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(

View File

@ -284,11 +284,13 @@ class EngineArgs:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')

View File

@ -1051,9 +1051,9 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
@ -1176,29 +1176,36 @@ class LLM:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)
return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()

View File

@ -400,8 +400,7 @@ class OpenAIServing:
_chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,

View File

@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)

View File

@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)

View File

@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
@ -21,7 +23,7 @@ from vllm.utils import make_async
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
TokenizerBase]
def decode_tokens(
@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)
tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
elif tokenizer_mode == "custom":
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(

View File

@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TokenizerBase(ABC):
@property
@abstractmethod
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_ids(self) -> List[int]:
raise NotImplementedError()
@property
@abstractmethod
def bos_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def eos_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def sep_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def pad_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def is_fast(self) -> bool:
raise NotImplementedError()
@property
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def max_token_id(self) -> int:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
@abstractmethod
def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()
@abstractmethod
def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
@abstractmethod
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()
@abstractmethod
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()
@abstractmethod
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()
@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()
@abstractmethod
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
@abstractmethod
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> TokenizerBase:
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)

View File

@ -10,6 +10,7 @@ import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of
if TYPE_CHECKING:
@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
tools=tools) # type: ignore[type-var]
class MistralTokenizer:
class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
@ -251,6 +252,14 @@ class MistralTokenizer:
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
return True
@ -268,25 +277,26 @@ class MistralTokenizer:
def __call__(
self,
prompt: Union[str, List[str], List[int]],
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text
if is_list_of(prompt, str):
if is_list_of(text, str):
input_ids_: List[List[int]] = []
for p in prompt:
for p in text:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int):
input_ids = prompt
elif is_list_of(text, int):
input_ids = text
# For str, single prompt text
else:
input_ids = self.encode_one(prompt, truncation, max_length)
input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]:
@ -300,22 +310,29 @@ class MistralTokenizer:
def encode_one(
self,
prompt: str,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
input_ids = self.encode(text)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
def encode(self, prompt: str) -> List[int]:
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
if add_special_tokens is not None:
return self.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
else:
return self.tokenizer.encode(text, bos=True, eos=False)
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],