[ BugFix ] Prompt Logprobs Detokenization (#6223)
Co-authored-by: Zifei Tong <zifeitong@gmail.com>
This commit is contained in:
parent
a4feba929b
commit
7ed6a4f0e1
@ -87,7 +87,10 @@ steps:
|
||||
|
||||
- label: Engine Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
|
||||
- label: Entrypoints Test
|
||||
mirror_hardwares: [amd]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -139,6 +139,15 @@ def create_dummy_logprobs(
|
||||
} for token_id in complete_sequence_token_ids]
|
||||
|
||||
|
||||
def create_dummy_prompt_logprobs(
|
||||
complete_sequence_token_ids: List[int]
|
||||
) -> List[Optional[Dict[int, Any]]]:
|
||||
# logprob for the first prompt token is None.
|
||||
logprobs: List[Optional[Dict[int, Any]]] = [None]
|
||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||
return logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||
@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True])
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer):
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
seqs=[seq],
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=0.0)
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
|
||||
decoded_prompt_logprobs = dummy_logprobs
|
||||
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
||||
dummy_logprobs,
|
||||
position_offset=0)
|
||||
# First logprob is None.
|
||||
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
|
||||
1:] # type: ignore
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
if skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert complete_sequence == "".join([
|
||||
logprobs[token_id].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
# prompt text. Note that the first logprob is None.
|
||||
assert text == "".join([
|
||||
logprobs[token_id].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
assert complete_sequence != "".join([
|
||||
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
assert text != "".join([
|
||||
logprobs[token_id + 1].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
|
||||
def test_decode_prompt_logprobs_chunked_prefill(
|
||||
vllm_runner,
|
||||
model,
|
||||
chunked_prefill_token_size: int,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype="half",
|
||||
max_logprobs=5,
|
||||
gpu_memory_utilization=0.5,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
|
||||
vllm_sampling_params = SamplingParams(max_tokens=10,
|
||||
logprobs=5,
|
||||
prompt_logprobs=5,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
for idx, result in enumerate(vllm_results):
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
|
||||
# Compared detokenized prompts ids to original prompt.
|
||||
generated_string = ""
|
||||
for (prompt_token,
|
||||
prompt_logprobs) in zip(result.prompt_token_ids[1:],
|
||||
result.prompt_logprobs[1:]):
|
||||
# prompt_logprobs is a dict of the token_id: logprob
|
||||
# We select the token_id corresponding to the actual prompt
|
||||
# Decoded token in the detokenized string corresponding to this
|
||||
# prompt token.
|
||||
generated_string += prompt_logprobs[prompt_token].decoded_token
|
||||
|
||||
assert generated_string == example_prompts[idx], (
|
||||
"Detokenized prompt logprobs do not match original prompt")
|
||||
|
@ -60,14 +60,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||
output = outputs[0]
|
||||
prompt_logprobs = output.prompt_logprobs
|
||||
|
||||
# If this is the first (or only) "chunk" of the prefill, we need
|
||||
# to prepend None to the list of prompt logprobs. The reason for this
|
||||
# is that for N prompt tokens, the Sampler will generate N-1 total
|
||||
# prompt logprobs during prefill since the token at idx 0 will not
|
||||
# have a logprob associated with it.
|
||||
if prompt_logprobs is not None:
|
||||
if not seq_group.prompt_logprobs:
|
||||
prompt_logprobs = [None] + prompt_logprobs
|
||||
seq_group.prompt_logprobs = []
|
||||
|
||||
if seq_group.sampling_params.detokenize and self.detokenizer:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
if not seq_group.prompt_logprobs:
|
||||
# The first prompt token's logprob is None because it doesn't
|
||||
# have tokens that are precedent.
|
||||
seq_group.prompt_logprobs = [None]
|
||||
seq_group,
|
||||
prompt_logprobs,
|
||||
position_offset=len(seq_group.prompt_logprobs))
|
||||
|
||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
|
@ -21,14 +21,17 @@ class Detokenizer:
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def decode_prompt_logprobs_inplace(
|
||||
self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
|
||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: List[Optional[Dict[
|
||||
int, Logprob]]],
|
||||
position_offset: int) -> None:
|
||||
"""Decodes the logprobs for the prompt of a sequence group.
|
||||
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
relative to the start of the sequence (for chunked prefill).
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
@ -47,8 +50,13 @@ class Detokenizer:
|
||||
next_iter_tokens: List[str] = []
|
||||
prev_tokens = None
|
||||
|
||||
for token_position, prompt_logprobs_for_token in enumerate(
|
||||
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
|
||||
prompt_logprobs):
|
||||
|
||||
# Absolute token position equals the index in the logprobs
|
||||
# list plus the offset of the entire logprobs list relative
|
||||
# to the start of the sequence.
|
||||
token_position = token_position_in_logprob + position_offset
|
||||
if not prompt_logprobs_for_token:
|
||||
continue
|
||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||
|
Loading…
x
Reference in New Issue
Block a user