From 44f26a94664584f20458ae0ddf1a826b2d79a13c Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 16 Aug 2024 18:56:34 -0400 Subject: [PATCH] [Model] Align nemotron config with final HF state and fix lm-eval-small (#7611) --- ...4B-Base.yaml => Minitron-4B-Base-FP8.yaml} | 8 ++-- .../lm-eval-harness/configs/models-small.txt | 2 +- .../model_executor/layers/rotary_embedding.py | 6 +-- vllm/model_executor/models/nemotron.py | 6 +-- vllm/transformers_utils/configs/nemotron.py | 42 ++++++++----------- 5 files changed, 29 insertions(+), 35 deletions(-) rename .buildkite/lm-eval-harness/configs/{Minitron-4B-Base.yaml => Minitron-4B-Base-FP8.yaml} (60%) diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml similarity index 60% rename from .buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml rename to .buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml index a0466748..3ea0b7bb 100644 --- a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml @@ -1,11 +1,11 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1 -model_name: "nvidia/Minitron-4B-Base" +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 +model_name: "mgoin/Minitron-4B-Base-FP8" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.252 + value: 0.233 - name: "exact_match,flexible-extract" - value: 0.252 + value: 0.236 limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index bca89f00..bb9cd43e 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml -Minitron-4B-Base.yaml +Minitron-4B-Base-FP8.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-FP8W8.yaml Meta-Llama-3-8B-QQQ.yaml diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 95888e79..7b3acd7f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -774,7 +774,7 @@ def get_rope( is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, - rotary_percent: float = 1.0, + partial_rotary_factor: float = 1.0, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -787,8 +787,8 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None - if rotary_percent < 1.0: - rotary_dim = int(rotary_dim * rotary_percent) + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) key = (head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling_args, dtype) if key in _ROPE_DICT: diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 57598b49..7d92a1ff 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers # - There is no gate_proj, just up_proj # - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm # - Squared ReLU instead of SwiGLU -# - Adds a rotary_percent to RoPE +# - Adds a partial_rotary_factor to RoPE def _cast_if_autocast_enabled(*args): @@ -161,7 +161,7 @@ class NemotronAttention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.rotary_percent = config.rope_percent + self.partial_rotary_factor = config.partial_rotary_factor self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( @@ -187,7 +187,7 @@ class NemotronAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, - rotary_percent=self.rotary_percent, + partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention(self.num_heads, self.head_dim, diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index a22a9f47..139e6b3c 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig): Args: - vocab_size (`int`, *optional*, defaults to 32000): + vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`NemotronModel`] - hidden_size (`int`, *optional*, defaults to 4096): + hidden_size (`int`, *optional*, defaults to 6144): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): + intermediate_size (`int`, *optional*, defaults to 24576): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): + num_attention_heads (`int`, *optional*, defaults to 48): Number of attention heads for each attention layer in the Transformer decoder. - head_dim (`int`, *optional*, defaults to None): + head_dim (`int`, *optional*): Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None num_key_value_heads (`int`, *optional*): @@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig): heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): + max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): + initializer_range (`float`, *optional*, defaults to 0.0134): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - norm_eps (`float`, *optional*, defaults to 1e-06): + norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values @@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig): `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): + bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): + eos_token_id (`int`, *optional*, defaults to 3): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE - embeddings. Currently supports two scaling strategies: linear - and dynamic. Their scaling factor must be a float greater than 1. - The expected format is `{"type": strategy name, - "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. @@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig): ```python >>> from transformers import NemotronModel, NemotronConfig - >>> # Initializing a Nemotron nemotron-15b style configuration >>> configuration = NemotronConfig() - >>> # Initializing a model from the nemotron-15b style configuration >>> model = NemotronModel(configuration) - >>> # Accessing the model configuration >>> configuration = model.config ```""" @@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig): tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, - rope_percent=0.5, + partial_rotary_factor=0.5, attention_bias=False, attention_dropout=0.0, mlp_bias=False, @@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig): self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling - rope_percent = rope_percent or kwargs.get("rope_percentage", None) - self.rope_percent = rope_percent + # for backward compatibility + partial_rotary_factor = kwargs.get("rope_percent", None) or kwargs.get( + "rope_percentage", None) or partial_rotary_factor + self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout