From c70cf0fe061dc92a5608a67adbd12f82c52f8d9c Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 10 Apr 2025 01:08:47 -0600 Subject: [PATCH] [Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038) Signed-off-by: mgoin --- .../Qwen1.5-MoE-W4A16-compressed-tensors.yaml | 11 + .../lm-eval-harness/configs/models-small.txt | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 11 +- .../compressed_tensors/compressed_tensors.py | 3 +- .../compressed_tensors_moe.py | 242 +++++++++++++++++- 5 files changed, 254 insertions(+), 15 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml diff --git a/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml new file mode 100644 index 00000000..166af81a --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1 +model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.31 + - name: "exact_match,flexible-extract" + value: 0.47 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 6057229a..254d01ed 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml -Minitron-4B-Base-FP8.yaml +Qwen1.5-MoE-W4A16-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-FP8W8.yaml Meta-Llama-3-8B-QQQ.yaml diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 80ac5f42..89a7548d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module): } # need full intermediate size pre-sharding for WNA16 act order if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module): # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality - loaded_weight = loaded_weight.t().contiguous() if ( - self.quant_method.__class__.__name__ - == "CompressedTensorsWNA16MoEMethod") else loaded_weight + if self.quant_method.__class__.__name__ in ( + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod"): + loaded_weight = loaded_weight.t().contiguous() if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4b2d7ca2..b714d95b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig): if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method( - self, layer.activation, layer.expert_map) + return CompressedTensorsMoEMethod.get_moe_method(self, layer) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f573c8ae..d2299965 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,7 +6,8 @@ from typing import Callable, List, Optional import torch from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import (ActivationOrdering, + QuantizationStrategy) import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops @@ -30,9 +31,11 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", - "CompressedTensorsWNA16MoEMethod" + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", ] @@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - activation: str, - expert_map: Optional[torch.Tensor], + layer: torch.nn.Module, ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - return CompressedTensorsWNA16MoEMethod(quant_config) + # Prefer to use the non-marlin kernel when: + # 1. Many experts (MarlinMoE gives poor performance when >= 16) + # 2. Non-FP16 dtype (MarlinMoE only supports FP16) + # 3. Actorder is not group/dynamic (g_idx is unsupported) + # 4. Scaled are grouped (channelwise is unsupported) + if ((layer.local_num_experts >= 16 + or layer.params_dtype != torch.float16) and + weight_quant.actorder not in (ActivationOrdering.GROUP, + ActivationOrdering.DYNAMIC) + and weight_quant.strategy in QuantizationStrategy.GROUP): + return CompressedTensorsWNA16MoEMethod(quant_config) + else: + return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - and activation == "silu" and expert_map is None): + and layer.activation == "silu" and layer.expert_map is None): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -482,7 +496,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ) -class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): +class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( self, @@ -823,3 +837,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.num_bits, is_k_full=self.is_k_full) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + # channelwise is not supported by this kernel + assert config.strategy == "group" + self.group_size = config.group_size + # grouped actorder isn't supported by this kernel + assert config.actorder != "group" + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_scales_size = intermediate_size_per_partition + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": False}) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Reconfigure packed weights and scales to match moe_wna16 format + layer.w13_weight_packed = torch.nn.Parameter( + layer.w13_weight_packed.transpose(1, 2).contiguous().view( + torch.uint8), + requires_grad=False) + layer.w2_weight_packed = torch.nn.Parameter( + layer.w2_weight_packed.transpose(1, + 2).contiguous().view(torch.uint8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + layer.w13_weight_scale.transpose(1, 2).contiguous(), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + layer.w2_weight_scale.transpose(1, 2).contiguous(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a16=self.num_bits == 4, + use_int8_w8a16=self.num_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size])