[Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to True (#5226)
This commit is contained in:
parent
fee4dcc33a
commit
974fc9b845
@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
|
|||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||||
|
@pytest.mark.parametrize("detokenize", [True, False])
|
||||||
def test_get_prompt_logprobs(
|
def test_get_prompt_logprobs(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
|
|||||||
dtype,
|
dtype,
|
||||||
chunked_prefill_token_size: int,
|
chunked_prefill_token_size: int,
|
||||||
num_top_logprobs: int,
|
num_top_logprobs: int,
|
||||||
|
detokenize: bool,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
):
|
):
|
||||||
max_num_seqs = 256
|
max_num_seqs = 256
|
||||||
@ -48,7 +50,8 @@ def test_get_prompt_logprobs(
|
|||||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||||
logprobs=num_top_logprobs,
|
logprobs=num_top_logprobs,
|
||||||
prompt_logprobs=num_top_logprobs,
|
prompt_logprobs=num_top_logprobs,
|
||||||
temperature=0.0)
|
temperature=0.0,
|
||||||
|
detokenize=detokenize)
|
||||||
vllm_results = vllm_model.model.generate(
|
vllm_results = vllm_model.model.generate(
|
||||||
example_prompts, sampling_params=vllm_sampling_params)
|
example_prompts, sampling_params=vllm_sampling_params)
|
||||||
|
|
||||||
@ -65,11 +68,16 @@ def test_get_prompt_logprobs(
|
|||||||
top_logprob = next(iter(top_logprobs.values()))
|
top_logprob = next(iter(top_logprobs.values()))
|
||||||
output_string_from_most_likely_tokens.append(
|
output_string_from_most_likely_tokens.append(
|
||||||
top_logprob.decoded_token)
|
top_logprob.decoded_token)
|
||||||
|
|
||||||
|
if detokenize:
|
||||||
output_string_from_most_likely_tokens = "".join(
|
output_string_from_most_likely_tokens = "".join(
|
||||||
output_string_from_most_likely_tokens)
|
output_string_from_most_likely_tokens)
|
||||||
assert output_text == output_string_from_most_likely_tokens, (
|
assert output_text == output_string_from_most_likely_tokens, (
|
||||||
"The output text from the top logprob for each token position "
|
"The output text from the top logprob for each token position "
|
||||||
"should be the same as the output text in the result.")
|
"should be the same as the output text in the result.")
|
||||||
|
else:
|
||||||
|
assert output_text == ''
|
||||||
|
assert output_string_from_most_likely_tokens == [None] * max_tokens
|
||||||
|
|
||||||
# The first prompt logprob is always None
|
# The first prompt logprob is always None
|
||||||
assert result.prompt_logprobs[0] is None
|
assert result.prompt_logprobs[0] is None
|
||||||
@ -98,8 +106,9 @@ def test_get_prompt_logprobs(
|
|||||||
hf_logprob[i][-1][token_id].item(),
|
hf_logprob[i][-1][token_id].item(),
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
rtol=1e-2)
|
rtol=1e-2)
|
||||||
|
if detokenize:
|
||||||
assert isinstance(sample_logprob.decoded_token, str), (
|
assert isinstance(sample_logprob.decoded_token, str), (
|
||||||
"The token should be decoded by the time it is returned "
|
"The token should be decoded by the time it is returned"
|
||||||
" to the user.")
|
" to the user.")
|
||||||
|
|
||||||
# Test if prompt logprobs are correctly set.
|
# Test if prompt logprobs are correctly set.
|
||||||
|
@ -60,8 +60,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||||
output = outputs[0]
|
output = outputs[0]
|
||||||
prompt_logprobs = output.prompt_logprobs
|
prompt_logprobs = output.prompt_logprobs
|
||||||
if (prompt_logprobs is not None
|
if prompt_logprobs is not None:
|
||||||
and seq_group.sampling_params.detokenize and self.detokenizer):
|
if seq_group.sampling_params.detokenize and self.detokenizer:
|
||||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq_group, prompt_logprobs)
|
seq_group, prompt_logprobs)
|
||||||
if not seq_group.prompt_logprobs:
|
if not seq_group.prompt_logprobs:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user