[V1][Spec Decode] Eagle Model loading (#16035)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
9665313c39
commit
e8224f3dca
@ -76,6 +76,7 @@ def main():
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config={
|
||||
"method": "eagle",
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"draft_tensor_parallel_size": args.draft_tp,
|
||||
|
@ -374,6 +374,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
|
||||
}
|
||||
|
||||
_TRANSFORMERS_MODELS = {
|
||||
|
@ -53,6 +53,11 @@ def model_name():
|
||||
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eagle_model_name():
|
||||
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
@ -95,3 +100,47 @@ def test_ngram_correctness(
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
eagle_model_name: str,
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "eagle",
|
||||
"model": eagle_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
@ -414,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
return ((source.prefix + name, tensor)
|
||||
for (name, tensor) in weights_iterator)
|
||||
|
||||
def _get_all_weights(
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
@ -453,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self._get_all_weights(model_config, model))
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
|
151
vllm/model_executor/models/llama_eagle.py
Normal file
151
vllm/model_executor/models/llama_eagle.py
Normal file
@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||
LlamaForCausalLM)
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
disable_input_layernorm: bool,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, prefix=prefix)
|
||||
|
||||
# Skip the input_layernorm
|
||||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
||||
if disable_input_layernorm:
|
||||
del self.input_layernorm
|
||||
self.input_layernorm = nn.Identity()
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
start_layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
i == 0,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
return hidden_states + residual
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
|
||||
|
||||
@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
|
||||
self.truncated_vocab_size = self.model.vocab_size if \
|
||||
truncated_vocab_size is None else truncated_vocab_size
|
||||
|
||||
if "architectures" not in kwargs:
|
||||
if not envs.VLLM_USE_V1:
|
||||
kwargs["architectures"] = ["EAGLEModel"]
|
||||
else:
|
||||
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -4,8 +4,11 @@ import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@ -21,8 +24,12 @@ class EagleProposer:
|
||||
self.num_speculative_tokens = (
|
||||
vllm_config.speculative_config.num_speculative_tokens)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
|
||||
device=device)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
@ -54,7 +61,9 @@ class EagleProposer:
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
seq_lens = target_positions[last_token_indices] + 1
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
|
||||
@ -98,7 +107,7 @@ class EagleProposer:
|
||||
hidden_states = sample_hidden_states
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size]
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
input_ids = draft_token_ids_list[-1]
|
||||
@ -176,26 +185,28 @@ class EagleProposer:
|
||||
return cu_num_tokens, token_indices
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
self.model = DummyEagleModel()
|
||||
self.model.get_input_embeddings = target_model.get_input_embeddings
|
||||
self.model.compute_logits = target_model.compute_logits
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
||||
self.vllm_config.parallel_config)
|
||||
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
# FIXME(lily): This does not handle with distributed inference.
|
||||
target_device = self.vllm_config.device_config.device
|
||||
# We need to set the vllm_config here to register attention
|
||||
# layers in the forward context.
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
self.model = EagleLlamaForCausalLM(
|
||||
model_config=draft_model_config,
|
||||
start_layer_id=target_layer_num).to(target_device)
|
||||
|
||||
# FIXME(woosuk): This is a dummy model for testing.
|
||||
# Remove this once we have a real model.
|
||||
class DummyEagleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_embeddings = self.get_input_embeddings(input_ids)
|
||||
return hidden_states + input_embeddings # Dummy return.
|
||||
self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
|
||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||
|
@ -1191,9 +1191,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
# We need to slice token_ids, positions, and hidden_states
|
||||
# because the eagle head does not use cuda graph and should
|
||||
# not include padding.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
cu_num_tokens = attn_metadata.query_start_loc
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user