Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f49e5aff11
commit
d085a44082
@ -88,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
|
||||
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
|
||||
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR)
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
if not (per_tensor or per_channel):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
"For FP8 Fused MoE layers, we require per tensor "
|
||||
"or channelwise, dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales and per_channel:
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||
"channelwise, dynamic per token quantization.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -123,24 +132,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They are combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
@ -163,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
@ -204,24 +230,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
|
||||
# for w13 per expert. Use max then dequant and requant each expert.
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -265,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy ==
|
||||
QuantizationStrategy.CHANNEL,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user