[Feature] Enhance EAGLE Architecture with Proper RMS Norms (#14990)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
5aefd6ac31
commit
781d056280
@ -800,10 +800,18 @@ class ModelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_deepseek_mla(self) -> bool:
|
def is_deepseek_mla(self) -> bool:
|
||||||
return (hasattr(self.hf_text_config, "model_type")) \
|
if not hasattr(self.hf_text_config, "model_type"):
|
||||||
and (self.hf_text_config.model_type in \
|
return False
|
||||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
|
elif self.hf_text_config.model_type in \
|
||||||
and (self.hf_text_config.kv_lora_rank is not None)
|
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
||||||
|
return self.hf_text_config.kv_lora_rank is not None
|
||||||
|
elif self.hf_text_config.model_type == 'eagle':
|
||||||
|
# if the model is an EAGLE module, check for the
|
||||||
|
# underlying architecture
|
||||||
|
return self.hf_text_config.model.model_type in \
|
||||||
|
('deepseek_v2', 'deepseek_v3') \
|
||||||
|
and self.hf_text_config.kv_lora_rank is not None
|
||||||
|
return False
|
||||||
|
|
||||||
def get_head_size(self) -> int:
|
def get_head_size(self) -> int:
|
||||||
# TODO remove hard code
|
# TODO remove hard code
|
||||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -59,7 +60,15 @@ class EAGLE(nn.Module):
|
|||||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||||
in the draft checkpoint (using key token_map). Also, the draft config
|
in the draft checkpoint (using key token_map). Also, the draft config
|
||||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
needs to have truncated_vocab_size (=k) as an attribute.
|
||||||
|
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
|
||||||
|
module with regards to the use of additional RMS norms. The original
|
||||||
|
EAGLE architecture 1) skips the pre-attention norm in its first
|
||||||
|
transformer block, and 2) skips the final output norm, both of which we
|
||||||
|
found to be suboptimal. We also add the support for separate norms
|
||||||
|
applying to both the token embedding and hidden states before projection
|
||||||
|
as in DeepSeek MTP, which we found to improve performance as well.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -81,9 +90,22 @@ class EAGLE(nn.Module):
|
|||||||
# While weights and biases are generally not needed,
|
# While weights and biases are generally not needed,
|
||||||
# they are retained here to support certain unit tests
|
# they are retained here to support certain unit tests
|
||||||
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
||||||
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
if not hasattr(self.config.model,
|
||||||
weight=self.model.model.layers[0].input_layernorm.weight)
|
"skip_prenorm") or self.config.model.skip_prenorm:
|
||||||
self.model.model.norm = DummyOutputNorm()
|
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
||||||
|
weight=self.model.model.layers[0].input_layernorm.weight)
|
||||||
|
|
||||||
|
if not hasattr(
|
||||||
|
self.config.model,
|
||||||
|
"skip_output_norm") or self.config.model.skip_output_norm:
|
||||||
|
self.model.model.norm = DummyOutputNorm()
|
||||||
|
|
||||||
|
self.add_para_norm = False
|
||||||
|
if hasattr(self.config.model,
|
||||||
|
"add_para_norm") and self.config.model.add_para_norm:
|
||||||
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.add_para_norm = True
|
||||||
|
|
||||||
self.orig_vocab_size = config.vocab_size
|
self.orig_vocab_size = config.vocab_size
|
||||||
self.truncated_vocab_size = config.truncated_vocab_size
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
@ -128,8 +150,17 @@ class EAGLE(nn.Module):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
inputs_embeds = self.fc(
|
if self.add_para_norm:
|
||||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
inputs_embeds = torch.cat([
|
||||||
|
self.enorm(inputs_embeds),
|
||||||
|
self.hnorm(previous_hidden_states)
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
else:
|
||||||
|
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
inputs_embeds = self.fc(inputs_embeds)
|
||||||
|
|
||||||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||||
|
|
||||||
@ -190,6 +221,14 @@ class EAGLE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logger.warning_once("Found bias in the loaded weights but "
|
logger.warning_once("Found bias in the loaded weights but "
|
||||||
"the model config doesn't have bias.")
|
"the model config doesn't have bias.")
|
||||||
|
elif name.startswith("enorm.weight"):
|
||||||
|
weight_loader = getattr(self.enorm.weight, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(self.enorm.weight, loaded_weight)
|
||||||
|
elif name.startswith("hnorm.weight"):
|
||||||
|
weight_loader = getattr(self.hnorm.weight, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(self.hnorm.weight, loaded_weight)
|
||||||
elif name.startswith("model.lm_head.") or name.startswith(
|
elif name.startswith("model.lm_head.") or name.startswith(
|
||||||
"model.model."):
|
"model.model."):
|
||||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||||
|
@ -5,6 +5,8 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
|
|
||||||
|
|
||||||
class EAGLEConfig(PretrainedConfig):
|
class EAGLEConfig(PretrainedConfig):
|
||||||
model_type = "eagle"
|
model_type = "eagle"
|
||||||
@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
truncated_vocab_size: Optional[int] = None,
|
truncated_vocab_size: Optional[int] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
model_config = None if model is None else (AutoConfig.for_model(
|
model_config: Union[PretrainedConfig, DeepseekV2Config, None]
|
||||||
**model) if isinstance(model, dict) else model)
|
if isinstance(model, dict):
|
||||||
|
archs = model.get("architectures", [])
|
||||||
|
target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
|
||||||
|
if any(target_arch in archs for target_arch in target_archs):
|
||||||
|
# AutoConfig does not support DeepSeek MoE models yet
|
||||||
|
model_config = DeepseekV2Config(**model)
|
||||||
|
else:
|
||||||
|
model_config = AutoConfig.for_model(**model)
|
||||||
|
else:
|
||||||
|
model_config = model
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if k != "architectures" and k != "model_type" and hasattr(
|
if k != "architectures" and k != "model_type" and hasattr(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user