[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#6611)
This commit is contained in:
parent
85ad7e2d01
commit
07278c37dd
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nvidia/Minitron-4B-Base"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.252
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.252
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -4,5 +4,6 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
|
@ -159,6 +159,21 @@ class QuickGELU(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
relu_applied = nn.functional.relu(x)
|
||||
squared = torch.square(relu_applied)
|
||||
return squared
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
@ -207,6 +222,7 @@ _ACTIVATION_REGISTRY = {
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
"relu2": ReLUSquaredActivation(),
|
||||
"quick_gelu": QuickGELU(),
|
||||
}
|
||||
|
||||
|
@ -774,6 +774,7 @@ def get_rope(
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
rotary_percent: float = 1.0,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -786,6 +787,8 @@ def get_rope(
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
if rotary_percent < 1.0:
|
||||
rotary_dim = int(rotary_dim * rotary_percent)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
|
@ -51,6 +51,7 @@ _GENERATION_MODELS = {
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
|
531
vllm/model_executor/models/nemotron.py
Normal file
531
vllm/model_executor/models/nemotron.py
Normal file
@ -0,0 +1,531 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Nemotron model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs import NemotronConfig
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
# The architecture is pretty similar to Llama, with these changes:
|
||||
# - There is no gate_proj, just up_proj
|
||||
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
|
||||
# - Squared ReLU instead of SwiGLU
|
||||
# - Adds a rotary_percent to RoPE
|
||||
|
||||
|
||||
def _cast_if_autocast_enabled(*args):
|
||||
if not torch.is_autocast_enabled():
|
||||
return args
|
||||
else:
|
||||
return torch.cuda.amp.autocast_mode._cast(
|
||||
args, torch.get_autocast_gpu_dtype())
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape: Union[int, List[int], torch.Size],
|
||||
eps: float = 1e-5,
|
||||
elementwise_affine: bool = True,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine, bias,
|
||||
device, dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
args = _cast_if_autocast_enabled(x, self.normalized_shape,
|
||||
self.weight + 1, self.bias, self.eps)
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = torch.nn.functional.layer_norm(*args)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
|
||||
class NemotronMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_fn = get_act_fn(hidden_act)
|
||||
|
||||
def forward(self, x):
|
||||
up, _ = self.up_proj(x)
|
||||
x = self.act_fn(up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class NemotronAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.rotary_percent = config.rope_percent
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
rotary_percent=self.rotary_percent,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class NemotronDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = NemotronAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = NemotronMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = NemotronLayerNorm1P(config.hidden_size,
|
||||
eps=config.norm_eps)
|
||||
self.post_attention_layernorm = NemotronLayerNorm1P(
|
||||
config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NemotronModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: NemotronDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = NemotronLayerNorm1P(config.hidden_size,
|
||||
eps=config.norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NemotronForCausalLM(nn.Module, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(config, NemotronConfig)
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = NemotronModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
@ -8,7 +8,7 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
JAISConfig, MedusaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
RWConfig)
|
||||
NemotronConfig, RWConfig)
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@ -26,6 +26,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"jais": JAISConfig,
|
||||
"mlp_speculator": MLPSpeculatorConfig,
|
||||
"medusa": MedusaConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
@ -8,6 +8,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
|
||||
__all__ = [
|
||||
"ChatGLMConfig",
|
||||
@ -17,4 +18,5 @@ __all__ = [
|
||||
"JAISConfig",
|
||||
"MedusaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"NemotronConfig",
|
||||
]
|
||||
|
209
vllm/transformers_utils/configs/nemotron.py
Normal file
209
vllm/transformers_utils/configs/nemotron.py
Normal file
@ -0,0 +1,209 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Nemotron model configuration"""
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class NemotronConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a
|
||||
[`NemotronModel`]. It is used to instantiate an Nemotron model
|
||||
according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Nemotron-8B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be
|
||||
used to control the model outputs. Read the documentation from
|
||||
[`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Nemotron model. Defines the number of
|
||||
different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`NemotronModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the
|
||||
Transformer decoder.
|
||||
head_dim (`int`, *optional*, defaults to None):
|
||||
Projection weights dimension in multi-head attention. Set to
|
||||
hidden_size // num_attention_heads if None
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to
|
||||
implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use
|
||||
Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention
|
||||
(MQA) otherwise GQA is used. When converting a multi-head
|
||||
checkpoint to a GQA checkpoint, each group key and value
|
||||
head should be constructed by meanpooling all the original
|
||||
heads within that group. For more details checkout
|
||||
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
|
||||
is not specified, will default to `num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the
|
||||
decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used
|
||||
with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values
|
||||
attentions (not used by all models). Only relevant if
|
||||
`config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE
|
||||
embeddings. Currently supports two scaling strategies: linear
|
||||
and dynamic. Their scaling factor must be a float greater than 1.
|
||||
The expected format is `{"type": strategy name,
|
||||
"factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output
|
||||
projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj and down_proj layers in the MLP
|
||||
layers.
|
||||
|
||||
```python
|
||||
>>> from transformers import NemotronModel, NemotronConfig
|
||||
|
||||
>>> # Initializing a Nemotron nemotron-15b style configuration
|
||||
>>> configuration = NemotronConfig()
|
||||
|
||||
>>> # Initializing a model from the nemotron-15b style configuration
|
||||
>>> model = NemotronModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "nemotron"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=6144,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=48,
|
||||
head_dim=None,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="relu2",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.0134,
|
||||
norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
rope_percent=0.5,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
head_dim = head_dim or kwargs.get("kv_channels", None)
|
||||
self.head_dim = head_dim if head_dim is not None else (
|
||||
hidden_size // num_attention_heads)
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_eps = norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
rope_percent = rope_percent or kwargs.get("rope_percentage", None)
|
||||
self.rope_percent = rope_percent
|
||||
self._rope_scaling_validation()
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling,
|
||||
dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, "
|
||||
f"`type` and `factor`, got {self.rope_scaling}")
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in [
|
||||
"linear", "dynamic"
|
||||
]:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s type field must be one of ['linear', "
|
||||
f"'dynamic'], got {rope_scaling_type}")
|
||||
if rope_scaling_factor is None or not isinstance(
|
||||
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor field must be a float > 1, got "
|
||||
f"{rope_scaling_factor}")
|
Loading…
x
Reference in New Issue
Block a user