[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-08-08 23:08:46 -06:00 committed by GitHub
parent e02ac55617
commit 99b4cf5f23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 5 deletions

View File

@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the
correctess for the target model outputs.
"""
from unittest.mock import patch
import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)
@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality when the vocab dimension is padded
"""
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
logits = logits[..., :self.org_vocab_size]
return logits
def extra_repr(self) -> str:

View File

@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(

View File

@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)