[Model] Align nemotron config with final HF state and fix lm-eval-small (#7611)

This commit is contained in:
Michael Goin 2024-08-16 18:56:34 -04:00 committed by GitHub
parent 37fd47e780
commit 44f26a9466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 35 deletions

View File

@ -1,11 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1 # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
model_name: "nvidia/Minitron-4B-Base" model_name: "mgoin/Minitron-4B-Base-FP8"
tasks: tasks:
- name: "gsm8k" - name: "gsm8k"
metrics: metrics:
- name: "exact_match,strict-match" - name: "exact_match,strict-match"
value: 0.252 value: 0.233
- name: "exact_match,flexible-extract" - name: "exact_match,flexible-extract"
value: 0.252 value: 0.236
limit: 1000 limit: 1000
num_fewshot: 5 num_fewshot: 5

View File

@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml Minitron-4B-Base-FP8.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml Meta-Llama-3-8B-QQQ.yaml

View File

@ -774,7 +774,7 @@ def get_rope(
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
rotary_percent: float = 1.0, partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
@ -787,8 +787,8 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items()) rope_scaling_args = tuple(rope_scaling_tuple.items())
else: else:
rope_scaling_args = None rope_scaling_args = None
if rotary_percent < 1.0: if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * rotary_percent) rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype) rope_scaling_args, dtype)
if key in _ROPE_DICT: if key in _ROPE_DICT:

View File

@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
# - There is no gate_proj, just up_proj # - There is no gate_proj, just up_proj
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm # - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Squared ReLU instead of SwiGLU # - Squared ReLU instead of SwiGLU
# - Adds a rotary_percent to RoPE # - Adds a partial_rotary_factor to RoPE
def _cast_if_autocast_enabled(*args): def _cast_if_autocast_enabled(*args):
@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rotary_percent = config.rope_percent self.partial_rotary_factor = config.partial_rotary_factor
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
rotary_percent=self.rotary_percent, partial_rotary_factor=self.partial_rotary_factor,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,

View File

@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 32000): vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Nemotron model. Defines the number of Vocabulary size of the Nemotron model. Defines the number of
different tokens that can be represented by the different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`] `inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 4096): hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the hidden representations. Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008): intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations. Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32): num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder. Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32): num_attention_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Number of attention heads for each attention layer in the
Transformer decoder. Transformer decoder.
head_dim (`int`, *optional*, defaults to None): head_dim (`int`, *optional*):
Projection weights dimension in multi-head attention. Set to Projection weights dimension in multi-head attention. Set to
hidden_size // num_attention_heads if None hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*): num_key_value_heads (`int`, *optional*):
@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig):
heads within that group. For more details checkout heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
is not specified, will default to `num_attention_heads`. is not specified, will default to `num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the The non-linear activation function (function or string) in the
decoder. decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048): max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used The maximum sequence length that this model might ever be used
with. with.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.0134):
The standard deviation of the truncated_normal_initializer for The standard deviation of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-06): norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers. The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values Whether or not the model should return the last key/values
@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig):
`config.is_decoder=True`. `config.is_decoder=True`.
pad_token_id (`int`, *optional*): pad_token_id (`int`, *optional*):
Padding token id. Padding token id.
bos_token_id (`int`, *optional*, defaults to 1): bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id. Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2): eos_token_id (`int`, *optional*, defaults to 3):
End of stream token id. End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`): tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Dictionary containing the scaling configuration for the RoPE Percentage of the query and keys which will have rotary embedding.
embeddings. Currently supports two scaling strategies: linear
and dynamic. Their scaling factor must be a float greater than 1.
The expected format is `{"type": strategy name,
"factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output Whether to use a bias in the query, key, value and output
projection layers during self-attention. projection layers during self-attention.
@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig):
```python ```python
>>> from transformers import NemotronModel, NemotronConfig >>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration >>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig() >>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration >>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration) >>> model = NemotronModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None, rope_scaling=None,
rope_percent=0.5, partial_rotary_factor=0.5,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False, mlp_bias=False,
@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
rope_percent = rope_percent or kwargs.get("rope_percentage", None) # for backward compatibility
self.rope_percent = rope_percent partial_rotary_factor = kwargs.get("rope_percent", None) or kwargs.get(
"rope_percentage", None) or partial_rotary_factor
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation() self._rope_scaling_validation()
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout