[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
e02ac55617
commit
99b4cf5f23
@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
from .conftest import (run_equality_correctness_test,
|
||||
run_greedy_equality_correctness_test)
|
||||
|
||||
@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
return pad_vocab_size(vocab_size, pad_to=32064)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
||||
patched_pad_vocab_size):
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
# Only perform shape/dtype/device checking in strict mode, as it adds
|
||||
# overhead.
|
||||
if self._strict_mode:
|
||||
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
|
||||
draft_probs, draft_token_ids)
|
||||
self._raise_if_incorrect_input(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
|
||||
accepted, recovered_token_ids = (
|
||||
self._batch_modified_rejection_sampling(
|
||||
|
@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
|
||||
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||
|
||||
states = self.activation(self.ln[head_index](states)) # b k d
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
previous_hidden_states = states
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
states = states.flatten(0, 1)
|
||||
|
||||
logits = self.logits_processor(self.head[head_index], states,
|
||||
sampling_metadata)
|
||||
|
||||
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
||||
output = self.sampler(logits, sampling_metadata)
|
||||
last_tokens = output.sampled_token_ids
|
||||
next_tokens.append(output)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user