Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f49e5aff11
commit
d085a44082
@ -88,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
"input_activations")
|
"input_activations")
|
||||||
|
|
||||||
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
|
and self.input_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR)
|
||||||
|
per_channel = (
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||||
|
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||||
|
if not (per_tensor or per_channel):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
"For FP8 Fused MoE layers, we require per tensor "
|
||||||
"for weights and activations are supported. Found "
|
"or channelwise, dynamic per token quantization. Found "
|
||||||
f"{self.weight_quant}, {self.input_quant}")
|
f"{self.weight_quant}, {self.input_quant}")
|
||||||
|
|
||||||
self.static_input_scales = not self.input_quant.dynamic
|
self.static_input_scales = not self.input_quant.dynamic
|
||||||
|
if self.static_input_scales and per_channel:
|
||||||
|
raise ValueError(
|
||||||
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -123,22 +132,38 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
|
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
# They will be combined to a single scale after weight loading.
|
# They are combined to a single scale after weight loading.
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
2,
|
num_experts, 2, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
1,
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||||
# to ensure the weight scales are loaded in properly
|
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
@ -163,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
if self.static_input_scales:
|
if self.static_input_scales:
|
||||||
|
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"QuantConfig has static quantization, but found "
|
"QuantConfig has static quantization, but found "
|
||||||
@ -204,8 +230,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
|
||||||
# We take the max then dequant and requant each expert.
|
# for w13 per expert. Use max then dequant and requant each expert.
|
||||||
|
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
@ -213,13 +240,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
layer.w13_weight[expert_id][start:start +
|
||||||
|
shard_size, :],
|
||||||
layer.w13_weight_scale[expert_id][shard_id])
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
layer.w13_weight[expert_id][
|
layer.w13_weight[expert_id][
|
||||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
dq_weight, max_w13_scales[expert_id])
|
dq_weight, max_w13_scales[expert_id])
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
@ -265,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_quant.strategy ==
|
||||||
|
QuantizationStrategy.CHANNEL,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user