[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)

This commit is contained in:
Dipika Sikka 2025-02-06 04:02:14 -05:00 committed by GitHub
parent 0408efc6d0
commit 7ca9934fe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 8 deletions

View File

@ -1,5 +1,7 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main

View File

@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)

View File

@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = (intermediate_size_per_partition //
self.quant_config.group_size)
w2_scales_size = (intermediate_size_full
if self.quant_config.desc_act else
intermediate_size_per_partition)
scales_size2 = (w2_scales_size // self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_scales,
{"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_qzeros,
{"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
is_k_full=self.is_k_full).to(orig_dtype)