[V1][Spec Decode] Eagle Model loading (#16035)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
9665313c39
commit
e8224f3dca
@ -76,6 +76,7 @@ def main():
|
|||||||
max_num_seqs=args.max_num_seqs,
|
max_num_seqs=args.max_num_seqs,
|
||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
|
"method": "eagle",
|
||||||
"model": eagle_dir,
|
"model": eagle_dir,
|
||||||
"num_speculative_tokens": args.num_spec_tokens,
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
"draft_tensor_parallel_size": args.draft_tp,
|
"draft_tensor_parallel_size": args.draft_tp,
|
||||||
|
@ -374,6 +374,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||||
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
|
trust_remote_code=True,
|
||||||
|
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
|
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_MODELS = {
|
_TRANSFORMERS_MODELS = {
|
||||||
|
@ -53,6 +53,11 @@ def model_name():
|
|||||||
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def eagle_model_name():
|
||||||
|
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
@ -95,3 +100,47 @@ def test_ngram_correctness(
|
|||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.7 * len(ref_outputs))
|
assert matches > int(0.7 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_eagle_correctness(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
eagle_model_name: str,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
|
should be the same when using eagle speculative decoding.
|
||||||
|
'''
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
|
del ref_llm
|
||||||
|
|
||||||
|
spec_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
speculative_config={
|
||||||
|
"method": "eagle",
|
||||||
|
"model": eagle_model_name,
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
|
max_model_len=1024,
|
||||||
|
)
|
||||||
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
|
matches += 1
|
||||||
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
|
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||||
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
assert matches > int(0.7 * len(ref_outputs))
|
||||||
|
del spec_llm
|
@ -414,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
return ((source.prefix + name, tensor)
|
return ((source.prefix + name, tensor)
|
||||||
for (name, tensor) in weights_iterator)
|
for (name, tensor) in weights_iterator)
|
||||||
|
|
||||||
def _get_all_weights(
|
def get_all_weights(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -453,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
loaded_weights = model.load_weights(
|
loaded_weights = model.load_weights(
|
||||||
self._get_all_weights(model_config, model))
|
self.get_all_weights(model_config, model))
|
||||||
self.counter_after_loading_weights = time.perf_counter()
|
self.counter_after_loading_weights = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Loading weights took %.2f seconds",
|
"Loading weights took %.2f seconds",
|
||||||
|
151
vllm/model_executor/models/llama_eagle.py
Normal file
151
vllm/model_executor/models/llama_eagle.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Iterable, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM)
|
||||||
|
|
||||||
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
disable_input_layernorm: bool,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, prefix=prefix)
|
||||||
|
|
||||||
|
# Skip the input_layernorm
|
||||||
|
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
||||||
|
if disable_input_layernorm:
|
||||||
|
del self.input_layernorm
|
||||||
|
self.input_layernorm = nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
start_layer_id: int = 0,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = model_config.hf_config
|
||||||
|
self.vocab_size = self.config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
LlamaDecoderLayer(
|
||||||
|
self.config,
|
||||||
|
i == 0,
|
||||||
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||||
|
) for i in range(self.config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.fc(
|
||||||
|
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||||
|
residual = None
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
return hidden_states + residual
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
|
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.config = model_config.hf_config
|
||||||
|
self.model = LlamaModel(model_config=model_config,
|
||||||
|
start_layer_id=start_layer_id,
|
||||||
|
prefix="model")
|
||||||
|
|
||||||
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||||
|
scale=logit_scale)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = {}
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "lm_head" not in name:
|
||||||
|
name = "model." + name
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
|
loader.load_weights(model_weights.items())
|
@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
|
|
||||||
_SPECULATIVE_DECODING_MODELS = {
|
_SPECULATIVE_DECODING_MODELS = {
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
|
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
|
|
||||||
|
|
||||||
@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
self.truncated_vocab_size = self.model.vocab_size if \
|
self.truncated_vocab_size = self.model.vocab_size if \
|
||||||
truncated_vocab_size is None else truncated_vocab_size
|
truncated_vocab_size is None else truncated_vocab_size
|
||||||
|
|
||||||
if "architectures" not in kwargs:
|
if not envs.VLLM_USE_V1:
|
||||||
kwargs["architectures"] = ["EAGLEModel"]
|
kwargs["architectures"] = ["EAGLEModel"]
|
||||||
|
else:
|
||||||
|
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@ -4,8 +4,11 @@ import torch.nn as nn
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
@ -21,8 +24,12 @@ class EagleProposer:
|
|||||||
self.num_speculative_tokens = (
|
self.num_speculative_tokens = (
|
||||||
vllm_config.speculative_config.num_speculative_tokens)
|
vllm_config.speculative_config.num_speculative_tokens)
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
|
# We need +1 here because the arange is used to set query_start_loc,
|
||||||
device=device)
|
# which has one more element than batch_size.
|
||||||
|
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||||
|
1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
@ -54,7 +61,9 @@ class EagleProposer:
|
|||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
input_ids[last_token_indices] = next_token_ids
|
input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
seq_lens = target_positions[last_token_indices] + 1
|
# FA requires seq_len to have dtype int32.
|
||||||
|
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||||
|
|
||||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||||
max_seq_len = seq_lens.max().item()
|
max_seq_len = seq_lens.max().item()
|
||||||
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
|
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
|
||||||
@ -98,7 +107,7 @@ class EagleProposer:
|
|||||||
hidden_states = sample_hidden_states
|
hidden_states = sample_hidden_states
|
||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange[:batch_size]
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
for _ in range(self.num_speculative_tokens - 1):
|
for _ in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
input_ids = draft_token_ids_list[-1]
|
input_ids = draft_token_ids_list[-1]
|
||||||
@ -176,26 +185,28 @@ class EagleProposer:
|
|||||||
return cu_num_tokens, token_indices
|
return cu_num_tokens, token_indices
|
||||||
|
|
||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
self.model = DummyEagleModel()
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
self.model.get_input_embeddings = target_model.get_input_embeddings
|
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
||||||
self.model.compute_logits = target_model.compute_logits
|
self.vllm_config.parallel_config)
|
||||||
|
|
||||||
|
draft_model_config = \
|
||||||
|
self.vllm_config.speculative_config.draft_model_config
|
||||||
|
# FIXME(lily): This does not handle with distributed inference.
|
||||||
|
target_device = self.vllm_config.device_config.device
|
||||||
|
# We need to set the vllm_config here to register attention
|
||||||
|
# layers in the forward context.
|
||||||
|
with set_default_torch_dtype(
|
||||||
|
draft_model_config.dtype), set_current_vllm_config(
|
||||||
|
self.vllm_config):
|
||||||
|
self.model = EagleLlamaForCausalLM(
|
||||||
|
model_config=draft_model_config,
|
||||||
|
start_layer_id=target_layer_num).to(target_device)
|
||||||
|
|
||||||
# FIXME(woosuk): This is a dummy model for testing.
|
self.model.load_weights(
|
||||||
# Remove this once we have a real model.
|
loader.get_all_weights(
|
||||||
class DummyEagleModel(nn.Module):
|
self.vllm_config.speculative_config.draft_model_config,
|
||||||
|
self.model))
|
||||||
def __init__(self):
|
self.model.lm_head = target_model.lm_head
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
input_embeddings = self.get_input_embeddings(input_ids)
|
|
||||||
return hidden_states + input_embeddings # Dummy return.
|
|
||||||
|
|
||||||
|
|
||||||
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
|
||||||
|
@ -1191,9 +1191,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
|
# We need to slice token_ids, positions, and hidden_states
|
||||||
|
# because the eagle head does not use cuda graph and should
|
||||||
|
# not include padding.
|
||||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||||
target_positions = positions
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
target_hidden_states = hidden_states
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping
|
target_slot_mapping = attn_metadata.slot_mapping
|
||||||
cu_num_tokens = attn_metadata.query_start_loc
|
cu_num_tokens = attn_metadata.query_start_loc
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user