[Bugfix] spec decode handle None entries in topk args in create_sequence_group_output (#7232)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
955b5191c9
commit
cc0eaf12b1
@ -343,3 +343,78 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
|||||||
b=baseline_rank_to_logprob[rank],
|
b=baseline_rank_to_logprob[rank],
|
||||||
abs_tol=1e-1,
|
abs_tol=1e-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"max_logprobs": 6,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
|
||||||
|
"""Check the behavior when logprobs are disabled.
|
||||||
|
Token choices should match with the base model.
|
||||||
|
"""
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
"San Francisco is know for its",
|
||||||
|
"Facebook was created in 2004 by",
|
||||||
|
"Curious George is a",
|
||||||
|
"Python 3.11 brings improvements to its",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
# Use smaller output len for fast test
|
||||||
|
max_tokens=7,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||||
|
test_llm_generator, prompts, sampling_params)
|
||||||
|
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
||||||
|
baseline_llm_generator, prompts, sampling_params)
|
||||||
|
|
||||||
|
assert len(baseline_batch_logprobs) == len(prompts)
|
||||||
|
assert len(spec_batch_logprobs) == len(prompts)
|
||||||
|
|
||||||
|
# For each sequence in the batch.
|
||||||
|
for _, (baseline_logprobs, spec_logprobs) in enumerate(
|
||||||
|
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
||||||
|
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||||
|
|
||||||
|
# For each generated position of the sequence.
|
||||||
|
for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||||
|
zip(spec_logprobs, baseline_logprobs)):
|
||||||
|
|
||||||
|
assert len(spec_pos_logprobs) == 1
|
||||||
|
spec_top_token_id = list(spec_pos_logprobs)[0]
|
||||||
|
|
||||||
|
spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
|
||||||
|
assert spec_top_logprob.logprob == 0.0
|
||||||
|
assert spec_top_logprob.rank == -1
|
||||||
|
|
||||||
|
# check that the chosen token matches the base model
|
||||||
|
baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
|
||||||
|
assert baseline_logprob.rank == 1
|
||||||
|
assert spec_top_logprob.decoded_token \
|
||||||
|
== baseline_logprob.decoded_token
|
||||||
|
@ -64,23 +64,25 @@ def create_sequence_group_output(
|
|||||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||||
token_id_logprob (float): The logprob value of the sampled token.
|
token_id_logprob (float): The logprob value of the sampled token.
|
||||||
seq_id (int): The sequence id.
|
seq_id (int): The sequence id.
|
||||||
topk_token_ids (List[int]): The list of top-k token ids.
|
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||||
topk_logprobs (List[float]): The list of top-k logprobs.
|
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||||
"""
|
"""
|
||||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||||
logprobs: Dict[Optional[int], Logprob] = {
|
logprobs: Dict[int, Logprob] = {
|
||||||
token_id: Logprob(
|
token_id: Logprob(
|
||||||
logprob=token_id_logprob,
|
logprob=token_id_logprob,
|
||||||
rank=token_id_logprob_rank,
|
rank=token_id_logprob_rank,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
logprobs.update({
|
logprobs.update({
|
||||||
topk_token_ids[topk_logprob_index]: Logprob(
|
topk_token_id: Logprob(
|
||||||
logprob=topk_logprobs[topk_logprob_index],
|
logprob=topk_logprob if topk_logprob is not None else 0.0,
|
||||||
rank=topk_logprob_index + 1,
|
rank=topk_index + 1,
|
||||||
)
|
)
|
||||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
for topk_index, (topk_token_id, topk_logprob) \
|
||||||
|
in enumerate(zip(topk_token_ids, topk_logprobs)) \
|
||||||
|
if topk_token_id is not None
|
||||||
})
|
})
|
||||||
|
|
||||||
return CompletionSequenceGroupOutput(
|
return CompletionSequenceGroupOutput(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user