Make initialization of tokenizer and detokenizer optional (#3748)
Co-authored-by: Yun Ding <yunding@nvidia.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
7f2593b164
commit
a37d815b83
23
tests/engine/test_skip_tokenizer_init.py
Normal file
23
tests/engine/test_skip_tokenizer_init.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
def test_skip_tokenizer_initialization(model: str):
|
||||||
|
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||||
|
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||||
|
# token ids.
|
||||||
|
llm = LLM(model=model, skip_tokenizer_init=True)
|
||||||
|
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||||
|
with pytest.raises(ValueError) as err:
|
||||||
|
llm.generate("abc", sampling_params)
|
||||||
|
assert "prompts must be None if" in str(err.value)
|
||||||
|
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
assert len(outputs) > 0
|
||||||
|
completions = outputs[0].outputs
|
||||||
|
assert len(completions) > 0
|
||||||
|
assert completions[0].text == ""
|
||||||
|
assert completions[0].token_ids
|
@ -66,6 +66,8 @@ class ModelConfig:
|
|||||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
When a sequence has context length larger than this, we fall back
|
When a sequence has context length larger than this, we fall back
|
||||||
to eager mode.
|
to eager mode.
|
||||||
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||||
|
detokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -85,6 +87,7 @@ class ModelConfig:
|
|||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
max_context_len_to_capture: Optional[int] = None,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
max_logprobs: int = 5,
|
max_logprobs: int = 5,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -99,6 +102,7 @@ class ModelConfig:
|
|||||||
self.enforce_eager = enforce_eager
|
self.enforce_eager = enforce_eager
|
||||||
self.max_context_len_to_capture = max_context_len_to_capture
|
self.max_context_len_to_capture = max_context_len_to_capture
|
||||||
self.max_logprobs = max_logprobs
|
self.max_logprobs = max_logprobs
|
||||||
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision)
|
code_revision)
|
||||||
@ -106,7 +110,8 @@ class ModelConfig:
|
|||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||||
max_model_len)
|
max_model_len)
|
||||||
self._verify_tokenizer_mode()
|
if not self.skip_tokenizer_init:
|
||||||
|
self._verify_tokenizer_mode()
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ class EngineArgs:
|
|||||||
"""Arguments for vLLM engine."""
|
"""Arguments for vLLM engine."""
|
||||||
model: str
|
model: str
|
||||||
tokenizer: Optional[str] = None
|
tokenizer: Optional[str] = None
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
@ -93,6 +94,10 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='Name or path of the huggingface tokenizer to use.')
|
help='Name or path of the huggingface tokenizer to use.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--skip-tokenizer-init',
|
||||||
|
action='store_true',
|
||||||
|
help='Skip initialization of tokenizer and detokenizer')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--revision',
|
'--revision',
|
||||||
type=str,
|
type=str,
|
||||||
@ -453,7 +458,7 @@ class EngineArgs:
|
|||||||
self.code_revision, self.tokenizer_revision, self.max_model_len,
|
self.code_revision, self.tokenizer_revision, self.max_model_len,
|
||||||
self.quantization, self.quantization_param_path,
|
self.quantization, self.quantization_param_path,
|
||||||
self.enforce_eager, self.max_context_len_to_capture,
|
self.enforce_eager, self.max_context_len_to_capture,
|
||||||
self.max_logprobs)
|
self.max_logprobs, self.skip_tokenizer_init)
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space, self.kv_cache_dtype,
|
self.swap_space, self.kv_cache_dtype,
|
||||||
|
@ -100,6 +100,7 @@ class LLMEngine:
|
|||||||
f"model={model_config.model!r}, "
|
f"model={model_config.model!r}, "
|
||||||
f"speculative_config={speculative_config!r}, "
|
f"speculative_config={speculative_config!r}, "
|
||||||
f"tokenizer={model_config.tokenizer!r}, "
|
f"tokenizer={model_config.tokenizer!r}, "
|
||||||
|
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
|
||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
f"revision={model_config.revision}, "
|
f"revision={model_config.revision}, "
|
||||||
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
||||||
@ -132,8 +133,14 @@ class LLMEngine:
|
|||||||
self.decoding_config = decoding_config or DecodingConfig()
|
self.decoding_config = decoding_config or DecodingConfig()
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
self._init_tokenizer()
|
if not self.model_config.skip_tokenizer_init:
|
||||||
self.detokenizer = Detokenizer(self.tokenizer)
|
self.tokenizer: BaseTokenizerGroup
|
||||||
|
self._init_tokenizer()
|
||||||
|
self.detokenizer = Detokenizer(self.tokenizer)
|
||||||
|
else:
|
||||||
|
self.detokenizer = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
self.generation_config_fields = _load_generation_config_dict(
|
self.generation_config_fields = _load_generation_config_dict(
|
||||||
model_config)
|
model_config)
|
||||||
@ -187,9 +194,10 @@ class LLMEngine:
|
|||||||
parallel_config.disable_custom_all_reduce,
|
parallel_config.disable_custom_all_reduce,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Ping the tokenizer to ensure liveness if it runs in a
|
if self.tokenizer:
|
||||||
# different process.
|
# Ping the tokenizer to ensure liveness if it runs in a
|
||||||
self.tokenizer.ping()
|
# different process.
|
||||||
|
self.tokenizer.ping()
|
||||||
|
|
||||||
# Create the scheduler.
|
# Create the scheduler.
|
||||||
# NOTE: the cache_config here have been updated with the numbers of
|
# NOTE: the cache_config here have been updated with the numbers of
|
||||||
@ -296,7 +304,7 @@ class LLMEngine:
|
|||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
revision=self.model_config.tokenizer_revision)
|
revision=self.model_config.tokenizer_revision)
|
||||||
init_kwargs.update(tokenizer_init_kwargs)
|
init_kwargs.update(tokenizer_init_kwargs)
|
||||||
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
|
self.tokenizer = get_tokenizer_group(
|
||||||
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
@ -393,8 +401,13 @@ class LLMEngine:
|
|||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
eos_token_id = None
|
||||||
lora_request).eos_token_id
|
if self.tokenizer:
|
||||||
|
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
||||||
|
lora_request).eos_token_id
|
||||||
|
else:
|
||||||
|
logger.warning("Use None for EOS token id because tokenizer is "
|
||||||
|
"not initialized")
|
||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||||
eos_token_id, lora_request)
|
eos_token_id, lora_request)
|
||||||
|
|
||||||
|
@ -59,7 +59,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
|
|
||||||
# Process prompt logprobs
|
# Process prompt logprobs
|
||||||
prompt_logprobs = outputs.prompt_logprobs
|
prompt_logprobs = outputs.prompt_logprobs
|
||||||
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
if prompt_logprobs is not None and \
|
||||||
|
seq_group.sampling_params.detokenize and self.detokenizer:
|
||||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq_group, prompt_logprobs)
|
seq_group, prompt_logprobs)
|
||||||
seq_group.prompt_logprobs = prompt_logprobs
|
seq_group.prompt_logprobs = prompt_logprobs
|
||||||
@ -105,7 +106,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
if seq_group.sampling_params.detokenize:
|
if seq_group.sampling_params.detokenize and self.detokenizer:
|
||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
seq, seq_group.sampling_params)
|
seq, seq_group.sampling_params)
|
||||||
else:
|
else:
|
||||||
|
@ -32,6 +32,9 @@ class LLM:
|
|||||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||||
if available, and "slow" will always use the slow tokenizer.
|
if available, and "slow" will always use the slow tokenizer.
|
||||||
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||||
|
detokenizer. Expect valid prompt_token_ids and None for prompt
|
||||||
|
from the input.
|
||||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
tensor_parallel_size: The number of GPUs to use for distributed
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
@ -76,6 +79,7 @@ class LLM:
|
|||||||
model: str,
|
model: str,
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
@ -96,6 +100,7 @@ class LLM:
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -160,6 +165,10 @@ class LLM:
|
|||||||
if prompts is None and prompt_token_ids is None:
|
if prompts is None and prompt_token_ids is None:
|
||||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
"provided.")
|
"provided.")
|
||||||
|
if self.llm_engine.model_config.skip_tokenizer_init \
|
||||||
|
and prompts is not None:
|
||||||
|
raise ValueError("prompts must be None if skip_tokenizer_init "
|
||||||
|
"is True")
|
||||||
if isinstance(prompts, str):
|
if isinstance(prompts, str):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user