[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
Keyun Tong 2025-02-11 20:25:58 -08:00 committed by GitHub
parent 72c2b68dc9
commit 3ee696a63d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 343 additions and 41 deletions

View File

@ -1275,11 +1275,12 @@ if __name__ == "__main__":
'--tokenizer-mode', '--tokenizer-mode',
type=str, type=str,
default="auto", default="auto",
choices=['auto', 'slow', 'mistral'], choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' 'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.') '"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
type=str, type=str,

View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
@property
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()
@property
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()
@property
def all_special_ids(self) -> List[int]:
raise NotImplementedError()
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
raise NotImplementedError()
@property
def vocab_size(self) -> int:
raise NotImplementedError()
@property
def max_token_id(self) -> int:
raise NotImplementedError()
def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()
def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()
def test_customized_tokenizer():
TokenizerRegistry.register("test_tokenizer",
"tests.tokenization.test_tokenizer_registry",
"TestTokenizer")
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -102,8 +102,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use. it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use. tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`. "mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer. downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or allowed_local_media_path: Allowing API requests to read local images or
@ -467,10 +468,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]: if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError( raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.") "either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _get_preferred_task( def _get_preferred_task(

View File

@ -284,11 +284,13 @@ class EngineArgs:
'--tokenizer-mode', '--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'], choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' 'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.') '"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
help='Trust remote code from huggingface.') help='Trust remote code from huggingface.')

View File

@ -1051,9 +1051,9 @@ class LLM:
def _cross_encoding_score( def _cross_encoding_score(
self, self,
tokenizer: Union[AnyTokenizer], tokenizer: AnyTokenizer,
text_1: List[Union[str, TextPrompt, TokensPrompt]], text_1: List[str],
text_2: List[Union[str, TextPrompt, TokensPrompt]], text_2: List[str],
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
@ -1176,29 +1176,36 @@ class LLM:
if isinstance(text_1, (str, dict)): if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
text_1 = [text_1] text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1] input_text_1: List[str] = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)): if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
text_2 = [text_2] text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2] input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2): if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N") raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0: if len(input_text_1) == 0:
raise ValueError("At least one text element must be given") raise ValueError("At least one text element must be given")
if len(text_2) == 0: if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given") raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder: if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2, return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm, truncate_prompt_tokens, use_tqdm,
lora_request, lora_request,
prompt_adapter_request) prompt_adapter_request)
else: else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm, return self._embedding_score(
lora_request, prompt_adapter_request) tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()

View File

@ -400,8 +400,7 @@ class OpenAIServing:
_chat_template_kwargs.update(chat_template_kwargs or {}) _chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, List[int]] request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template( request_prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=messages, messages=messages,

View File

@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async = make_async(tokenizer.__call__, tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor) executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q, prompt_inputs = await tokenize_async(q,
text_pair=t, text_pair=t,
**tokenization_kwargs) **tokenization_kwargs)

View File

@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens # Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt) prompt_token_ids = tokenizer.encode(text=prompt)
else: else:
prompt_token_ids = tokenizer.encode(text=prompt, prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False) add_special_tokens=False)

View File

@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async from vllm.utils import make_async
@ -21,7 +23,7 @@ from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer] TokenizerBase]
def decode_tokens( def decode_tokens(
@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`. :code:`tokenizer.encode(text, add_special_tokens=...)`.
""" """
if isinstance(tokenizer, MistralTokenizer): if add_special_tokens is not None:
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text) return tokenizer.encode(text)
@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.', 'encoding and decoding.',
FutureWarning, FutureWarning,
stacklevel=2) stacklevel=2)
tokenizer: AnyTokenizer
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision) revision=revision)
elif tokenizer_mode == "custom":
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs)
else: else:
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(

View File

@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TokenizerBase(ABC):
@property
@abstractmethod
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_ids(self) -> List[int]:
raise NotImplementedError()
@property
@abstractmethod
def bos_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def eos_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def sep_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def pad_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def is_fast(self) -> bool:
raise NotImplementedError()
@property
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def max_token_id(self) -> int:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
@abstractmethod
def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()
@abstractmethod
def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()
@abstractmethod
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()
@abstractmethod
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()
@abstractmethod
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()
@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()
@abstractmethod
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
@abstractmethod
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> TokenizerBase:
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)

View File

@ -10,6 +10,7 @@ import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of from vllm.utils import is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
tools=tools) # type: ignore[type-var] tools=tools) # type: ignore[type-var]
class MistralTokenizer: class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer self.mistral = tokenizer
@ -251,6 +252,14 @@ class MistralTokenizer:
def eos_token_id(self) -> int: def eos_token_id(self) -> int:
return self.tokenizer.eos_id return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
return True return True
@ -268,25 +277,26 @@ class MistralTokenizer:
def __call__( def __call__(
self, self,
prompt: Union[str, List[str], List[int]], text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
): ):
input_ids: Union[List[int], List[List[int]]] input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text # For List[str], original prompt text
if is_list_of(prompt, str): if is_list_of(text, str):
input_ids_: List[List[int]] = [] input_ids_: List[List[int]] = []
for p in prompt: for p in text:
each_input_ids = self.encode_one(p, truncation, max_length) each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids) input_ids_.append(each_input_ids)
input_ids = input_ids_ input_ids = input_ids_
# For List[int], apply chat template output, already tokens. # For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int): elif is_list_of(text, int):
input_ids = prompt input_ids = text
# For str, single prompt text # For str, single prompt text
else: else:
input_ids = self.encode_one(prompt, truncation, max_length) input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids) return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> Dict[str, int]:
@ -300,22 +310,29 @@ class MistralTokenizer:
def encode_one( def encode_one(
self, self,
prompt: str, text: str,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> List[int]: ) -> List[int]:
# Mistral Tokenizers should not add special tokens # Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt) input_ids = self.encode(text)
if truncation: if truncation:
input_ids = input_ids[:max_length] input_ids = input_ids[:max_length]
return input_ids return input_ids
def encode(self, prompt: str) -> List[int]: def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
# `encode` should only be used for prompt completion # `encode` should only be used for prompt completion
# it should never be used for chat_completion. # it should never be used for chat_completion.
# For chat completion use `apply_chat_template` # For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False) if add_special_tokens is not None:
return self.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
else:
return self.tokenizer.encode(text, bos=True, eos=False)
def apply_chat_template(self, def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"], messages: List["ChatCompletionMessageParam"],