From 7ca9934fe773edf8680aed287b0a05cb195bd8e4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 6 Feb 2025 04:02:14 -0500 Subject: [PATCH] [Misc] Update w2 scale loading for GPTQMarlinMoE (#12757) --- tests/weight_loading/models-large.txt | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- .../layers/quantization/gptq_marlin.py | 23 ++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 8ab7f05d..9c1c11da 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,5 +1,7 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3c7ef5e0..f18c0313 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module): "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ == - "CompressedTensorsWNA16MoEMethod"): + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 99ab2999..84c53b2c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size_full) + if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = (intermediate_size_per_partition // - self.quant_config.group_size) + w2_scales_size = (intermediate_size_full + if self.quant_config.desc_act else + intermediate_size_per_partition) + scales_size2 = (w2_scales_size // self.quant_config.group_size) strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 @@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, + {"load_full_w2": self.quant_config.desc_act}) # up_proj scales w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, @@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, + {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - ).to(orig_dtype) + is_k_full=self.is_k_full).to(orig_dtype)