2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-01-26 06:59:58 -05:00
|
|
|
"""Tests for sparse cutlass kernels
|
|
|
|
|
|
|
|
Run `pytest tests/kernels/test_semi_structured.py`.
|
|
|
|
"""
|
|
|
|
from typing import Tuple, Type
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
|
|
sparse_cutlass_supported)
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
|
|
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
|
|
|
|
|
|
|
CUDA_DEVICES = [
|
|
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
|
|
]
|
|
|
|
|
|
|
|
capability = current_platform.get_device_capability()
|
|
|
|
capability = capability[0] * 10 + capability[1]
|
|
|
|
|
|
|
|
|
|
|
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
return tensor.to(dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
return tensor.to(dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
def prune_to_2_4(tensor):
|
|
|
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
|
|
|
original_shape = tensor.shape
|
|
|
|
reshaped = tensor.reshape(-1, 4)
|
|
|
|
|
|
|
|
# Get indices of top 2 absolute values in each group of 4
|
|
|
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
|
|
|
|
|
|
|
# Create binary mask
|
|
|
|
mask = torch.zeros_like(reshaped)
|
|
|
|
mask.scatter_(dim=1,
|
|
|
|
index=indices,
|
|
|
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
|
|
|
|
|
|
|
# Apply mask and reshape back
|
|
|
|
pruned = reshaped * mask
|
|
|
|
|
|
|
|
# Turn all -0.0 to 0.0
|
|
|
|
pruned[pruned == -0.0] = 0.0
|
|
|
|
|
|
|
|
return pruned.reshape(original_shape)
|
|
|
|
|
|
|
|
|
2025-02-13 19:01:14 -05:00
|
|
|
# This function checks that applying an identity matrix multiplication
|
|
|
|
# to the compressed weights yields the original uncompressed weights.
|
|
|
|
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
|
|
|
b_compressed: torch.Tensor,
|
|
|
|
b_metadata: torch.Tensor):
|
|
|
|
|
|
|
|
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
|
|
|
# same dtype as its inputs. This line addresses that constraint while
|
|
|
|
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
|
|
|
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
|
|
|
|
|
|
|
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
|
|
|
|
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
|
|
|
|
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
|
|
|
|
b_compressed,
|
|
|
|
b_metadata,
|
|
|
|
eye_scale,
|
|
|
|
eye_scale,
|
|
|
|
out_dtype=out_dtype)
|
|
|
|
|
|
|
|
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
|
|
|
|
|
|
|
|
2025-01-26 06:59:58 -05:00
|
|
|
def make_rand_sparse_tensors(
|
|
|
|
dtype: torch.dtype, m: int, n: int, k: int
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
2025-02-13 19:01:14 -05:00
|
|
|
a = torch.randn((m, k), device='cuda')
|
|
|
|
b = torch.randn((n, k), device='cuda').t()
|
|
|
|
|
|
|
|
if dtype == torch.int8:
|
|
|
|
# ensure A and B aren't all zeros after rounding
|
|
|
|
a = a * 5.0
|
|
|
|
b = b * 5.0
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
b = prune_to_2_4(b.t()).t()
|
|
|
|
|
|
|
|
if dtype == torch.int8:
|
|
|
|
a, b = to_int8(a), to_int8(b)
|
|
|
|
elif dtype == torch.float8_e4m3fn:
|
|
|
|
a, b = to_fp8(a), to_fp8(b)
|
|
|
|
elif dtype == torch.float16:
|
|
|
|
a, b = to_fp16(a), to_fp16(b)
|
|
|
|
elif dtype == torch.bfloat16:
|
|
|
|
a, b = to_bf16(a), to_bf16(b)
|
|
|
|
else:
|
|
|
|
raise ValueError("unsupported dtype")
|
|
|
|
|
|
|
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
2025-02-13 19:01:14 -05:00
|
|
|
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
# Compressed B, Metadata, Original A, B
|
|
|
|
return b_compressed, e, a, b
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
|
|
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
|
|
|
# Test working with a subset of A and B for sparse matmul
|
|
|
|
def test_cutlass_sparse_subset():
|
|
|
|
|
|
|
|
big_m = 1024
|
|
|
|
m, n, k = 512, 512, 512
|
|
|
|
|
|
|
|
# Create tensors
|
|
|
|
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
|
|
|
big_m, n, k)
|
|
|
|
a = whole_a[0:m, 0:k]
|
|
|
|
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
|
|
|
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
|
|
|
|
|
|
|
out = ops.cutlass_scaled_sparse_mm(a,
|
|
|
|
b_comp,
|
|
|
|
e,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
|
|
|
out_dtype=torch.bfloat16)
|
|
|
|
baseline = baseline_scaled_mm(a,
|
|
|
|
b,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
|
|
|
out_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
|
|
|
|
|
|
|
|
|
|
|
MNK_FACTORS = [
|
|
|
|
(1, 256, 128),
|
|
|
|
(1, 16384, 1024),
|
|
|
|
(1, 24576, 512),
|
|
|
|
(16, 256, 512),
|
|
|
|
(16, 16384, 128),
|
|
|
|
(16, 24576, 4096),
|
|
|
|
(32, 8192, 4096),
|
|
|
|
(32, 16384, 4096),
|
|
|
|
(33, 1024, 1024),
|
|
|
|
(33, 8192, 128),
|
|
|
|
(64, 2048, 512),
|
|
|
|
(64, 16384, 1024),
|
|
|
|
(100, 8192, 512),
|
|
|
|
(128, 32768, 4096),
|
|
|
|
(256, 4096, 4096),
|
|
|
|
(512, 256, 1024),
|
|
|
|
(512, 8192, 4096),
|
|
|
|
(512, 16384, 128),
|
|
|
|
(512, 24576, 128),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# Test working with a subset of A and B for sparse matmul
|
|
|
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
|
|
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
2025-02-13 19:01:14 -05:00
|
|
|
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
2025-01-26 06:59:58 -05:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
2025-02-13 19:01:14 -05:00
|
|
|
@pytest.mark.parametrize("use_bias", [True, False])
|
|
|
|
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
|
|
|
|
use_bias: bool):
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
# Create tensors
|
|
|
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
|
|
|
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
|
|
|
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
|
|
|
|
2025-02-13 19:01:14 -05:00
|
|
|
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
|
|
|
|
|
2025-01-26 06:59:58 -05:00
|
|
|
out = ops.cutlass_scaled_sparse_mm(a,
|
|
|
|
b_comp,
|
|
|
|
e,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype=dtype,
|
|
|
|
bias=bias)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
2025-02-13 19:01:14 -05:00
|
|
|
baseline = baseline_scaled_mm(a,
|
|
|
|
b,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
|
|
|
out_dtype=dtype,
|
|
|
|
bias=bias)
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
|
|
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
|
|
|
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
|
|
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
|
|
|
reason="FP8 is not supported on this GPU type.")
|
2025-02-13 19:01:14 -05:00
|
|
|
@pytest.mark.parametrize("use_bias", [True, False])
|
|
|
|
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
# Create tensors
|
|
|
|
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
|
|
|
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
|
|
|
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
bias = torch.rand(
|
|
|
|
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
out = ops.cutlass_scaled_sparse_mm(a,
|
|
|
|
b_comp,
|
|
|
|
e,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype=out_dtype,
|
|
|
|
bias=bias)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
baseline = baseline_scaled_mm(a,
|
|
|
|
b,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype=out_dtype,
|
|
|
|
bias=bias)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
2025-02-13 19:01:14 -05:00
|
|
|
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
|
|
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
|
|
|
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
|
|
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
|
|
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
|
|
|
@pytest.mark.parametrize("use_bias", [True, False])
|
|
|
|
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
|
|
|
per_out_ch: bool, use_bias: bool):
|
|
|
|
|
|
|
|
# Create tensors
|
|
|
|
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
|
|
|
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
|
|
|
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
bias = torch.rand(
|
|
|
|
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
out = ops.cutlass_scaled_sparse_mm(a,
|
|
|
|
b_comp,
|
|
|
|
e,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype=out_dtype,
|
|
|
|
bias=bias)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
baseline = baseline_scaled_mm(a,
|
|
|
|
b,
|
|
|
|
scale_a,
|
|
|
|
scale_b,
|
2025-02-13 19:01:14 -05:00
|
|
|
out_dtype=out_dtype,
|
|
|
|
bias=bias)
|
2025-01-26 06:59:58 -05:00
|
|
|
|
|
|
|
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|