[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)
This commit is contained in:
parent
0408efc6d0
commit
7ca9934fe7
@ -1,5 +1,7 @@
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
||||
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
|
||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
|
||||
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
|
@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__ ==
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Currently assuming is_k_full is always True
|
||||
# (input size per partition is the same as full input size)
|
||||
# Supports only sym for now (no zp)
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size_full)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
scales_size2 = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
w2_scales_size = (intermediate_size_full
|
||||
if self.quant_config.desc_act else
|
||||
intermediate_size_per_partition)
|
||||
scales_size2 = (w2_scales_size // self.quant_config.group_size)
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_scales,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_qzeros,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
).to(orig_dtype)
|
||||
is_k_full=self.is_k_full).to(orig_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user