[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
983a40a8bb
commit
f525c0be8b
@ -2,7 +2,7 @@
|
||||
# adding a new command to an existing step. See different options here for examples.
|
||||
|
||||
# This script will be feed into Jinja template in `test-template-aws.j2` at
|
||||
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
||||
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
||||
# to generate the final pipeline yaml file.
|
||||
|
||||
# Documentation
|
||||
@ -15,7 +15,7 @@
|
||||
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
||||
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
|
||||
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
|
||||
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
||||
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
||||
# in this case, commands must be specified. the first command runs on first host, the second
|
||||
# command runs on the second host.
|
||||
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
|
||||
@ -24,8 +24,8 @@
|
||||
# When adding a test
|
||||
# - If the test belong to an existing group, add it there
|
||||
# - If the test is short, add to any existing step
|
||||
# - If the test takes more than 10min, then it is okay to create a new step.
|
||||
# Note that all steps execute in parallel.
|
||||
# - If the test takes more than 10min, then it is okay to create a new step.
|
||||
# Note that all steps execute in parallel.
|
||||
|
||||
steps:
|
||||
##### fast check tests #####
|
||||
@ -145,14 +145,14 @@ steps:
|
||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
- tests/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||
@ -254,7 +254,7 @@ steps:
|
||||
- vllm/model_executor/guided_decoding
|
||||
- tests/test_logits_processor
|
||||
- tests/model_executor/test_guided_processors
|
||||
commands:
|
||||
commands:
|
||||
- pytest -v -s test_logits_processor.py
|
||||
- pytest -v -s model_executor/test_guided_processors.py
|
||||
|
||||
@ -265,7 +265,7 @@ steps:
|
||||
- vllm/model_executor/models/eagle.py
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
@ -580,7 +580,7 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
@ -605,7 +605,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
|
||||
|
||||
##### multi gpus test #####
|
||||
@ -617,7 +617,7 @@ steps:
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
commands:
|
||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
|
@ -296,6 +296,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
|
||||
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
||||
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
|
||||
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
|
318
tests/spec_decode/e2e/test_mtp_correctness.py
Normal file
318
tests/spec_decode/e2e/test_mtp_correctness.py
Normal file
@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, mtp would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 1
|
||||
|
||||
# precision
|
||||
PRECISION = "bfloat16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
},
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"num_speculative_tokens": k,
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.9
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
@ -763,7 +763,7 @@ class ModelConfig:
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3'))\
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
@ -856,8 +856,12 @@ class ModelConfig:
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
if self.hf_text_config.model_type == "deepseek_mtp":
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
@ -1689,6 +1693,18 @@ class SpeculativeConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["DeepSeekMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
@ -1771,12 +1787,18 @@ class SpeculativeConfig:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
|
||||
if speculative_model is None:
|
||||
if num_speculative_tokens is not None:
|
||||
raise ValueError("num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
return None
|
||||
if target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
# use the draft model from the same model:
|
||||
speculative_model = target_model_config.model
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
else:
|
||||
return None
|
||||
|
||||
if (speculative_disable_by_batch_size is not None
|
||||
and speculative_disable_by_batch_size < 2):
|
||||
@ -1830,6 +1852,7 @@ class SpeculativeConfig:
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
@ -1846,7 +1869,6 @@ class SpeculativeConfig:
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
|
||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
||||
if n_predict is not None:
|
||||
if num_speculative_tokens is None:
|
||||
@ -1960,8 +1982,9 @@ class SpeculativeConfig:
|
||||
speculative_draft_tensor_parallel_size = 1
|
||||
if target_parallel_config.tensor_parallel_size > 1:
|
||||
logger.warning(
|
||||
"MLPSpeculator cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1")
|
||||
"%s cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1",
|
||||
draft_hf_config.model_type)
|
||||
else:
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
target_parallel_config.tensor_parallel_size
|
||||
|
284
vllm/model_executor/models/deepseek_mtp.py
Normal file
284
vllm/model_executor/models/deepseek_mtp.py
Normal file
@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||
get_spec_layer_idx_from_weight_name)
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
|
||||
cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return self.shared_head(hidden_states)
|
||||
|
||||
|
||||
class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
DeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches[spec_step_idx],
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
spec_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class DeepSeekMTP(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.model.compute_logits(hidden_states, sampling_metadata,
|
||||
spec_step_idx)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
spec_layer_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
return name
|
@ -732,13 +732,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# TODO(simon): support nextn predict layers
|
||||
if hasattr(self.config, "num_nextn_predict_layers"
|
||||
) and self.config.num_nextn_predict_layers > 0:
|
||||
assert self.config.num_nextn_predict_layers == 1
|
||||
layer_idx = self.config.num_hidden_layers
|
||||
if name.startswith(f"model.layers.{layer_idx}"):
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@ -805,3 +801,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
@ -187,6 +187,7 @@ _MULTIMODAL_MODELS = {
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
@ -1307,6 +1307,8 @@ class ExecuteModelRequest(
|
||||
previous_hidden_states: Optional[HiddenStates] = None
|
||||
# The number of forward steps to run.
|
||||
num_steps: int = 1
|
||||
# The step index for spec model input.
|
||||
spec_step_idx: Optional[int] = None
|
||||
# Finished request ids since last step.
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
|
@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
return False
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() != "FLASH_ATTN":
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
for step in range(num_steps):
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
kwargs = {"previous_hidden_states": hidden_states} \
|
||||
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
compute_logits_kwargs = {}
|
||||
# Run model
|
||||
if hasattr(self.model.config, "num_nextn_predict_layers"):
|
||||
# for DeepSeek MTP only to use the corresponding layer for
|
||||
# each step
|
||||
spec_step_idx = kwargs.get("spec_step_idx", step)
|
||||
model_execute_kwargs["spec_step_idx"] = spec_step_idx
|
||||
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**kwargs,
|
||||
**model_execute_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
model_input.sampling_metadata,
|
||||
**compute_logits_kwargs)
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
|
@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=speculative_config.disable_logprobs,
|
||||
disable_log_stats=speculative_config.disable_log_stats,
|
||||
num_speculative_tokens=speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
return spec_decode_worker
|
||||
@ -153,10 +154,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
num_speculative_tokens: int,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
enable_lm_head_weight_load = False
|
||||
num_spec_prefill_steps = 1
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
@ -179,14 +182,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
elif draft_model_config.hf_config.model_type == "medusa":
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
|
||||
"deepseek_mtp":
|
||||
if current_platform.is_cuda_alike():
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
"EAGLE does not support TP > 1 yet")
|
||||
f"{draft_model_config.hf_config.model_type} "
|
||||
"does not support TP > 1 yet")
|
||||
|
||||
allow_zero_draft_token_step = False
|
||||
|
||||
@ -195,6 +200,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
enable_lm_head_weight_load = True
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||
num_spec_prefill_steps = num_speculative_tokens
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
@ -247,7 +254,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
||||
enable_lm_head_weight_load=enable_lm_head_weight_load)
|
||||
enable_lm_head_weight_load=enable_lm_head_weight_load,
|
||||
num_spec_prefill_steps=num_spec_prefill_steps)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -261,6 +269,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
enable_lm_head_weight_load: Optional[bool] = False,
|
||||
num_spec_prefill_steps: int = 1,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
@ -293,6 +302,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft model is larger than 1 (TODO: #5814)
|
||||
enable_lm_head_weight_load: whether to load lm_head weight for
|
||||
draft models like eagle.
|
||||
num_spec_prefill_steps: number of speculative prefill steps to run
|
||||
before the speculative decoding starts. This is only used when
|
||||
the draft model is a deepseek_mtp model that requires prefill
|
||||
kv cache separately for each MTP layer.
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
@ -326,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||
self._disable_logprobs = disable_logprobs
|
||||
self._disable_log_stats = disable_log_stats
|
||||
self._num_spec_prefill_steps = num_spec_prefill_steps
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@ -685,8 +699,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
for i in range(self._num_spec_prefill_steps):
|
||||
execute_model_req.spec_step_idx = i
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
|
@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
virtual_engine: int = 0
|
||||
async_callback: Optional[Callable] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
previous_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -1649,6 +1650,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
@ -1706,6 +1708,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
@ -1723,7 +1729,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
@ -1815,7 +1823,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
1. current vLLM instance is KV cache consumer/decode vLLM instance
|
||||
2. this batch is not a profiling run
|
||||
3. this batch is a prefill run
|
||||
|
||||
|
||||
Args:
|
||||
model_input: input to the model executable
|
||||
kv_caches: vLLM's paged memory
|
||||
@ -1840,7 +1848,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
1. current vLLM instance is KV cache producer/prefill vLLM instance
|
||||
2. this batch is not a profiling run
|
||||
3. this batch is a prefill run
|
||||
|
||||
|
||||
Args:
|
||||
model_input: input to the model executable
|
||||
kv_caches: vLLM's paged memory
|
||||
@ -1976,7 +1984,11 @@ class CUDAGraphRunner(nn.Module):
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
if positions is not None:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
# in some case like MLA, it will reuse positions in metadata
|
||||
# but truncate them to the original size
|
||||
# so the shape is not padded, we need to copy partial only
|
||||
self.input_buffers["positions"][:positions.shape[0]].copy_(
|
||||
positions, non_blocking=True)
|
||||
|
||||
if self.backend_name != "NO_ATTENTION":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
|
@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
|
||||
valid_attn_kwargs = {}
|
||||
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||
if field.name in tensor_dict:
|
||||
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
||||
if field.name == "input_positions":
|
||||
valid_attn_kwargs[field.name] = tensor_dict[field.name]
|
||||
else:
|
||||
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
||||
|
||||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||
tensor_dict["attn_metadata"] = attn_metadata
|
||||
|
@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||
model_config.hf_config.model_type) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
|
@ -397,6 +397,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
if (execute_model_req is not None and execute_model_req.spec_step_idx):
|
||||
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user