
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
261 lines
11 KiB
Python
261 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .utils import maybe_prefix
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DummyInputLayerNorm(nn.Module):
|
|
|
|
def __init__(self, weight=None, bias=None):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(weight) if weight is not None else None
|
|
self.bias = nn.Parameter(bias) if bias is not None else None
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class DummyOutputNorm(nn.Module):
|
|
|
|
def forward(self, x, residual):
|
|
if residual is None:
|
|
return x
|
|
else:
|
|
return x + residual, None
|
|
|
|
|
|
class EAGLE(nn.Module):
|
|
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
|
Reference implementation: https://github.com/SafeAILab/EAGLE
|
|
|
|
Differences from reference implementation:
|
|
1. In reference, LlamaDecoderLayer implementation doesn't have
|
|
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
|
|
Following this approach, our implementation also disables
|
|
the input_layernorm for the first decoder layer.
|
|
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
|
decoder layer is fixed to be LlamaDecoderLayer.
|
|
3. We have an optional token_map which reduces draft vocab to most
|
|
frequently used tokens to give some additional speed-up by reducing
|
|
sampling overhead. This is disabled unless the checkpoint file has
|
|
explicit token_map tensor and config has an optional attribute
|
|
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
|
the top-k most frequent tokens in target dataset and add that as a tensor
|
|
in the draft checkpoint (using key token_map). Also, the draft config
|
|
needs to have truncated_vocab_size (=k) as an attribute.
|
|
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
|
|
module with regards to the use of additional RMS norms. The original
|
|
EAGLE architecture 1) skips the pre-attention norm in its first
|
|
transformer block, and 2) skips the final output norm, both of which we
|
|
found to be suboptimal. We also add the support for separate norms
|
|
applying to both the token embedding and hidden states before projection
|
|
as in DeepSeek MTP, which we found to improve performance as well.
|
|
"""
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.config = config
|
|
|
|
architectures = getattr(self.config.model, "architectures", [])
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
|
|
|
self.model = model_cls(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
|
|
self.fc = nn.Linear(config.model.hidden_size * 2,
|
|
config.model.hidden_size,
|
|
bias=getattr(self.config, "eagle_fc_bias", False))
|
|
|
|
# Modify layer normalization and residual connections as suggested
|
|
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
|
|
# While weights and biases are generally not needed,
|
|
# they are retained here to support certain unit tests
|
|
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
|
|
if not hasattr(self.config.model,
|
|
"skip_prenorm") or self.config.model.skip_prenorm:
|
|
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
|
|
weight=self.model.model.layers[0].input_layernorm.weight)
|
|
|
|
if not hasattr(
|
|
self.config.model,
|
|
"skip_output_norm") or self.config.model.skip_output_norm:
|
|
self.model.model.norm = DummyOutputNorm()
|
|
|
|
self.add_para_norm = False
|
|
if hasattr(self.config.model,
|
|
"add_para_norm") and self.config.model.add_para_norm:
|
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.add_para_norm = True
|
|
|
|
self.orig_vocab_size = config.vocab_size
|
|
self.truncated_vocab_size = config.truncated_vocab_size
|
|
self.unpadded_vocab_size = self.truncated_vocab_size
|
|
|
|
self.lm_head = ParallelLMHead(
|
|
self.unpadded_vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=self.truncated_vocab_size,
|
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
|
)
|
|
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
self.truncated_vocab_size,
|
|
logit_scale)
|
|
|
|
# Token map is a idx to token mapping to reduce the vocab size for
|
|
# the draft model. Using smaller vocab size for draft, containing
|
|
# only most frequent tokens reduces the speculation overhead. This
|
|
# doesn't affect the acceptance rate much and thus gives more speed
|
|
# -up. By default, this is disabled and is only used if the EAGLE
|
|
# checkpoint file has token_map tensor.
|
|
self.token_map = None
|
|
|
|
@property
|
|
def sampler(self):
|
|
return self.model.sampler
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
|
|
|
if self.add_para_norm:
|
|
inputs_embeds = torch.cat([
|
|
self.enorm(inputs_embeds),
|
|
self.hnorm(previous_hidden_states)
|
|
],
|
|
dim=-1)
|
|
else:
|
|
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
|
|
dim=-1)
|
|
|
|
inputs_embeds = self.fc(inputs_embeds)
|
|
|
|
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
|
|
|
hidden_states = self.model.model(
|
|
input_ids=None,
|
|
inputs_embeds=inputs_embeds,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
|
|
if self.token_map is not None:
|
|
_logits = logits
|
|
logits = -torch.inf * torch.ones(
|
|
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
|
device=_logits.device,
|
|
dtype=_logits.dtype)
|
|
|
|
logits[..., self.token_map] = _logits
|
|
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
|
|
# due to missing lm_head weights and its config being that of a
|
|
# Llama model. Here's a compatible version with the same weights:
|
|
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
|
|
# Also, here's an example script for converting trained EAGLE
|
|
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
|
|
model_weights = {}
|
|
for name, loaded_weight in weights:
|
|
if name == "token_map":
|
|
if self.config.truncated_vocab_size < self.config.vocab_size:
|
|
self.token_map = nn.Parameter(loaded_weight,
|
|
requires_grad=False)
|
|
elif name.startswith("fc.weight"):
|
|
weight_loader = getattr(self.fc.weight, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(self.fc.weight, loaded_weight)
|
|
elif name.startswith("fc.bias"):
|
|
if self.fc.bias is not None:
|
|
weight_loader = getattr(self.fc.bias, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(self.fc.bias, loaded_weight)
|
|
else:
|
|
logger.warning_once("Found bias in the loaded weights but "
|
|
"the model config doesn't have bias.")
|
|
elif name.startswith("enorm.weight"):
|
|
weight_loader = getattr(self.enorm.weight, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(self.enorm.weight, loaded_weight)
|
|
elif name.startswith("hnorm.weight"):
|
|
weight_loader = getattr(self.hnorm.weight, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(self.hnorm.weight, loaded_weight)
|
|
elif name.startswith("model.lm_head.") or name.startswith(
|
|
"model.model."):
|
|
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
|
elif name.startswith("lm_head.") or name.startswith("model."):
|
|
model_weights[name] = loaded_weight
|
|
else:
|
|
model_weights[f"model.{name}"] = loaded_weight
|
|
|
|
if "lm_head.weight" in model_weights:
|
|
lm_head_weight = model_weights.pop("lm_head.weight")
|
|
|
|
if self.token_map is not None and\
|
|
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
|
|
|
lm_head_weight = lm_head_weight[self.token_map]
|
|
|
|
else:
|
|
# NOTE(Shangming): initialize the placeholder for lm_head weight.
|
|
lm_head_weight = torch.zeros(
|
|
self.lm_head.org_vocab_size,
|
|
self.lm_head.embedding_dim,
|
|
dtype=self.config.torch_dtype,
|
|
)
|
|
|
|
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(self.lm_head.weight, lm_head_weight)
|
|
|
|
self.model.load_weights(model_weights.items())
|