[Feature][Spec Decode] Simplify the use of Eagle Spec Decode (#12304)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
2010f04c17
commit
46cdd59577
@ -175,7 +175,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
speculative_model="path/to/modified/eagle/model",
|
||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
speculative_draft_tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -190,14 +190,12 @@ for output in outputs:
|
||||
|
||||
A few important things to consider when using the EAGLE based draft models:
|
||||
|
||||
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) cannot be
|
||||
used directly with vLLM due to differences in the expected layer names and model definition.
|
||||
To use these models with vLLM, use the [following script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d)
|
||||
to convert them. Note that this script does not modify the model's weights.
|
||||
|
||||
In the above example, use the script to first convert
|
||||
the [yuhuili/EAGLE-LLaMA3-Instruct-8B](https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B) model
|
||||
and then use the converted checkpoint as the draft model in vLLM.
|
||||
1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should
|
||||
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
|
||||
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
|
||||
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
|
||||
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
|
||||
the latest version of vLLM, please leave a comment or raise an issue.
|
||||
|
||||
2. The EAGLE based draft models need to be run without tensor parallelism
|
||||
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
|
||||
|
@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": "float16",
|
||||
|
||||
# Main model
|
||||
"model_name": "Qwen/Qwen2-7B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
@ -13,15 +13,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
from .utils import (create_batch, create_sampler_output_list, create_worker,
|
||||
mock_worker)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@ -905,3 +908,38 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
|
||||
def test_correctly_load_weight_for_eagle():
|
||||
"""
|
||||
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
|
||||
"""
|
||||
seed = 100
|
||||
block_size = 32
|
||||
num_gpu_blocks = 8096 // block_size
|
||||
target_worker = create_worker(
|
||||
Worker,
|
||||
"JackFram/llama-68m",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
draft_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
"abhigoyal/vllm-eagle-llama-68m-random",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False)
|
||||
worker.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_worker.model_runner.model.lm_head.weight.data)
|
||||
assert torch.allclose(
|
||||
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
|
||||
worker.scorer_worker.model_runner.model.lm_head.weight.data)
|
||||
|
@ -1833,6 +1833,15 @@ class SpeculativeConfig:
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
|
||||
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
|
||||
if "eagle-" in draft_model_config.model.lower():
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
if isinstance(draft_model_config.hf_config, EAGLEConfig):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(draft_model_config.hf_config)
|
||||
draft_model_config.hf_config = eagle_config
|
||||
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -18,6 +19,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DummyInputLayerNorm(nn.Module):
|
||||
|
||||
@ -190,8 +193,8 @@ class EAGLE(nn.Module):
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.bias, loaded_weight)
|
||||
else:
|
||||
raise ValueError("Found bias in the loaded weights "
|
||||
"but the model config doesn't have bias")
|
||||
logger.warning_once("Found bias in the loaded weights but "
|
||||
"the model config doesn't have bias.")
|
||||
elif name.startswith("model.lm_head.") or name.startswith(
|
||||
"model.model."):
|
||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||
@ -200,12 +203,21 @@ class EAGLE(nn.Module):
|
||||
else:
|
||||
model_weights[f"model.{name}"] = loaded_weight
|
||||
|
||||
lm_head_weight = model_weights.pop("lm_head.weight")
|
||||
if "lm_head.weight" in model_weights:
|
||||
lm_head_weight = model_weights.pop("lm_head.weight")
|
||||
|
||||
if self.token_map is not None and\
|
||||
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
||||
if self.token_map is not None and\
|
||||
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
||||
|
||||
lm_head_weight = lm_head_weight[self.token_map]
|
||||
lm_head_weight = lm_head_weight[self.token_map]
|
||||
|
||||
else:
|
||||
# NOTE(Shangming): initialize the placeholder for lm_head weight.
|
||||
lm_head_weight = torch.zeros(
|
||||
self.lm_head.org_vocab_size,
|
||||
self.lm_head.embedding_dim,
|
||||
dtype=self.config.torch_dtype,
|
||||
)
|
||||
|
||||
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict, List, Set, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
@ -386,3 +387,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
weight_loader = getattr(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
"weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
lm_head_weight)
|
||||
|
@ -10,6 +10,7 @@ from vllm.distributed.parallel_state import (get_tp_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
@ -173,3 +174,21 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._worker.vocab_size
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
weight_loader = getattr(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
"weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
lm_head_weight)
|
||||
|
@ -9,7 +9,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -155,6 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
enable_lm_head_weight_load = False
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
@ -187,6 +189,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"EAGLE does not support TP > 1 yet")
|
||||
|
||||
allow_zero_draft_token_step = False
|
||||
|
||||
# Load lm_head weight for eagle in init_device
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
enable_lm_head_weight_load = True
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
@ -239,7 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step,
|
||||
enable_lm_head_weight_load=enable_lm_head_weight_load)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -252,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
enable_lm_head_weight_load: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
@ -282,6 +291,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
allow_zero_draft_token_step: whether to allow a step where the draft
|
||||
model generates no draft token; should disallow when the tp of
|
||||
draft model is larger than 1 (TODO: #5814)
|
||||
enable_lm_head_weight_load: whether to load lm_head weight for
|
||||
draft models like eagle.
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
@ -291,6 +302,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._allow_zero_draft_token_step = allow_zero_draft_token_step
|
||||
self._enable_lm_head_weight_load = enable_lm_head_weight_load
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
@ -327,6 +339,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
if self._enable_lm_head_weight_load:
|
||||
# NOTE(Shangming): gather lm_head weight when tp enabled
|
||||
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
|
||||
self.scorer_worker.model_runner.model_runner.model.lm_head.\
|
||||
weight.data,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_lm_head_weight)
|
||||
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
device_type=self.device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user