[ BugFix ] Prompt Logprobs Detokenization (#6223)
Co-authored-by: Zifei Tong <zifeitong@gmail.com>
This commit is contained in:
parent
a4feba929b
commit
7ed6a4f0e1
@ -87,7 +87,10 @@ steps:
|
|||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
commands:
|
||||||
|
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||||
|
# OOM in the CI unless we run this separately
|
||||||
|
- pytest -v -s tokenization
|
||||||
|
|
||||||
- label: Entrypoints Test
|
- label: Entrypoints Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -139,6 +139,15 @@ def create_dummy_logprobs(
|
|||||||
} for token_id in complete_sequence_token_ids]
|
} for token_id in complete_sequence_token_ids]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_prompt_logprobs(
|
||||||
|
complete_sequence_token_ids: List[int]
|
||||||
|
) -> List[Optional[Dict[int, Any]]]:
|
||||||
|
# logprob for the first prompt token is None.
|
||||||
|
logprobs: List[Optional[Dict[int, Any]]] = [None]
|
||||||
|
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||||
@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
|||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", [True])
|
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
detokenizer: Detokenizer):
|
||||||
complete_sequence_token_ids: List[int],
|
|
||||||
detokenizer: Detokenizer,
|
|
||||||
skip_special_tokens: bool):
|
|
||||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||||
prompt_logprobs=1)
|
prompt_logprobs=1)
|
||||||
|
|
||||||
# Run sequentially.
|
# Run sequentially.
|
||||||
@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
|
|||||||
seqs=[seq],
|
seqs=[seq],
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
arrival_time=0.0)
|
arrival_time=0.0)
|
||||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
||||||
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
|
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
||||||
decoded_prompt_logprobs = dummy_logprobs
|
dummy_logprobs,
|
||||||
|
position_offset=0)
|
||||||
|
# First logprob is None.
|
||||||
|
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
|
||||||
|
1:] # type: ignore
|
||||||
|
|
||||||
if skip_special_tokens:
|
# decoded_prompt_logprobs doesn't contain the first token.
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
token_ids = complete_sequence_token_ids
|
||||||
# prompt text. Note that this will only be true if we skip
|
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
|
||||||
# special tokens.
|
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
|
||||||
assert complete_sequence == "".join([
|
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
|
||||||
logprobs[token_id].decoded_token for token_id, logprobs in zip(
|
text = text_full[len(text_first):]
|
||||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
|
||||||
])
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
assert complete_sequence != "".join([
|
# prompt text. Note that the first logprob is None.
|
||||||
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
|
assert text == "".join([
|
||||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
logprobs[token_id].decoded_token
|
||||||
])
|
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||||
|
])
|
||||||
|
assert text != "".join([
|
||||||
|
logprobs[token_id + 1].decoded_token
|
||||||
|
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
|
||||||
|
def test_decode_prompt_logprobs_chunked_prefill(
|
||||||
|
vllm_runner,
|
||||||
|
model,
|
||||||
|
chunked_prefill_token_size: int,
|
||||||
|
example_prompts,
|
||||||
|
):
|
||||||
|
max_num_seqs = 256
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
max_num_batched_tokens = None
|
||||||
|
if chunked_prefill_token_size != -1:
|
||||||
|
enable_chunked_prefill = True
|
||||||
|
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||||
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype="half",
|
||||||
|
max_logprobs=5,
|
||||||
|
gpu_memory_utilization=0.5,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_num_seqs) as vllm_model:
|
||||||
|
|
||||||
|
vllm_sampling_params = SamplingParams(max_tokens=10,
|
||||||
|
logprobs=5,
|
||||||
|
prompt_logprobs=5,
|
||||||
|
temperature=0.0)
|
||||||
|
vllm_results = vllm_model.model.generate(
|
||||||
|
example_prompts, sampling_params=vllm_sampling_params)
|
||||||
|
|
||||||
|
for idx, result in enumerate(vllm_results):
|
||||||
|
assert result.prompt_logprobs is not None
|
||||||
|
assert result.prompt_logprobs[0] is None
|
||||||
|
|
||||||
|
# Compared detokenized prompts ids to original prompt.
|
||||||
|
generated_string = ""
|
||||||
|
for (prompt_token,
|
||||||
|
prompt_logprobs) in zip(result.prompt_token_ids[1:],
|
||||||
|
result.prompt_logprobs[1:]):
|
||||||
|
# prompt_logprobs is a dict of the token_id: logprob
|
||||||
|
# We select the token_id corresponding to the actual prompt
|
||||||
|
# Decoded token in the detokenized string corresponding to this
|
||||||
|
# prompt token.
|
||||||
|
generated_string += prompt_logprobs[prompt_token].decoded_token
|
||||||
|
|
||||||
|
assert generated_string == example_prompts[idx], (
|
||||||
|
"Detokenized prompt logprobs do not match original prompt")
|
||||||
|
@ -60,14 +60,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||||
output = outputs[0]
|
output = outputs[0]
|
||||||
prompt_logprobs = output.prompt_logprobs
|
prompt_logprobs = output.prompt_logprobs
|
||||||
|
|
||||||
|
# If this is the first (or only) "chunk" of the prefill, we need
|
||||||
|
# to prepend None to the list of prompt logprobs. The reason for this
|
||||||
|
# is that for N prompt tokens, the Sampler will generate N-1 total
|
||||||
|
# prompt logprobs during prefill since the token at idx 0 will not
|
||||||
|
# have a logprob associated with it.
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
|
if not seq_group.prompt_logprobs:
|
||||||
|
prompt_logprobs = [None] + prompt_logprobs
|
||||||
|
seq_group.prompt_logprobs = []
|
||||||
|
|
||||||
if seq_group.sampling_params.detokenize and self.detokenizer:
|
if seq_group.sampling_params.detokenize and self.detokenizer:
|
||||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq_group, prompt_logprobs)
|
seq_group,
|
||||||
if not seq_group.prompt_logprobs:
|
prompt_logprobs,
|
||||||
# The first prompt token's logprob is None because it doesn't
|
position_offset=len(seq_group.prompt_logprobs))
|
||||||
# have tokens that are precedent.
|
|
||||||
seq_group.prompt_logprobs = [None]
|
|
||||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||||
|
@ -21,14 +21,17 @@ class Detokenizer:
|
|||||||
"""Returns the HF tokenizer to use for a given sequence."""
|
"""Returns the HF tokenizer to use for a given sequence."""
|
||||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
def decode_prompt_logprobs_inplace(
|
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||||
self, seq_group: SequenceGroup,
|
prompt_logprobs: List[Optional[Dict[
|
||||||
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
|
int, Logprob]]],
|
||||||
|
position_offset: int) -> None:
|
||||||
"""Decodes the logprobs for the prompt of a sequence group.
|
"""Decodes the logprobs for the prompt of a sequence group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_group: The sequence group to decode.
|
seq_group: The sequence group to decode.
|
||||||
prompt_logprobs: The logprobs to decode.
|
prompt_logprobs: The logprobs to decode.
|
||||||
|
position_offset: Offset of the first index of the logprobs
|
||||||
|
relative to the start of the sequence (for chunked prefill).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The prompt logprobs with the decoded tokens.
|
The prompt logprobs with the decoded tokens.
|
||||||
@ -47,8 +50,13 @@ class Detokenizer:
|
|||||||
next_iter_tokens: List[str] = []
|
next_iter_tokens: List[str] = []
|
||||||
prev_tokens = None
|
prev_tokens = None
|
||||||
|
|
||||||
for token_position, prompt_logprobs_for_token in enumerate(
|
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
|
||||||
prompt_logprobs):
|
prompt_logprobs):
|
||||||
|
|
||||||
|
# Absolute token position equals the index in the logprobs
|
||||||
|
# list plus the offset of the entire logprobs list relative
|
||||||
|
# to the start of the sequence.
|
||||||
|
token_position = token_position_in_logprob + position_offset
|
||||||
if not prompt_logprobs_for_token:
|
if not prompt_logprobs_for_token:
|
||||||
continue
|
continue
|
||||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user