[Feature] Enhance EAGLE Architecture with Proper RMS Norms (#14990)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
5aefd6ac31
commit
781d056280
@ -800,10 +800,18 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return self.hf_text_config.model.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3') \
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
return False
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -59,7 +60,15 @@ class EAGLE(nn.Module):
|
||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||
in the draft checkpoint (using key token_map). Also, the draft config
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
needs to have truncated_vocab_size (=k) as an attribute.
|
||||
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
|
||||
module with regards to the use of additional RMS norms. The original
|
||||
EAGLE architecture 1) skips the pre-attention norm in its first
|
||||
transformer block, and 2) skips the final output norm, both of which we
|
||||
found to be suboptimal. We also add the support for separate norms
|
||||
applying to both the token embedding and hidden states before projection
|
||||
as in DeepSeek MTP, which we found to improve performance as well.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -81,10 +90,23 @@ class EAGLE(nn.Module):
|
||||
# While weights and biases are generally not needed,
|
||||
# they are retained here to support certain unit tests
|
||||
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
||||
if not hasattr(self.config.model,
|
||||
"skip_prenorm") or self.config.model.skip_prenorm:
|
||||
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
||||
weight=self.model.model.layers[0].input_layernorm.weight)
|
||||
|
||||
if not hasattr(
|
||||
self.config.model,
|
||||
"skip_output_norm") or self.config.model.skip_output_norm:
|
||||
self.model.model.norm = DummyOutputNorm()
|
||||
|
||||
self.add_para_norm = False
|
||||
if hasattr(self.config.model,
|
||||
"add_para_norm") and self.config.model.add_para_norm:
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.add_para_norm = True
|
||||
|
||||
self.orig_vocab_size = config.vocab_size
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||
@ -128,8 +150,17 @@ class EAGLE(nn.Module):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = self.fc(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
if self.add_para_norm:
|
||||
inputs_embeds = torch.cat([
|
||||
self.enorm(inputs_embeds),
|
||||
self.hnorm(previous_hidden_states)
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
|
||||
dim=-1)
|
||||
|
||||
inputs_embeds = self.fc(inputs_embeds)
|
||||
|
||||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||
|
||||
@ -190,6 +221,14 @@ class EAGLE(nn.Module):
|
||||
else:
|
||||
logger.warning_once("Found bias in the loaded weights but "
|
||||
"the model config doesn't have bias.")
|
||||
elif name.startswith("enorm.weight"):
|
||||
weight_loader = getattr(self.enorm.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.enorm.weight, loaded_weight)
|
||||
elif name.startswith("hnorm.weight"):
|
||||
weight_loader = getattr(self.hnorm.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.hnorm.weight, loaded_weight)
|
||||
elif name.startswith("model.lm_head.") or name.startswith(
|
||||
"model.model."):
|
||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||
|
@ -5,6 +5,8 @@ from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
|
||||
|
||||
class EAGLEConfig(PretrainedConfig):
|
||||
model_type = "eagle"
|
||||
@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
|
||||
truncated_vocab_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
model_config = None if model is None else (AutoConfig.for_model(
|
||||
**model) if isinstance(model, dict) else model)
|
||||
model_config: Union[PretrainedConfig, DeepseekV2Config, None]
|
||||
if isinstance(model, dict):
|
||||
archs = model.get("architectures", [])
|
||||
target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
|
||||
if any(target_arch in archs for target_arch in target_archs):
|
||||
# AutoConfig does not support DeepSeek MoE models yet
|
||||
model_config = DeepseekV2Config(**model)
|
||||
else:
|
||||
model_config = AutoConfig.for_model(**model)
|
||||
else:
|
||||
model_config = model
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if k != "architectures" and k != "model_type" and hasattr(
|
||||
|
Loading…
x
Reference in New Issue
Block a user