[Hardware][NV] Fix Modelopt model loading for k-v-scales for Llama models. (#11787)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
ff7424f491
commit
b02fd288b2
@ -652,8 +652,17 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return remapped_name
|
return remapped_name
|
||||||
|
|
||||||
possible_scale_names = [".k_scale", ".v_scale"]
|
possible_scale_names = [".k_scale", ".v_scale"]
|
||||||
|
modelopt_scale_names = [
|
||||||
|
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
||||||
|
]
|
||||||
for scale_name in possible_scale_names:
|
for scale_name in possible_scale_names:
|
||||||
if name.endswith(scale_name):
|
if name.endswith(scale_name):
|
||||||
|
if any(mo_scale_name in name
|
||||||
|
for mo_scale_name in modelopt_scale_names):
|
||||||
|
remapped_name = name.replace(
|
||||||
|
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||||
|
f".self_attn.attn{scale_name}")
|
||||||
|
else:
|
||||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
@ -404,6 +404,11 @@ class LlamaModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
if "scale" in name:
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@ -423,10 +428,6 @@ class LlamaModel(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
@ -452,7 +452,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user