[Bugfix] Fix compute datatype for cutlass 3.x epilogues (#5931)
This commit is contained in:
parent
b2c620230a
commit
6a2d659d28
@ -144,14 +144,14 @@ struct ScaledEpilogueBias
|
|||||||
using ScaleB = typename SUPER::ScaleB;
|
using ScaleB = typename SUPER::ScaleB;
|
||||||
|
|
||||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
cutlass::multiplies, ElementD, ElementD,
|
cutlass::multiplies, float, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
using EVTCompute0 =
|
using EVTCompute0 =
|
||||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
cutlass::multiply_add, ElementD, ElementD,
|
cutlass::multiply_add, ElementD, float,
|
||||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
using BiasDescriptor =
|
using BiasDescriptor =
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_cutlass.py`.
|
Run `pytest tests/kernels/test_cutlass.py`.
|
||||||
"""
|
"""
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
|
|||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def baseline_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: Type[torch.dtype],
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
output = (scale_a * (scale_b * (torch.mm(
|
||||||
|
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_gemm_helper(m: int,
|
def cutlass_fp8_gemm_helper(m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
per_token_act_quant: bool,
|
per_token_act_quant: bool,
|
||||||
per_out_channel_weight_quant: bool,
|
per_out_channel_weight_quant: bool,
|
||||||
bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
# Test for a cutlass kernel with per-token activation quantization
|
# Test for a cutlass kernel with per-token activation quantization
|
||||||
@ -43,23 +58,19 @@ def cutlass_fp8_gemm_helper(m: int,
|
|||||||
m_a_scales = m if per_token_act_quant else 1
|
m_a_scales = m if per_token_act_quant else 1
|
||||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||||
|
|
||||||
scale_a = (torch.randn(
|
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
dtype=torch.float32))
|
||||||
scale_b = (torch.randn(
|
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
dtype=torch.float32))
|
||||||
if bias:
|
if use_bias:
|
||||||
# bias term should be > 1 so that the absolute tolerance can catch it
|
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||||
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
|
|
||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
|
|
||||||
else:
|
else:
|
||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
bias = None
|
||||||
bias_t = 0
|
|
||||||
|
|
||||||
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
|
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
scale_b * b.to(dtype=torch.float32)) +
|
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
bias_t).to(out_dtype)
|
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
|
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_int8_gemm_helper(m: int,
|
def cutlass_int8_gemm_helper(m: int,
|
||||||
@ -67,7 +78,7 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
k: int,
|
k: int,
|
||||||
per_token_act_quant: bool,
|
per_token_act_quant: bool,
|
||||||
per_out_channel_weight_quant: bool,
|
per_out_channel_weight_quant: bool,
|
||||||
bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
# Test for a cutlass kernel with per-token activation quantization
|
# Test for a cutlass kernel with per-token activation quantization
|
||||||
@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
m_a_scales = m if per_token_act_quant else 1
|
m_a_scales = m if per_token_act_quant else 1
|
||||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||||
|
|
||||||
scale_a = (torch.randn(
|
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||||
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
|
dtype=torch.float32))
|
||||||
scale_b = (torch.randn(
|
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
dtype=torch.float32))
|
||||||
|
|
||||||
if bias:
|
if use_bias:
|
||||||
# bias term should be > 1 so that the absolute tolerance can catch it
|
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||||
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
|
|
||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
|
|
||||||
else:
|
else:
|
||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
bias = None
|
||||||
bias_t = 0
|
|
||||||
|
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
|
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
|
|
||||||
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
|
|
||||||
scale_b * b.to(dtype=torch.float32)) +
|
|
||||||
bias_t).to(dtype=out_dtype)
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
@ -102,12 +110,12 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(capability < 89,
|
@pytest.mark.skipif(capability < 89,
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||||
per_out_ch: bool, bias: bool):
|
per_out_ch: bool, use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
|
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||||
@ -115,70 +123,70 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
|||||||
@pytest.mark.parametrize("k", [128, 496, 1024])
|
@pytest.mark.parametrize("k", [128, 496, 1024])
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||||
per_out_ch: bool, bias: bool):
|
per_out_ch: bool, use_bias: bool):
|
||||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
|
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: Type[torch.dtype],
|
||||||
bias: bool):
|
use_bias: bool):
|
||||||
cutlass_int8_gemm_helper(512,
|
cutlass_int8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
per_out_ch,
|
per_out_ch,
|
||||||
bias,
|
use_bias,
|
||||||
out_dtype=out_dtype)
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(capability < 89,
|
@pytest.mark.skipif(capability < 89,
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: Type[torch.dtype],
|
||||||
bias: bool):
|
use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(512,
|
cutlass_fp8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
per_out_ch,
|
per_out_ch,
|
||||||
bias,
|
use_bias,
|
||||||
out_dtype=out_dtype)
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.skipif(capability < 89,
|
@pytest.mark.skipif(capability < 89,
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||||
bias: bool, device: str):
|
use_bias: bool, device: str):
|
||||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
|
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
|
||||||
torch.bfloat16, device)
|
torch.bfloat16, device)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||||
bias: bool, device: str):
|
use_bias: bool, device: str):
|
||||||
cutlass_int8_gemm_helper(512,
|
cutlass_int8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
512,
|
512,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
per_out_ch,
|
per_out_ch,
|
||||||
bias,
|
use_bias,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
|||||||
# kernel must handle any M thrown at it.
|
# kernel must handle any M thrown at it.
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
@pytest.mark.skipif(capability < 89,
|
@pytest.mark.skipif(capability < 89,
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||||
bias: bool):
|
use_bias: bool):
|
||||||
for nk in range(32, 128, 32):
|
for nk in range(32, 128, 32):
|
||||||
for m in range(1, 128):
|
for m in range(1, 128):
|
||||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
|
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||||
|
use_bias)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||||
bias: bool):
|
use_bias: bool):
|
||||||
for nk in range(32, 128, 32):
|
for nk in range(32, 128, 32):
|
||||||
for m in range(1, 128):
|
for m in range(1, 128):
|
||||||
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||||
bias)
|
use_bias)
|
||||||
|
|
||||||
|
|
||||||
# Test working with a subset of A and B
|
# Test working with a subset of A and B
|
||||||
@ -229,9 +238,11 @@ def test_cutlass_subset():
|
|||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
out_dtype=torch.bfloat16)
|
out_dtype=torch.bfloat16)
|
||||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
baseline = baseline_scaled_mm(a,
|
||||||
scale_b *
|
b,
|
||||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user