[Hardware][NV] Fix Modelopt model loading for k-v-scales for Llama models. (#11787)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
ff7424f491
commit
b02fd288b2
@ -652,9 +652,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
return remapped_name
|
||||
|
||||
possible_scale_names = [".k_scale", ".v_scale"]
|
||||
modelopt_scale_names = [
|
||||
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
||||
]
|
||||
for scale_name in possible_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if any(mo_scale_name in name
|
||||
for mo_scale_name in modelopt_scale_names):
|
||||
remapped_name = name.replace(
|
||||
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||
f".self_attn.attn{scale_name}")
|
||||
else:
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if remapped_name not in params_dict:
|
||||
logger.warning_once(
|
||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||
|
@ -404,6 +404,11 @@ class LlamaModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@ -423,10 +428,6 @@ class LlamaModel(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
@ -452,7 +452,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user