Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-04-13 08:55:18 -06:00 committed by GitHub
parent f49e5aff11
commit d085a44082
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -88,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR): and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError( raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales " "For FP8 Fused MoE layers, we require per tensor "
"for weights and activations are supported. Found " "or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}") f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
@ -123,24 +132,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# They will be combined to a single scale after weight loading. # Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, # They are combined to a single scale after weight loading.
2, w13_weight_scale = torch.nn.Parameter(torch.ones(
dtype=torch.float32), num_experts, 2, dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
dtype=torch.float32), w13_weight_scale = torch.nn.Parameter(torch.ones(
requires_grad=False) num_experts,
layer.register_parameter("w2_weight_scale", w2_weight_scale) 2 * intermediate_size_per_partition,
# Add the quantization method used (per tensor/grouped/channel) 1,
# to ensure the weight scales are loaded in properly dtype=torch.float32),
extra_weight_attrs.update( requires_grad=False)
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter(torch.ones(
set_weight_attrs(w2_weight_scale, extra_weight_attrs) num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES # INPUT_SCALES
if self.static_input_scales: if self.static_input_scales:
@ -163,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
if self.static_input_scales: if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None): if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError( raise ValueError(
"QuantConfig has static quantization, but found " "QuantConfig has static quantization, but found "
@ -204,24 +230,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False) requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert. # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# We take the max then dequant and requant each expert. # for w13 per expert. Use max then dequant and requant each expert.
assert layer.w13_weight_scale is not None if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
shard_size = layer.intermediate_size_per_partition assert layer.w13_weight_scale is not None
max_w13_scales = layer.w13_weight_scale.max(dim=1).values shard_size = layer.intermediate_size_per_partition
for expert_id in range(layer.local_num_experts): max_w13_scales = layer.w13_weight_scale.max(dim=1).values
start = 0 for expert_id in range(layer.local_num_experts):
for shard_id in range(2): start = 0
dq_weight = per_tensor_dequantize( for shard_id in range(2):
layer.w13_weight[expert_id][start:start + shard_size, :], dq_weight = per_tensor_dequantize(
layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][start:start +
layer.w13_weight[expert_id][ shard_size, :],
start:start + shard_size, :], _ = ops.scaled_fp8_quant( layer.w13_weight_scale[expert_id][shard_id])
dq_weight, max_w13_scales[expert_id]) layer.w13_weight[expert_id][
start += shard_size start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, start += shard_size
requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply( def apply(
self, self,
@ -265,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,