[Bugfix][SpecDecode] apply sampling parameters to target probabilities for consistency in rejection sampling. (#10198)
Signed-off-by: jeongin601 <0200angela@gmail.com> Signed-off-by: jeong_in.bae <jeong_in.bae@navercorp.com>
This commit is contained in:
parent
0a4d968500
commit
1bf905ddaa
@ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
@ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int):
|
||||
)
|
||||
|
||||
assert output.request_id == input_seq_group_metadata.request_id
|
||||
assert output.sampling_params.repetition_penalty == \
|
||||
input_seq_group_metadata.sampling_params.repetition_penalty
|
||||
assert output.sampling_params.temperature == \
|
||||
input_seq_group_metadata.sampling_params.temperature
|
||||
assert output.sampling_params.top_p == \
|
||||
input_seq_group_metadata.sampling_params.top_p
|
||||
assert output.sampling_params.top_k == \
|
||||
input_seq_group_metadata.sampling_params.top_k
|
||||
assert len(output.seq_data) == 1
|
||||
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
|
||||
prompt_tokens)
|
||||
|
@ -307,28 +307,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
# Use simpler sampling parameters apart from for final token
|
||||
# (in particular don't do seeded sampling) since those sampled tokens
|
||||
# aren't used.
|
||||
# We don't replace the sampling_params in the greedy case because
|
||||
# this also controls whether the probs get modified in the sampler
|
||||
# (see use of _modify_greedy_probs_inplace there).
|
||||
sampling_params = input_seq_group_metadata.sampling_params
|
||||
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
|
||||
if sampling_params.temperature else sampling_params
|
||||
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
last_index = len(token_ids_to_score) - 1
|
||||
for i, token_ids in enumerate(token_ids_to_score):
|
||||
target_sampling_params = sampling_params if i == last_index \
|
||||
else non_bonus_sampling_params
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
sampling_params=target_sampling_params,
|
||||
sampling_params=sampling_params,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
Loading…
x
Reference in New Issue
Block a user