[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
parent
72c2b68dc9
commit
3ee696a63d
@ -1275,11 +1275,12 @@ if __name__ == "__main__":
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=['auto', 'slow', 'mistral'],
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
||||
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
|
123
tests/tokenization/test_tokenizer_registry.py
Normal file
123
tests/tokenization/test_tokenizer_registry.py
Normal file
@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||
TokenizerRegistry)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerBase):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||
return TestTokenizer()
|
||||
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register("test_tokenizer",
|
||||
"tests.tokenization.test_tokenizer_registry",
|
||||
"TestTokenizer")
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
|
||||
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
@ -102,8 +102,9 @@ class ModelConfig:
|
||||
it; otherwise, you must specify explicitly which task to use.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, "slow" will always use the slow tokenizer, and
|
||||
"mistral" will always use the tokenizer from `mistral_common`.
|
||||
available, "slow" will always use the slow tokenizer,
|
||||
"mistral" will always use the tokenizer from `mistral_common`, and
|
||||
"custom" will use --tokenizer to select the preregistered tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
allowed_local_media_path: Allowing API requests to read local images or
|
||||
@ -467,10 +468,10 @@ class ModelConfig:
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto', 'slow' or 'mistral'.")
|
||||
"either 'auto', 'slow', 'mistral' or 'custom'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
|
@ -284,11 +284,13 @@ class EngineArgs:
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow', 'mistral'],
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n* '
|
||||
'"custom" will use --tokenizer to select the '
|
||||
'preregistered tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='Trust remote code from huggingface.')
|
||||
|
@ -1051,9 +1051,9 @@ class LLM:
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
text_1: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
text_2: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
tokenizer: AnyTokenizer,
|
||||
text_1: List[str],
|
||||
text_2: List[str],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
@ -1176,29 +1176,36 @@ class LLM:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
input_text_1: List[str] = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
input_text_2: List[str] = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
if len(input_text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
if len(input_text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if self.llm_engine.model_config.is_cross_encoder:
|
||||
return self._cross_encoding_score(tokenizer, text_1, text_2,
|
||||
return self._cross_encoding_score(tokenizer, input_text_1,
|
||||
input_text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
else:
|
||||
return self._embedding_score(tokenizer, text_1, text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
input_text_1, # type: ignore[arg-type]
|
||||
input_text_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
@ -400,8 +400,7 @@ class OpenAIServing:
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
|
@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=q,
|
||||
prompt_inputs = await tokenize_async(q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
|
@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||
TokenizerRegistry)
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import make_async
|
||||
@ -21,7 +23,7 @@ from vllm.utils import make_async
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
TokenizerBase]
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
@ -47,11 +49,7 @@ def encode_tokens(
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
elif add_special_tokens is not None:
|
||||
if add_special_tokens is not None:
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
return tokenizer.encode(text)
|
||||
|
||||
@ -183,9 +181,17 @@ def get_tokenizer(
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
tokenizer: AnyTokenizer
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||
revision=revision)
|
||||
elif tokenizer_mode == "custom":
|
||||
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
|
||||
*args,
|
||||
revision=revision,
|
||||
download_dir=download_dir,
|
||||
**kwargs)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
146
vllm/transformers_utils/tokenizer_base.py
Normal file
146
vllm/transformers_utils/tokenizer_base.py
Normal file
@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TokenizerBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_ids(self) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def bos_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eos_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_fast(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def encode_one(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenizerRegistry:
|
||||
# Tokenizer name -> (tokenizer module, tokenizer class)
|
||||
REGISTRY: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, module: str, class_name: str) -> None:
|
||||
TokenizerRegistry.REGISTRY[name] = (module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> TokenizerBase:
|
||||
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
|
||||
if tokenizer_cls is None:
|
||||
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
|
||||
|
||||
tokenizer_module = importlib.import_module(tokenizer_cls[0])
|
||||
class_ = getattr(tokenizer_module, tokenizer_cls[1])
|
||||
return class_.from_pretrained(*args, **kwargs)
|
@ -10,6 +10,7 @@ import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
|
||||
tools=tools) # type: ignore[type-var]
|
||||
|
||||
|
||||
class MistralTokenizer:
|
||||
class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
@ -251,6 +252,14 @@ class MistralTokenizer:
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
@ -268,25 +277,26 @@ class MistralTokenizer:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[int]],
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
# For List[str], original prompt text
|
||||
if is_list_of(prompt, str):
|
||||
if is_list_of(text, str):
|
||||
input_ids_: List[List[int]] = []
|
||||
for p in prompt:
|
||||
for p in text:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For List[int], apply chat template output, already tokens.
|
||||
elif is_list_of(prompt, int):
|
||||
input_ids = prompt
|
||||
elif is_list_of(text, int):
|
||||
input_ids = text
|
||||
# For str, single prompt text
|
||||
else:
|
||||
input_ids = self.encode_one(prompt, truncation, max_length)
|
||||
input_ids = self.encode_one(text, truncation, max_length)
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
@ -300,22 +310,29 @@ class MistralTokenizer:
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
prompt: str,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
input_ids = self.encode(text)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
return input_ids
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
if add_special_tokens is not None:
|
||||
return self.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
else:
|
||||
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user