[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-08-08 23:08:46 -06:00 committed by GitHub
parent e02ac55617
commit 99b4cf5f23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 5 deletions

View File

@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the
correctess for the target model outputs. correctess for the target model outputs.
""" """
from unittest.mock import patch
import pytest import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from .conftest import (run_equality_correctness_test, from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test) run_greedy_equality_correctness_test)
@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality when the vocab dimension is padded
"""
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{

View File

@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
if logits is not None: if logits is not None:
logits = logits[:, :self.org_vocab_size] logits = logits[..., :self.org_vocab_size]
return logits return logits
def extra_repr(self) -> str: def extra_repr(self) -> str:

View File

@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds # Only perform shape/dtype/device checking in strict mode, as it adds
# overhead. # overhead.
if self._strict_mode: if self._strict_mode:
self._raise_if_incorrect_input(target_probs, bonus_token_ids, self._raise_if_incorrect_input(target_probs, draft_token_ids,
draft_probs, draft_token_ids) bonus_token_ids, draft_probs)
accepted, recovered_token_ids = ( accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling( self._batch_modified_rejection_sampling(

View File

@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
states.add_(z, alpha=self.emb_weight / self.state_weight) states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states, logits = self.logits_processor(self.head[head_index], states,
sampling_metadata) sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata) output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids last_tokens = output.sampled_token_ids
next_tokens.append(output) next_tokens.append(output)