[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lucia Fang 2025-02-19 01:06:23 -08:00 committed by GitHub
parent 983a40a8bb
commit f525c0be8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 727 additions and 46 deletions

View File

@ -2,7 +2,7 @@
# adding a new command to an existing step. See different options here for examples.
# This script will be feed into Jinja template in `test-template-aws.j2` at
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
# to generate the final pipeline yaml file.
# Documentation
@ -15,7 +15,7 @@
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
# in this case, commands must be specified. the first command runs on first host, the second
# command runs on the second host.
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
@ -24,8 +24,8 @@
# When adding a test
# - If the test belong to an existing group, add it there
# - If the test is short, add to any existing step
# - If the test takes more than 10min, then it is okay to create a new step.
# Note that all steps execute in parallel.
# - If the test takes more than 10min, then it is okay to create a new step.
# Note that all steps execute in parallel.
steps:
##### fast check tests #####
@ -145,14 +145,14 @@ steps:
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- label: Metrics, Tracing Test # 10min
num_gpus: 2
num_gpus: 2
fast_check: true
source_file_dependencies:
- vllm/
- tests/metrics
- tests/tracing
commands:
- pytest -v -s metrics
- pytest -v -s metrics
- "pip install \
'opentelemetry-sdk>=1.26.0,<1.27.0' \
'opentelemetry-api>=1.26.0,<1.27.0' \
@ -254,7 +254,7 @@ steps:
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
- tests/model_executor/test_guided_processors
commands:
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py
@ -265,7 +265,7 @@ steps:
- vllm/model_executor/models/eagle.py
commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
- label: LoRA Test %N # 15min each
@ -580,7 +580,7 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
@ -605,7 +605,7 @@ steps:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
##### multi gpus test #####
@ -617,7 +617,7 @@ steps:
num_gpus: 4
source_file_dependencies:
- vllm/
commands:
commands:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py

View File

@ -296,6 +296,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
}
_FALLBACK_MODEL = {

View File

@ -0,0 +1,318 @@
# SPDX-License-Identifier: Apache-2.0
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -763,7 +763,7 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int:
@ -856,8 +856,12 @@ class ModelConfig:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
@ -1689,6 +1693,18 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
@ -1771,12 +1787,18 @@ class SpeculativeConfig:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None
if target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
else:
return None
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
@ -1830,6 +1852,7 @@ class SpeculativeConfig:
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
draft_hf_config = draft_model_config.hf_config
@ -1846,7 +1869,6 @@ class SpeculativeConfig:
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
@ -1960,8 +1982,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size

View File

@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .utils import maybe_prefix
class SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return self.shared_head(hidden_states)
class DeepSeekMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
DeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids,
positions,
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
return logits
class DeepSeekMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name

View File

@ -732,13 +732,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if "rotary_emb.inv_freq" in name:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@ -805,3 +801,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None

View File

@ -187,6 +187,7 @@ _MULTIMODAL_MODELS = {
_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

View File

@ -1307,6 +1307,8 @@ class ExecuteModelRequest(
previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run.
num_steps: int = 1
# The step index for spec model input.
spec_step_idx: Optional[int] = None
# Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.

View File

@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
return False
# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "FLASH_ATTN":
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
return False
# TODO: Add support for LORA
@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
kwargs = {"previous_hidden_states": hidden_states} \
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}
compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable(
@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
**model_execute_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,

View File

@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens,
)
return spec_decode_worker
@ -153,10 +154,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
@ -179,14 +182,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
"deepseek_mtp":
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
@ -195,6 +200,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
@ -247,7 +254,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load)
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
def __init__(
self,
@ -261,6 +269,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False,
num_spec_prefill_steps: int = 1,
):
"""
Create a SpecDecodeWorker.
@ -293,6 +302,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
@ -326,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
self._num_spec_prefill_steps = num_spec_prefill_steps
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@ -685,8 +699,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
self.proposer_worker.execute_model(execute_model_req)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)

View File

@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
virtual_engine: int = 0
async_callback: Optional[Callable] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
previous_hidden_states: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@ -1649,6 +1650,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
@ -1706,6 +1708,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
@ -1723,7 +1729,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
**seqlen_agnostic_kwargs,
**model_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
@ -1815,7 +1823,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
1. current vLLM instance is KV cache consumer/decode vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
@ -1840,7 +1848,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
1. current vLLM instance is KV cache producer/prefill vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
@ -1976,7 +1984,11 @@ class CUDAGraphRunner(nn.Module):
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
if positions is not None:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
# in some case like MLA, it will reuse positions in metadata
# but truncate them to the original size
# so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(

View File

@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
if field.name == "input_positions":
valid_attn_kwargs[field.name] = tensor_dict[field.name]
else:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata

View File

@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner

View File

@ -397,6 +397,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps
if (execute_model_req is not None and execute_model_req.spec_step_idx):
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
self.execute_worker(worker_input)