[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
e02ac55617
commit
99b4cf5f23
@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the
|
|||||||
correctess for the target model outputs.
|
correctess for the target model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||||
|
|
||||||
from .conftest import (run_equality_correctness_test,
|
from .conftest import (run_equality_correctness_test,
|
||||||
run_greedy_equality_correctness_test)
|
run_greedy_equality_correctness_test)
|
||||||
|
|
||||||
@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"block_size": 8,
|
||||||
|
# 2 for small prompt, 256//8 for generated.
|
||||||
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality when the vocab dimension is padded
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default pad_to is 64, test model has vocab_size of 32000
|
||||||
|
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||||
|
return pad_vocab_size(vocab_size, pad_to=32064)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
||||||
|
patched_pad_vocab_size):
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
# Remove paddings in vocab (if any).
|
# Remove paddings in vocab (if any).
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = logits[:, :self.org_vocab_size]
|
logits = logits[..., :self.org_vocab_size]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||||
# overhead.
|
# overhead.
|
||||||
if self._strict_mode:
|
if self._strict_mode:
|
||||||
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
self._raise_if_incorrect_input(target_probs, draft_token_ids,
|
||||||
draft_probs, draft_token_ids)
|
bonus_token_ids, draft_probs)
|
||||||
|
|
||||||
accepted, recovered_token_ids = (
|
accepted, recovered_token_ids = (
|
||||||
self._batch_modified_rejection_sampling(
|
self._batch_modified_rejection_sampling(
|
||||||
|
@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
|
|||||||
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||||
|
|
||||||
states = self.activation(self.ln[head_index](states)) # b k d
|
states = self.activation(self.ln[head_index](states)) # b k d
|
||||||
# TODO: not yet supporting top_k_tokens_per_head
|
|
||||||
previous_hidden_states = states
|
previous_hidden_states = states
|
||||||
|
# TODO: not yet supporting top_k_tokens_per_head
|
||||||
|
states = states.flatten(0, 1)
|
||||||
|
|
||||||
logits = self.logits_processor(self.head[head_index], states,
|
logits = self.logits_processor(self.head[head_index], states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
|
||||||
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
output = self.sampler(logits, sampling_metadata)
|
||||||
last_tokens = output.sampled_token_ids
|
last_tokens = output.sampled_token_ids
|
||||||
next_tokens.append(output)
|
next_tokens.append(output)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user