[V1][Perf] Faster incremental detokenization (#15137)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-17 07:45:24 -07:00 committed by GitHub
parent 7c02d6a137
commit 05fcd1b430
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 317 additions and 145 deletions

View File

@ -8,7 +8,7 @@ blake3
py-cpuinfo
transformers >= 4.51.1
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp

View File

@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.51.1
tokenizers==0.21.1
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3

View File

@ -624,8 +624,10 @@ tiktoken==0.7.0
# mistral-common
timm==1.0.11
# via -r requirements/test.in
tokenizers==0.21.0
# via transformers
tokenizers==0.21.1
# via
# -r requirements/test.in
# transformers
torch==2.6.0
# via
# -r requirements/test.in

View File

@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
skip_special_tokens=False,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,

View File

@ -4,14 +4,22 @@ from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import AutoTokenizer
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [
"Hello here, this is a simple test",
@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
]
] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [
"facebook/opt-125m",
"gpt2",
@ -38,26 +47,37 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool, starting_index: int):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
prev_tokens,
offset,
token_offset,
skip_special_tokens=skip_special_tokens)
decoded_text += text
if prev_tokens is None:
prev_tokens = new_tokens
else:
prev_tokens += new_tokens
return decoded_text
def _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
params, None, 0.0, None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture
@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text = _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture
@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
else {"spaces_between_special_tokens": spaces_between_special_tokens}
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens)
generated = truth[len(prompt):]
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
decoded_text = _run_incremental_decode(
generated = new_truth[len(prompt):]
else:
generated = new_truth
starting_index = 0
all_input_ids = truth_tokens
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
decoded_text = _run_incremental_decode(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)],
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture
@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
prompt_token_ids = prompt_token_ids or []
return Sequence(
seq_id=0,
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens:
if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True,
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the

View File

@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because mistral uses the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,

View File

@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from typing import Optional
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
@dataclass
class IncrementalDetokenizer:
# Generation data
token_ids: list[int]
output_text: str = ""
tokens: list[str] = field(default_factory=list)
prompt_len: int = 0
# Stop strings
stop: list[str] = field(default_factory=list)
include_stop_str_in_output: bool = False
# Metadata for incremental detokenization
prefix_offset: int = 0
read_offset: int = 0
# Parameters for detokenization
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer] = None
# Accounting for stop string buffering
stop_buffer_length: int = 0
_last_output_text_offset: int = 0
def __init__(self):
self.token_ids: list[int] = []
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
return self.token_ids
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod
def from_new_request(
@ -54,39 +40,37 @@ class IncrementalDetokenizer:
) -> "IncrementalDetokenizer":
if tokenizer is None:
return cls(token_ids=[])
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.skip_special_tokens,
)
if isinstance(tokenizer, PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
stops = request.sampling_params.stop
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stops and not request.sampling_params.include_stop_str_in_output:
stop_buffer_length = max(len(s) for s in stops) - 1
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
else:
stop_buffer_length = 0
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
return cls(
tokens=tokens,
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(),
stop=stops,
include_stop_str_in_output=request.sampling_params.
include_stop_str_in_output,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens,
prompt_len=len(request.prompt_token_ids),
tokenizer=tokenizer,
stop_buffer_length=stop_buffer_length,
)
# Generation data
self.output_text = ""
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
@ -98,11 +82,7 @@ class IncrementalDetokenizer:
Return matched stop string or None.
"""
if not new_token_ids:
# Skip detokenization if no new token ids
return None
if self.tokenizer is None:
# Skip detokenization if no tokenizer
self.token_ids.extend(new_token_ids)
# Skip detokenization if no new token ids.
return None
if stop_terminated and not self.include_stop_str_in_output:
@ -116,34 +96,16 @@ class IncrementalDetokenizer:
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
decoded_text = ""
offset_before = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
)
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
decoded_text += new_decoded_token_text
self.output_text += decoded_text
self.output_text += self.decode_next(new_token_id)
if stop_terminated:
if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization
# Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id)
# Stop token triggered; skip stop string check
# Stop token triggered; skip stop string check.
return None
# 2) Evaluate stop strings.
@ -151,7 +113,7 @@ class IncrementalDetokenizer:
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(decoded_text),
new_char_count=len(self.output_text) - offset_before,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
@ -162,6 +124,10 @@ class IncrementalDetokenizer:
return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
@ -177,3 +143,114 @@ class IncrementalDetokenizer:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast,
request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
self.stream = DecodeStream(
skip_special_tokens=sampling_params.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, max(prompt_len + 1, 32)):
suffix = request.prompt_token_ids[-i:]
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self.stream.step(self.tokenizer, tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
None)) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in
self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self.stream.step(self.tokenizer, next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = (
detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
))
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text