[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a5d11a54dc
commit
c70cf0fe06
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.31
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.47
|
||||
limit: 1319
|
||||
num_fewshot: 5
|
@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
loaded_weight = loaded_weight.t().contiguous() if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
|
@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(
|
||||
self, layer.activation, layer.expert_map)
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -6,7 +6,8 @@ from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsMoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
]
|
||||
|
||||
|
||||
@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
activation: str,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
layer: torch.nn.Module,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
# Prefer to use the non-marlin kernel when:
|
||||
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
|
||||
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
|
||||
# 3. Actorder is not group/dynamic (g_idx is unsupported)
|
||||
# 4. Scaled are grouped (channelwise is unsupported)
|
||||
if ((layer.local_num_experts >= 16
|
||||
or layer.params_dtype != torch.float16) and
|
||||
weight_quant.actorder not in (ActivationOrdering.GROUP,
|
||||
ActivationOrdering.DYNAMIC)
|
||||
and weight_quant.strategy in QuantizationStrategy.GROUP):
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
else:
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
and activation == "silu" and expert_map is None):
|
||||
and layer.activation == "silu" and layer.expert_map is None):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
@ -482,7 +496,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -823,3 +837,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.num_bits,
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
# channelwise is not supported by this kernel
|
||||
assert config.strategy == "group"
|
||||
self.group_size = config.group_size
|
||||
# grouped actorder isn't supported by this kernel
|
||||
assert config.actorder != "group"
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": False})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Reconfigure packed weights and scales to match moe_wna16 format
|
||||
layer.w13_weight_packed = torch.nn.Parameter(
|
||||
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
|
||||
torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_packed = torch.nn.Parameter(
|
||||
layer.w2_weight_packed.transpose(1,
|
||||
2).contiguous().view(torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
layer.w13_weight_scale.transpose(1, 2).contiguous(),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
layer.w2_weight_scale.transpose(1, 2).contiguous(),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=self.num_bits == 4,
|
||||
use_int8_w8a16=self.num_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, self.group_size])
|
||||
|
Loading…
x
Reference in New Issue
Block a user