[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (#12304)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-02-17 11:32:26 +08:00 committed by GitHub
parent 2010f04c17
commit 46cdd59577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 273 additions and 18 deletions

View File

@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4, tensor_parallel_size=4,
speculative_model="path/to/modified/eagle/model", speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
speculative_draft_tensor_parallel_size=1, speculative_draft_tensor_parallel_size=1,
) )
@ -190,14 +190,12 @@ for output in outputs:
A few important things to consider when using the EAGLE based draft models: A few important things to consider when using the EAGLE based draft models:
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be 1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
used directly with vLLM due to differences in the expected layer names and model definition. be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
to convert them. Note that this script does not modify the model's weights. [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
In the above example, use the script to first convert the latest version of vLLM, please leave a comment or raise an issue.
the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model
and then use the converted checkpoint as the draft model in vLLM.
2. The EAGLE based draft models need to be run without tensor parallelism 2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although (i.e. speculative_draft_tensor_parallel_size is set to 1), although

View File

@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed) batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Llama-2-7b-chat-hf",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "Qwen/Qwen2-7B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
pytest.main([__file__]) pytest.main([__file__])

View File

@ -13,15 +13,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from vllm.worker.worker import Worker
from .test_utils import mock_spec_decode_sampler from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker from .utils import (create_batch, create_sampler_output_list, create_worker,
mock_worker)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
worker.execute_model(execute_model_req=execute_model_req) worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted # but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1 assert draft_worker.get_spec_proposals.call_count == 1
def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed = 100
block_size = 32
num_gpu_blocks = 8096 // block_size
target_worker = create_worker(
Worker,
"JackFram/llama-68m",
block_size,
num_gpu_blocks,
seed,
)
draft_worker = create_worker(
MultiStepWorker,
"abhigoyal/vllm-eagle-llama-68m-random",
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False)
worker.proposer_worker.maybe_load_lm_head_weight(
target_worker.model_runner.model.lm_head.weight.data)
assert torch.allclose(
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
worker.scorer_worker.model_runner.model.lm_head.weight.data)

View File

@ -1833,6 +1833,15 @@ class SpeculativeConfig:
draft_hf_config = draft_model_config.hf_config draft_hf_config = draft_model_config.hf_config
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config
if (num_speculative_tokens is not None if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")): and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens draft_hf_config.num_lookahead_tokens = num_speculative_tokens

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix from .utils import maybe_prefix
logger = init_logger(__name__)
class DummyInputLayerNorm(nn.Module): class DummyInputLayerNorm(nn.Module):
@ -190,8 +193,8 @@ class EAGLE(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(self.fc.bias, loaded_weight) weight_loader(self.fc.bias, loaded_weight)
else: else:
raise ValueError("Found bias in the loaded weights " logger.warning_once("Found bias in the loaded weights but "
"but the model config doesn't have bias") "the model config doesn't have bias.")
elif name.startswith("model.lm_head.") or name.startswith( elif name.startswith("model.lm_head.") or name.startswith(
"model.model."): "model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight model_weights[name.split("model.", 1)[-1]] = loaded_weight
@ -200,12 +203,21 @@ class EAGLE(nn.Module):
else: else:
model_weights[f"model.{name}"] = loaded_weight model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight") if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\ if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]: lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map] lm_head_weight = lm_head_weight[self.token_map]
else:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size,
self.lm_head.embedding_dim,
dtype=self.config.torch_dtype,
)
weight_loader = getattr(self.lm_head.weight, "weight_loader", weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
@ -386,3 +387,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
execute_model_req.seq_group_metadata_list): execute_model_req.seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support beam search.") "MultiStepWorker does not support beam search.")
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)

View File

@ -10,6 +10,7 @@ from vllm.distributed.parallel_state import (get_tp_group,
patch_tensor_parallel_group) patch_tensor_parallel_group)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
@ -173,3 +174,21 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self._worker.vocab_size return self._worker.vocab_size
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
if self._is_dummy:
return
with self._patch_tensor_parallel_group():
weight_loader = getattr(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
"weight_loader",
default_weight_loader)
weight_loader(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
lm_head_weight)

View File

@ -9,7 +9,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import (broadcast_tensor_dict,
tensor_model_parallel_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -155,6 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
ngram_prompt_lookup_max = ( ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
@ -187,6 +189,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"EAGLE does not support TP > 1 yet") "EAGLE does not support TP > 1 yet")
allow_zero_draft_token_step = False allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
@ -239,7 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler, spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step) allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load)
def __init__( def __init__(
self, self,
@ -252,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True, allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
@ -282,6 +291,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814) draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
@ -291,6 +302,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.disable_by_batch_size = disable_by_batch_size or float("inf") self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._enable_lm_head_weight_load = enable_lm_head_weight_load
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler self.spec_decode_sampler
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
@ -327,6 +339,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.load_model() self.scorer_worker.load_model()
self.proposer_worker.load_model() self.proposer_worker.load_model()
if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)
self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device) self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank, self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device) device_type=self.device)