[Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to True (#5226)
This commit is contained in:
parent
fee4dcc33a
commit
974fc9b845
@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
num_top_logprobs: int,
|
||||
detokenize: bool,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
@ -48,7 +50,8 @@ def test_get_prompt_logprobs(
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_top_logprobs,
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
@ -65,11 +68,16 @@ def test_get_prompt_logprobs(
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens.append(
|
||||
top_logprob.decoded_token)
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens)
|
||||
assert output_text == output_string_from_most_likely_tokens, (
|
||||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
|
||||
if detokenize:
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens)
|
||||
assert output_text == output_string_from_most_likely_tokens, (
|
||||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
else:
|
||||
assert output_text == ''
|
||||
assert output_string_from_most_likely_tokens == [None] * max_tokens
|
||||
|
||||
# The first prompt logprob is always None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
@ -98,9 +106,10 @@ def test_get_prompt_logprobs(
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is returned "
|
||||
" to the user.")
|
||||
if detokenize:
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is returned"
|
||||
" to the user.")
|
||||
|
||||
# Test if prompt logprobs are correctly set.
|
||||
for vllm_result in vllm_results:
|
||||
|
@ -60,10 +60,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||
output = outputs[0]
|
||||
prompt_logprobs = output.prompt_logprobs
|
||||
if (prompt_logprobs is not None
|
||||
and seq_group.sampling_params.detokenize and self.detokenizer):
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
if prompt_logprobs is not None:
|
||||
if seq_group.sampling_params.detokenize and self.detokenizer:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
if not seq_group.prompt_logprobs:
|
||||
# The first prompt token's logprob is None because it doesn't
|
||||
# have tokens that are precedent.
|
||||
|
Loading…
x
Reference in New Issue
Block a user