From d085a4408244502dd496cae0ccbc524352e87ef8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sun, 13 Apr 2025 08:55:18 -0600 Subject: [PATCH] Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537) Signed-off-by: mgoin --- .../compressed_tensors_moe.py | 107 +++++++++++------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d2299965..628724c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -88,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR): + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " f"{self.weight_quant}, {self.input_quant}") self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -123,24 +132,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -163,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " @@ -204,24 +230,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) def apply( self, @@ -265,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=layer.w13_weight_scale,