[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
7c02d6a137
commit
05fcd1b430
@ -8,7 +8,7 @@ blake3
|
|||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers >= 4.51.1
|
transformers >= 4.51.1
|
||||||
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
|
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
|
||||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||||
protobuf # Required by LlamaTokenizer.
|
protobuf # Required by LlamaTokenizer.
|
||||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||||
aiohttp
|
aiohttp
|
||||||
|
@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
|||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||||
transformers==4.51.1
|
transformers==4.51.1
|
||||||
|
tokenizers==0.21.1
|
||||||
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
||||||
# quantization
|
# quantization
|
||||||
bitsandbytes>=0.45.3
|
bitsandbytes>=0.45.3
|
||||||
|
@ -624,8 +624,10 @@ tiktoken==0.7.0
|
|||||||
# mistral-common
|
# mistral-common
|
||||||
timm==1.0.11
|
timm==1.0.11
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
tokenizers==0.21.0
|
tokenizers==0.21.1
|
||||||
# via transformers
|
# via
|
||||||
|
# -r requirements/test.in
|
||||||
|
# transformers
|
||||||
torch==2.6.0
|
torch==2.6.0
|
||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
|
@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
|||||||
]
|
]
|
||||||
sampling_params = vllm.SamplingParams(temperature=0,
|
sampling_params = vllm.SamplingParams(temperature=0,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
|
skip_special_tokens=False,
|
||||||
stop=["[/assistant]"])
|
stop=["[/assistant]"])
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts,
|
prompts,
|
||||||
|
@ -4,14 +4,22 @@ from collections.abc import Generator
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
from vllm.inputs import token_inputs
|
from vllm.inputs import token_inputs
|
||||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
detokenize_incrementally)
|
|
||||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||||
|
IncrementalDetokenizer,
|
||||||
|
SlowIncrementalDetokenizer)
|
||||||
|
|
||||||
|
SPECIAL_TOKS_TRUTH = [
|
||||||
|
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
|
||||||
|
]
|
||||||
|
|
||||||
TRUTH = [
|
TRUTH = [
|
||||||
"Hello here, this is a simple test",
|
"Hello here, this is a simple test",
|
||||||
@ -22,7 +30,8 @@ TRUTH = [
|
|||||||
# incomplete UTF-8 characters
|
# incomplete UTF-8 characters
|
||||||
# see https://github.com/vllm-project/vllm/pull/9625
|
# see https://github.com/vllm-project/vllm/pull/9625
|
||||||
"ပုံပြင်လေးပြောပြပါ်",
|
"ပုံပြင်လေးပြောပြပါ်",
|
||||||
]
|
] + SPECIAL_TOKS_TRUTH
|
||||||
|
|
||||||
TOKENIZERS = [
|
TOKENIZERS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
"gpt2",
|
"gpt2",
|
||||||
@ -38,26 +47,37 @@ TOKENIZERS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
def _run_incremental_decode(tokenizer,
|
||||||
skip_special_tokens: bool, starting_index: int):
|
all_input_ids,
|
||||||
decoded_text = ""
|
skip_special_tokens: bool,
|
||||||
offset = 0
|
starting_index: int,
|
||||||
token_offset = 0
|
spaces_between_special_tokens: bool = True,
|
||||||
prev_tokens = None
|
fast: Optional[bool] = None):
|
||||||
for i in range(starting_index, len(all_input_ids)):
|
|
||||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
prompt_token_ids = all_input_ids[:starting_index]
|
||||||
tokenizer,
|
|
||||||
all_input_ids[:i + 1],
|
params = SamplingParams(
|
||||||
prev_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
offset,
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
token_offset,
|
)
|
||||||
skip_special_tokens=skip_special_tokens)
|
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
|
||||||
decoded_text += text
|
params, None, 0.0, None)
|
||||||
if prev_tokens is None:
|
|
||||||
prev_tokens = new_tokens
|
if fast is None:
|
||||||
else:
|
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||||
prev_tokens += new_tokens
|
tokenizer, request)
|
||||||
return decoded_text
|
elif fast:
|
||||||
|
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||||
|
else:
|
||||||
|
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
|
output_text = ""
|
||||||
|
for i, token_id in enumerate(all_input_ids[starting_index:]):
|
||||||
|
detokenizer.update([token_id], False)
|
||||||
|
finished = i == len(all_input_ids) - 1
|
||||||
|
output_text += detokenizer.get_next_output_text(finished, delta=True)
|
||||||
|
|
||||||
|
return output_text, detokenizer.output_token_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
|
|||||||
starting_index = 0
|
starting_index = 0
|
||||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||||
|
|
||||||
decoded_text = _run_incremental_decode(tokenizer,
|
decoded_text, out_ids = _run_incremental_decode(
|
||||||
all_input_ids,
|
tokenizer,
|
||||||
skip_special_tokens=True,
|
all_input_ids,
|
||||||
starting_index=starting_index)
|
skip_special_tokens=True,
|
||||||
|
starting_index=starting_index)
|
||||||
assert decoded_text == truth
|
assert decoded_text == truth
|
||||||
|
assert out_ids == all_input_ids[starting_index:]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
|||||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
|
||||||
|
@pytest.mark.parametrize("fast", (True, False))
|
||||||
|
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||||
|
spaces_between_special_tokens, fast):
|
||||||
|
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
if skip_special_tokens and not spaces_between_special_tokens:
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
|
# Fix up inconsistency in fast/slow tokenizer behaviour.
|
||||||
|
tokenizer.add_special_tokens({
|
||||||
|
"additional_special_tokens": [
|
||||||
|
at for at in
|
||||||
|
tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||||
|
if at.special
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
|
||||||
|
else {"spaces_between_special_tokens": spaces_between_special_tokens}
|
||||||
|
|
||||||
|
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||||
|
if tokenizer.bos_token_id is not None:
|
||||||
|
truth_tokens.insert(0, tokenizer.bos_token_id)
|
||||||
|
truth_tokens.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
new_truth = tokenizer.decode(truth_tokens,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
**extra_decode_args)
|
||||||
|
|
||||||
if with_prompt:
|
if with_prompt:
|
||||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
num_prompt_tokens = len(
|
||||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
tokenizer(truth[:len(truth) // 2],
|
||||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
add_special_tokens=False).input_ids)
|
||||||
|
if tokenizer.bos_token_id is not None:
|
||||||
|
num_prompt_tokens += 1
|
||||||
|
|
||||||
|
prompt_input_ids = truth_tokens[:num_prompt_tokens]
|
||||||
|
generated_input_ids = truth_tokens[num_prompt_tokens:]
|
||||||
all_input_ids = prompt_input_ids + generated_input_ids
|
all_input_ids = prompt_input_ids + generated_input_ids
|
||||||
starting_index = len(prompt_input_ids)
|
starting_index = len(prompt_input_ids)
|
||||||
prompt = tokenizer.decode(prompt_input_ids,
|
prompt = tokenizer.decode(prompt_input_ids,
|
||||||
skip_special_tokens=skip_special_tokens)
|
skip_special_tokens=skip_special_tokens,
|
||||||
generated = truth[len(prompt):]
|
**extra_decode_args)
|
||||||
else:
|
|
||||||
generated = truth
|
|
||||||
starting_index = 0
|
|
||||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
|
||||||
if skip_special_tokens:
|
|
||||||
if tokenizer.bos_token_id is not None:
|
|
||||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
|
||||||
starting_index += 1
|
|
||||||
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
decoded_text = _run_incremental_decode(
|
generated = new_truth[len(prompt):]
|
||||||
|
else:
|
||||||
|
generated = new_truth
|
||||||
|
starting_index = 0
|
||||||
|
all_input_ids = truth_tokens
|
||||||
|
|
||||||
|
decoded_text, out_ids = _run_incremental_decode(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
starting_index=starting_index)
|
starting_index=starting_index,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
fast=fast)
|
||||||
|
|
||||||
assert decoded_text == generated
|
assert decoded_text == generated
|
||||||
|
assert out_ids == all_input_ids[starting_index:]
|
||||||
|
|
||||||
decoded_text = _run_incremental_decode(
|
|
||||||
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
|
@pytest.mark.parametrize("fast", (True, False))
|
||||||
|
def test_oov_decode(tokenizer, fast):
|
||||||
|
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
decoded_text, out_ids = _run_incremental_decode(
|
||||||
tokenizer, [len(tokenizer)],
|
tokenizer, [len(tokenizer)],
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=True,
|
||||||
starting_index=starting_index)
|
starting_index=0,
|
||||||
|
spaces_between_special_tokens=True,
|
||||||
|
fast=fast)
|
||||||
|
|
||||||
assert decoded_text == ''
|
assert decoded_text == ''
|
||||||
|
assert out_ids == [len(tokenizer)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
|||||||
@pytest.fixture(name="complete_sequence_token_ids")
|
@pytest.fixture(name="complete_sequence_token_ids")
|
||||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||||
tokenizer) -> list[int]:
|
tokenizer) -> list[int]:
|
||||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
|
||||||
return complete_sequence_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def create_sequence(prompt_token_ids=None):
|
def create_sequence(prompt_token_ids=None):
|
||||||
prompt_token_ids = prompt_token_ids or [1]
|
prompt_token_ids = prompt_token_ids or []
|
||||||
return Sequence(
|
return Sequence(
|
||||||
seq_id=0,
|
seq_id=0,
|
||||||
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
|
inputs=token_inputs(prompt_token_ids),
|
||||||
block_size=16,
|
block_size=16,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
|||||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||||
|
|
||||||
if skip_special_tokens:
|
if not skip_special_tokens:
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
# generated text. Note that this will only be true if we skip
|
# generated text. Note that this will only be true if we skip
|
||||||
# special tokens.
|
# special tokens.
|
||||||
@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
|||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||||
|
complete_sequence_token_ids: list[int],
|
||||||
detokenizer: Detokenizer):
|
detokenizer: Detokenizer):
|
||||||
|
|
||||||
|
# We want to use skip_special_tokens=False here but Mistral tokenizers
|
||||||
|
# don't support that.
|
||||||
|
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
||||||
|
skip_special_tokens = True
|
||||||
|
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
|
||||||
|
MistralTokenizer):
|
||||||
|
skip_special_tokens = False
|
||||||
|
else:
|
||||||
|
pytest.skip("MistralTokenizers don't support "
|
||||||
|
"skip_special_tokens=False")
|
||||||
|
return
|
||||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||||
prompt_logprobs=1)
|
prompt_logprobs=1)
|
||||||
|
|
||||||
# Run sequentially.
|
# Run sequentially.
|
||||||
@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
|||||||
# decoded_prompt_logprobs doesn't contain the first token.
|
# decoded_prompt_logprobs doesn't contain the first token.
|
||||||
token_ids = complete_sequence_token_ids
|
token_ids = complete_sequence_token_ids
|
||||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||||
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
text_full = tokenizer.decode(token_ids,
|
||||||
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
text_first = tokenizer.decode(token_ids[0],
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
text = text_full[len(text_first):]
|
text = text_full[len(text_first):]
|
||||||
|
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
|
@ -70,6 +70,15 @@ class MistralToolParser(ToolParser):
|
|||||||
"Mistral Tool Parser could not locate the tool call token in "
|
"Mistral Tool Parser could not locate the tool call token in "
|
||||||
"the tokenizer!")
|
"the tokenizer!")
|
||||||
|
|
||||||
|
def adjust_request(
|
||||||
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
|
if request.tools and request.tool_choice != 'none':
|
||||||
|
# do not skip special tokens because mistral uses the special
|
||||||
|
# tokens to indicate the start and end of the tool calls
|
||||||
|
# information.
|
||||||
|
request.skip_special_tokens = False
|
||||||
|
return request
|
||||||
|
|
||||||
def extract_tool_calls(
|
def extract_tool_calls(
|
||||||
self,
|
self,
|
||||||
model_output: str,
|
model_output: str,
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
from tokenizers.decoders import DecodeStream
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.detokenizer_utils import (
|
from vllm.transformers_utils.detokenizer_utils import (
|
||||||
@ -12,39 +15,22 @@ from vllm.v1.engine import EngineCoreRequest
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IncrementalDetokenizer:
|
class IncrementalDetokenizer:
|
||||||
|
|
||||||
# Generation data
|
def __init__(self):
|
||||||
token_ids: list[int]
|
self.token_ids: list[int] = []
|
||||||
output_text: str = ""
|
|
||||||
tokens: list[str] = field(default_factory=list)
|
|
||||||
prompt_len: int = 0
|
|
||||||
|
|
||||||
# Stop strings
|
|
||||||
stop: list[str] = field(default_factory=list)
|
|
||||||
include_stop_str_in_output: bool = False
|
|
||||||
|
|
||||||
# Metadata for incremental detokenization
|
|
||||||
prefix_offset: int = 0
|
|
||||||
read_offset: int = 0
|
|
||||||
|
|
||||||
# Parameters for detokenization
|
|
||||||
skip_special_tokens: bool = True
|
|
||||||
spaces_between_special_tokens: bool = True
|
|
||||||
|
|
||||||
# Tokenizer for this request,
|
|
||||||
# None if detokenization is disabled.
|
|
||||||
tokenizer: Optional[AnyTokenizer] = None
|
|
||||||
|
|
||||||
# Accounting for stop string buffering
|
|
||||||
stop_buffer_length: int = 0
|
|
||||||
_last_output_text_offset: int = 0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids(self) -> list[int]:
|
def output_token_ids(self) -> list[int]:
|
||||||
return self.token_ids if not self.prompt_len else (
|
return self.token_ids
|
||||||
self.token_ids[self.prompt_len:])
|
|
||||||
|
def update(self, new_token_ids: list[int],
|
||||||
|
stop_terminated: bool) -> Optional[str]:
|
||||||
|
self.token_ids.extend(new_token_ids)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_request(
|
def from_new_request(
|
||||||
@ -54,39 +40,37 @@ class IncrementalDetokenizer:
|
|||||||
) -> "IncrementalDetokenizer":
|
) -> "IncrementalDetokenizer":
|
||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
return cls(token_ids=[])
|
# No tokenizer => skipping detokenization.
|
||||||
|
return IncrementalDetokenizer()
|
||||||
|
|
||||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
tokenizer=tokenizer,
|
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||||
prompt_ids=request.prompt_token_ids,
|
return FastIncrementalDetokenizer(tokenizer, request)
|
||||||
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
|
||||||
)
|
# Fall back to slow python-based incremental detokenization.
|
||||||
|
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||||
|
|
||||||
|
def __init__(self, request: EngineCoreRequest):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Stop strings
|
||||||
|
params = request.sampling_params
|
||||||
|
self.stop = stop = params.stop
|
||||||
|
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||||
|
|
||||||
stops = request.sampling_params.stop
|
|
||||||
# Number of chars to hold back when stop strings are to be excluded
|
# Number of chars to hold back when stop strings are to be excluded
|
||||||
# from streamed output.
|
# from streamed output.
|
||||||
if stops and not request.sampling_params.include_stop_str_in_output:
|
if stop and not self.include_stop_str_in_output:
|
||||||
stop_buffer_length = max(len(s) for s in stops) - 1
|
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
||||||
else:
|
else:
|
||||||
stop_buffer_length = 0
|
self.stop_buffer_length = 0
|
||||||
|
self._last_output_text_offset: int = 0
|
||||||
|
|
||||||
return cls(
|
# Generation data
|
||||||
tokens=tokens,
|
self.output_text = ""
|
||||||
# Detokenizer mutates this list, so need a unique copy.
|
|
||||||
# NOTE(Nick): could we take ownership of it though?
|
|
||||||
token_ids=request.prompt_token_ids.copy(),
|
|
||||||
stop=stops,
|
|
||||||
include_stop_str_in_output=request.sampling_params.
|
|
||||||
include_stop_str_in_output,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=request.sampling_params.
|
|
||||||
spaces_between_special_tokens,
|
|
||||||
prompt_len=len(request.prompt_token_ids),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
stop_buffer_length=stop_buffer_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, new_token_ids: list[int],
|
def update(self, new_token_ids: list[int],
|
||||||
stop_terminated: bool) -> Optional[str]:
|
stop_terminated: bool) -> Optional[str]:
|
||||||
@ -98,11 +82,7 @@ class IncrementalDetokenizer:
|
|||||||
Return matched stop string or None.
|
Return matched stop string or None.
|
||||||
"""
|
"""
|
||||||
if not new_token_ids:
|
if not new_token_ids:
|
||||||
# Skip detokenization if no new token ids
|
# Skip detokenization if no new token ids.
|
||||||
return None
|
|
||||||
if self.tokenizer is None:
|
|
||||||
# Skip detokenization if no tokenizer
|
|
||||||
self.token_ids.extend(new_token_ids)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if stop_terminated and not self.include_stop_str_in_output:
|
if stop_terminated and not self.include_stop_str_in_output:
|
||||||
@ -116,34 +96,16 @@ class IncrementalDetokenizer:
|
|||||||
# 1) Detokenize the new token ids incrementally.
|
# 1) Detokenize the new token ids incrementally.
|
||||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||||
# new_token_ids is more than 1. We need to optimize this.
|
# new_token_ids is more than 1. We need to optimize this.
|
||||||
decoded_text = ""
|
offset_before = len(self.output_text)
|
||||||
for new_token_id in new_token_ids:
|
for new_token_id in new_token_ids:
|
||||||
self.token_ids.append(new_token_id)
|
self.token_ids.append(new_token_id)
|
||||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
self.output_text += self.decode_next(new_token_id)
|
||||||
read_offset) = detokenize_incrementally(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
all_input_ids=self.token_ids,
|
|
||||||
prev_tokens=self.tokens,
|
|
||||||
prefix_offset=self.prefix_offset,
|
|
||||||
read_offset=self.read_offset,
|
|
||||||
skip_special_tokens=self.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=self.
|
|
||||||
spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tokens.extend(new_tokens)
|
|
||||||
self.prefix_offset = prefix_offset
|
|
||||||
self.read_offset = read_offset
|
|
||||||
|
|
||||||
decoded_text += new_decoded_token_text
|
|
||||||
|
|
||||||
self.output_text += decoded_text
|
|
||||||
|
|
||||||
if stop_terminated:
|
if stop_terminated:
|
||||||
if skipped_stop_token_id is not None:
|
if skipped_stop_token_id is not None:
|
||||||
# Cleanup after skipping detokenization
|
# Cleanup after skipping detokenization.
|
||||||
self.token_ids.append(skipped_stop_token_id)
|
self.token_ids.append(skipped_stop_token_id)
|
||||||
# Stop token triggered; skip stop string check
|
# Stop token triggered; skip stop string check.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 2) Evaluate stop strings.
|
# 2) Evaluate stop strings.
|
||||||
@ -151,7 +113,7 @@ class IncrementalDetokenizer:
|
|||||||
if self.stop:
|
if self.stop:
|
||||||
stop = StopChecker.check_stop_strings(
|
stop = StopChecker.check_stop_strings(
|
||||||
output_text=self.output_text,
|
output_text=self.output_text,
|
||||||
new_char_count=len(decoded_text),
|
new_char_count=len(self.output_text) - offset_before,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
include_in_output=self.include_stop_str_in_output,
|
include_in_output=self.include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
@ -162,6 +124,10 @@ class IncrementalDetokenizer:
|
|||||||
|
|
||||||
return stop_string
|
return stop_string
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode_next(self, next_token_id: int) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||||
"""If delta is True, only new text since the last call to
|
"""If delta is True, only new text since the last call to
|
||||||
this method is returned"""
|
this method is returned"""
|
||||||
@ -177,3 +143,114 @@ class IncrementalDetokenizer:
|
|||||||
self._last_output_text_offset = length
|
self._last_output_text_offset = length
|
||||||
return self.output_text[last_offset:length]
|
return self.output_text[last_offset:length]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerFast,
|
||||||
|
request: EngineCoreRequest):
|
||||||
|
super().__init__(request)
|
||||||
|
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
self.stream = DecodeStream(
|
||||||
|
skip_special_tokens=sampling_params.skip_special_tokens)
|
||||||
|
|
||||||
|
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||||
|
|
||||||
|
# Find a safe place to start.
|
||||||
|
prompt_suffix = request.prompt_token_ids
|
||||||
|
prompt_len = len(prompt_suffix)
|
||||||
|
if prompt_len > 4:
|
||||||
|
for i in range(4, max(prompt_len + 1, 32)):
|
||||||
|
suffix = request.prompt_token_ids[-i:]
|
||||||
|
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||||||
|
prompt_suffix = suffix
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prime the stream.
|
||||||
|
for tid in prompt_suffix:
|
||||||
|
self.stream.step(self.tokenizer, tid)
|
||||||
|
|
||||||
|
self.spaces_between_special_tokens = (
|
||||||
|
sampling_params.skip_special_tokens
|
||||||
|
or sampling_params.spaces_between_special_tokens)
|
||||||
|
|
||||||
|
if not self.spaces_between_special_tokens:
|
||||||
|
# Store dict of added token ids so that we can suppress
|
||||||
|
# the spaces between them.
|
||||||
|
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
|
||||||
|
None)) is None:
|
||||||
|
self.tokenizer.added_token_ids = added_token_ids = {
|
||||||
|
tid: tok.content
|
||||||
|
for tid, tok in
|
||||||
|
self.tokenizer.get_added_tokens_decoder().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if added_token_ids:
|
||||||
|
self.last_special = False
|
||||||
|
self.added_token_ids = added_token_ids
|
||||||
|
else:
|
||||||
|
# No added tokens.
|
||||||
|
self.spaces_between_special_tokens = True
|
||||||
|
|
||||||
|
def decode_next(self, next_token_id: int) -> str:
|
||||||
|
token = self.stream.step(self.tokenizer, next_token_id)
|
||||||
|
|
||||||
|
if not self.spaces_between_special_tokens:
|
||||||
|
special_token = self.added_token_ids.get(next_token_id)
|
||||||
|
is_special = special_token is not None
|
||||||
|
if is_special and self.last_special:
|
||||||
|
# Return raw token string without any prefixed spaces.
|
||||||
|
token = special_token
|
||||||
|
self.last_special = is_special
|
||||||
|
|
||||||
|
return token or ""
|
||||||
|
|
||||||
|
|
||||||
|
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
||||||
|
super().__init__(request)
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Metadata for incremental detokenization.
|
||||||
|
self.tokens, self.prefix_offset, self.read_offset = (
|
||||||
|
convert_prompt_ids_to_tokens(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt_ids=request.prompt_token_ids,
|
||||||
|
skip_special_tokens=request.sampling_params.
|
||||||
|
skip_special_tokens,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.token_ids.extend(request.prompt_token_ids)
|
||||||
|
self.prompt_len = len(request.prompt_token_ids)
|
||||||
|
|
||||||
|
params = request.sampling_params
|
||||||
|
self.skip_special_tokens = params.skip_special_tokens
|
||||||
|
self.spaces_between_special_tokens = (
|
||||||
|
params.spaces_between_special_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_token_ids(self) -> list[int]:
|
||||||
|
return self.token_ids if not self.prompt_len else (
|
||||||
|
self.token_ids[self.prompt_len:])
|
||||||
|
|
||||||
|
def decode_next(self, next_token_id: int) -> str:
|
||||||
|
new_tokens, decoded_text, prefix_offset, read_offset = (
|
||||||
|
detokenize_incrementally(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
all_input_ids=self.token_ids,
|
||||||
|
prev_tokens=self.tokens,
|
||||||
|
prefix_offset=self.prefix_offset,
|
||||||
|
read_offset=self.read_offset,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.
|
||||||
|
spaces_between_special_tokens,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.tokens.extend(new_tokens)
|
||||||
|
self.prefix_offset = prefix_offset
|
||||||
|
self.read_offset = read_offset
|
||||||
|
|
||||||
|
return decoded_text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user