[Bugfix] Fix compute datatype for cutlass 3.x epilogues (#5931)
This commit is contained in:
parent
b2c620230a
commit
6a2d659d28
@ -144,14 +144,14 @@ struct ScaledEpilogueBias
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, ElementD,
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, ElementD,
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using BiasDescriptor =
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
bias: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
@ -43,23 +58,19 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
if bias:
|
||||
# bias term should be > 1 so that the absolute tolerance can catch it
|
||||
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
|
||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||
dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||
dtype=torch.float32))
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
bias_t = 0
|
||||
bias = None
|
||||
|
||||
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)) +
|
||||
bias_t).to(out_dtype)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
|
||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
@ -67,7 +78,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
bias: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||
dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||
dtype=torch.float32))
|
||||
|
||||
if bias:
|
||||
# bias term should be > 1 so that the absolute tolerance can catch it
|
||||
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
bias_t = 0
|
||||
bias = None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)) +
|
||||
bias_t).to(dtype=out_dtype)
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
@ -102,12 +110,12 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@ -115,70 +123,70 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: bool):
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
bias,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: bool):
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
bias,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
|
||||
torch.bfloat16, device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
bias: bool, device: str):
|
||||
use_bias: bool, device: str):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
bias,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device)
|
||||
|
||||
@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
bias: bool):
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
bias: bool):
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||
bias)
|
||||
use_bias)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@ -229,9 +238,11 @@ def test_cutlass_subset():
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user