[Hardware][NV] Fix Modelopt model loading for k-v-scales for Llama models. (#11787)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Pavani Majety 2025-01-29 01:46:12 -08:00 committed by GitHub
parent ff7424f491
commit b02fd288b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 6 deletions

View File

@ -652,9 +652,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if any(mo_scale_name in name
for mo_scale_name in modelopt_scale_names):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}")
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "

View File

@ -404,6 +404,11 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@ -423,10 +428,6 @@ class LlamaModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue

View File

@ -452,7 +452,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)