diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 20f50888..25067e7a 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the correctess for the target model outputs. """ +from unittest.mock import patch + import pytest +from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size + from .conftest import (run_equality_correctness_test, run_greedy_equality_correctness_test) @@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality when the vocab dimension is padded + """ + + # Default pad_to is 64, test model has vocab_size of 32000 + def patched_pad_vocab_size(vocab_size, pad_to=None): + return pad_vocab_size(vocab_size, pad_to=32064) + + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", + patched_pad_vocab_size): + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index bd3e7e11..80534acd 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module): logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[..., :self.org_vocab_size] return logits def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 533b4363..2124196d 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) + self._raise_if_incorrect_input(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) accepted, recovered_token_ids = ( self._batch_modified_rejection_sampling( diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 95a655fb..9b96ecb7 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module): states.add_(z, alpha=self.emb_weight / self.state_weight) states = self.activation(self.ln[head_index](states)) # b k d - # TODO: not yet supporting top_k_tokens_per_head previous_hidden_states = states + # TODO: not yet supporting top_k_tokens_per_head + states = states.flatten(0, 1) logits = self.logits_processor(self.head[head_index], states, sampling_metadata) - output = self.sampler(logits.flatten(0, 1), sampling_metadata) + output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids next_tokens.append(output)