[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-04-10 01:08:47 -06:00 committed by GitHub
parent a5d11a54dc
commit c70cf0fe06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 254 additions and 15 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.31
- name: "exact_match,flexible-extract"
value: 0.47
limit: 1319
num_fewshot: 5

View File

@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base-FP8.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "

View File

@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(
self, layer.activation, layer.expert_map)
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return None
@classmethod

View File

@ -6,7 +6,8 @@ from typing import Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsWNA16MoEMethod"
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
activation: str,
expert_map: Optional[torch.Tensor],
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if ((layer.local_num_experts >= 16
or layer.params_dtype != torch.float16) and
weight_quant.actorder not in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)
and weight_quant.strategy in QuantizationStrategy.GROUP):
return CompressedTensorsWNA16MoEMethod(quant_config)
else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and activation == "silu" and expert_map is None):
and layer.activation == "silu" and layer.expert_map is None):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@ -482,7 +496,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
@ -823,3 +837,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter(
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
torch.uint8),
requires_grad=False)
layer.w2_weight_packed = torch.nn.Parameter(
layer.w2_weight_packed.transpose(1,
2).contiguous().view(torch.uint8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])