[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (#12304)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-02-17 11:32:26 +08:00 committed by GitHub
parent 2010f04c17
commit 46cdd59577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 273 additions and 18 deletions

View File

@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4,
speculative_model="path/to/modified/eagle/model",
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
speculative_draft_tensor_parallel_size=1,
)
@ -190,14 +190,12 @@ for output in outputs:
A few important things to consider when using the EAGLE based draft models:
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be
used directly with vLLM due to differences in the expected layer names and model definition.
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d)
to convert them. Note that this script does not modify the model's weights.
In the above example, use the script to first convert
the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model
and then use the converted checkpoint as the draft model in vLLM.
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
the latest version of vLLM, please leave a comment or raise an issue.
2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although

View File

@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Llama-2-7b-chat-hf",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": "float16",
# Main model
"model_name": "Qwen/Qwen2-7B-Instruct",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -13,15 +13,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from vllm.worker.worker import Worker
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker
from .utils import (create_batch, create_sampler_output_list, create_worker,
mock_worker)
@pytest.mark.parametrize('k', [1, 2, 6])
@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1
def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed = 100
block_size = 32
num_gpu_blocks = 8096 // block_size
target_worker = create_worker(
Worker,
"JackFram/llama-68m",
block_size,
num_gpu_blocks,
seed,
)
draft_worker = create_worker(
MultiStepWorker,
"abhigoyal/vllm-eagle-llama-68m-random",
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False)
worker.proposer_worker.maybe_load_lm_head_weight(
target_worker.model_runner.model.lm_head.weight.data)
assert torch.allclose(
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
worker.scorer_worker.model_runner.model.lm_head.weight.data)

View File

@ -1833,6 +1833,15 @@ class SpeculativeConfig:
draft_hf_config = draft_model_config.hf_config
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = init_logger(__name__)
class DummyInputLayerNorm(nn.Module):
@ -190,8 +193,8 @@ class EAGLE(nn.Module):
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
@ -200,12 +203,21 @@ class EAGLE(nn.Module):
else:
model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight")
if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]
lm_head_weight = lm_head_weight[self.token_map]
else:
# NOTE(Shangming): initialize the placeholder for lm_head weight.
lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size,
self.lm_head.embedding_dim,
dtype=self.config.torch_dtype,
)
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
@ -386,3 +387,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)

View File

@ -10,6 +10,7 @@ from vllm.distributed.parallel_state import (get_tp_group,
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker
@ -173,3 +174,21 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
@property
def vocab_size(self) -> int:
return self._worker.vocab_size
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
if self._is_dummy:
return
with self._patch_tensor_parallel_group():
weight_loader = getattr(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
"weight_loader",
default_weight_loader)
weight_loader(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
lm_head_weight)

View File

@ -9,7 +9,8 @@ import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.distributed.communication_op import (broadcast_tensor_dict,
tensor_model_parallel_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
@ -155,6 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
@ -187,6 +189,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"EAGLE does not support TP > 1 yet")
allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
@ -239,7 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load)
def __init__(
self,
@ -252,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False,
):
"""
Create a SpecDecodeWorker.
@ -282,6 +291,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
@ -291,6 +302,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._enable_lm_head_weight_load = enable_lm_head_weight_load
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
@ -327,6 +339,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.load_model()
self.proposer_worker.load_model()
if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)
self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)