99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Cutlass bench utils
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
|
|
import vllm._custom_ops as ops
|
|
|
|
|
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
return torch.round(tensor.clamp(
|
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
|
|
|
|
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
|
|
|
|
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
|
return tensor.to(dtype=torch.bfloat16)
|
|
|
|
|
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
|
return tensor.to(dtype=torch.float16)
|
|
|
|
|
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
a = torch.randn((m, k), device='cuda') * 5
|
|
b = torch.randn((n, k), device='cuda').t() * 5
|
|
|
|
if dtype == torch.int8:
|
|
return to_int8(a), to_int8(b)
|
|
if dtype == torch.float8_e4m3fn:
|
|
return to_fp8(a), to_fp8(b)
|
|
|
|
raise ValueError("unsupported dtype")
|
|
|
|
|
|
def prune_to_2_4(tensor):
|
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
|
original_shape = tensor.shape
|
|
reshaped = tensor.reshape(-1, 4)
|
|
|
|
# Get indices of top 2 absolute values in each group of 4
|
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
|
|
|
# Create binary mask
|
|
mask = torch.zeros_like(reshaped)
|
|
mask.scatter_(dim=1,
|
|
index=indices,
|
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
|
|
|
# Apply mask and reshape back
|
|
pruned = reshaped * mask
|
|
|
|
# Turn all -0.0 to 0.0
|
|
pruned[pruned == -0.0] = 0.0
|
|
|
|
return pruned.reshape(original_shape)
|
|
|
|
|
|
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
a = torch.randn((m, k), device='cuda') * 5
|
|
b = torch.randn((n, k), device='cuda').t() * 5
|
|
|
|
b = prune_to_2_4(b.t()).t()
|
|
|
|
if dtype == torch.int8:
|
|
a, b = to_int8(a), to_int8(b)
|
|
elif dtype == torch.float8_e4m3fn:
|
|
a, b = to_fp8(a), to_fp8(b)
|
|
elif dtype == torch.float16:
|
|
a, b = to_fp16(a), to_fp16(b)
|
|
elif dtype == torch.bfloat16:
|
|
a, b = to_bf16(a), to_bf16(b)
|
|
else:
|
|
raise ValueError("unsupported dtype")
|
|
|
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
|
|
|
# Compressed B, Metadata, Original A, B
|
|
return b_compressed, e, a, b
|
|
|
|
|
|
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
|
m: int, n: int, k: int) -> \
|
|
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
|
ABs = []
|
|
for _ in range(num_tensors):
|
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
|
if b_comp is not None:
|
|
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
|
BComps, Es, As, Bs = zip(*ABs)
|
|
return list(BComps), list(Es), list(As), list(Bs)
|