[Kernel][Bugfix] Refactor and Fix CUTLASS 2:4 Sparse Kernels (#13198)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-02-13 19:01:14 -05:00 committed by GitHub
parent 2344192a55
commit c1e37bf71b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 576 additions and 473 deletions

View File

@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.7.0 GIT_TAG v3.7.0
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
@ -266,7 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
"csrc/cutlass_extensions/common.cpp") "csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
@ -359,8 +360,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
@ -476,7 +476,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)

View File

@ -16,6 +16,30 @@ namespace vllm::c3x {
using namespace cute; using namespace cute;
template <typename T>
struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;
template <typename... Args>
static ArgumentType prepare_args(Args... args) {
return {};
}
};
/* /*
* This class provides the common load descriptors for the * This class provides the common load descriptors for the
* ScaledEpilogue[...] classes * ScaledEpilogue[...] classes
@ -174,6 +198,49 @@ struct ScaledEpilogueBias
} }
}; };
/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/* /*
* This epilogue directly supports per-tensor azp in int32 form. * This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj * As opposed to the per-token epilogue below, this epilogue only has an azp_adj

View File

@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
torch::Tensor& e, torch::Tensor const& a);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@ -53,12 +53,17 @@ struct cutlass_3x_gemm {
using EVTCompute = typename Epilogue::EVTCompute; using EVTCompute = typename Epilogue::EVTCompute;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
EpilogueSchedule, EVTCompute>::CollectiveOp; AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage); sizeof(typename CollectiveEpilogue::SharedStorage);
@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
using CollectiveMainloop = using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16, ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, 16, ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape, ElementAcc, TileShape, ClusterShape,
Stages, Stages,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;

View File

@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
// clang-format off // clang-format off
using RowMajor = typename cutlass::layout::RowMajor; using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType = using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, 4, float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp, ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch, Arch,
TileShape, WarpShape, InstructionShape, TileShape, WarpShape, InstructionShape,

View File

@ -1,165 +0,0 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
int m = a.size(0);
int k = a.size(1);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using ElementA = ElementA_;
using LayoutTagA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementA;
using LayoutTagB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = float;
using LayoutTagC = cutlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Core kernel configurations
using ElementAccumulator = ElementAcc_;
using TileShape = Shape<_128, _128, _128>;
using TileShapeRef = Shape<_128, _128, _64>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = typename std::conditional<
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>::type;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using ProblemShape = Shape<int, int, int, int>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
AlignmentC, ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
// The n (=1) dimension does not matter for the compressor
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
// Offline compressor kernel
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig,
cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
auto [M, N, K, L] = prob_shape;
StrideA stride_A;
stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
CompressorUtility compressor_utility(prob_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return true;
}
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kInt8) {
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}
#endif

View File

@ -0,0 +1,90 @@
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename Gemm>
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
using GemmKernel = typename Gemm::KernelType;
using ElementA = typename Gemm::ElementAB;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
int m = a.size(0);
int k = a.size(1);
using ProblemShape = typename GemmKernel::ProblemShape;
ProblemShape prob_shape{m, 1, k, 1};
int64_t lda = a.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
StrideA a_stride{lda, Int<1>{}, 0};
using CompressorUtility = typename Gemm::CompressorUtility;
CompressorUtility compressor_utility(prob_shape, a_stride);
// Allocate buffers for the metadata E and the compressed matrix A
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int MC = compressor_utility.get_tensorA_m_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto const a_meta_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto const a_nzs_options =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
using Compressor = typename Gemm::Compressor;
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return {a_meta, a_nzs};
}
#endif

View File

@ -1,42 +0,0 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
a_nzs.size(1) * 2 == a.size(1) &&
a_meta.size(1) * 2 * 4 == a.size(1));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
a_meta.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@ -9,17 +9,30 @@
using namespace cute; using namespace cute;
using namespace vllm; using namespace vllm;
struct GemmCallerTraits {
using return_type = void;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
}
};
struct GemmCompressorTraits {
using return_type = CompressorResult;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
}
};
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta, static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
using Cutlass3xGemm8 = using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const n = bt_nzs.size(0);
uint32_t const m = a.size(0); // Batch size
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) { if (mp2 <= 64) {
if (n == 28672) { if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>( return DispatchFunc::template invoke<Cutlass3xGemm2>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 4096 || n == 6144) { } else if (n == 4096 || n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>( return DispatchFunc::template invoke<Cutlass3xGemm1>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
if (n == 4096) { if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>( return DispatchFunc::template invoke<Cutlass3xGemm3>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 28672) { } else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>( return DispatchFunc::template invoke<Cutlass3xGemm5>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 6144) { } else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>( return DispatchFunc::template invoke<Cutlass3xGemm4>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
if (n == 4096) { if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>( return DispatchFunc::template invoke<Cutlass3xGemm6>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 28672) { } else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>( return DispatchFunc::template invoke<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 6144) { } else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>( return DispatchFunc::template invoke<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else { } else {
if (n == 6144 || n == 28672) { if (n == 6144 || n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>( return DispatchFunc::template invoke<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 4096) { } else if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>( return DispatchFunc::template invoke<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
// Otherwise the default heuristic // Otherwise the default heuristic
if (mp2 <= 64) { if (mp2 <= 64) {
// n in [1, 64] // n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>( return DispatchFunc::template invoke<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// n in (64, 128] // n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>( return DispatchFunc::template invoke<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
// n in (128, 256] // n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>( return DispatchFunc::template invoke<Cutlass3xGemmM256>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
// n in (256, inf) // n in (256, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>( return DispatchFunc::template invoke<Cutlass3xGemmM512>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::half_t>());
TORCH_CHECK(a.dtype() == torch::kFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf) return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>( std::forward<Args>(args)...);
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} }
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta, static_assert(std::is_same_v<InType, int8_t>);
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
typename sm90_int8_config_M32_NSmall<InType, OutType, typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm; Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192; bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) { if (mp2 <= 32) {
// m in [1, 32] // m in [1, 32]
if (is_small_n) { if (is_small_n) {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>( return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>( return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 64) { } else if (mp2 <= 64) {
// m in (32, 64] // m in (32, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>( return DispatchFunc::template invoke<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// m in (64, 128] // m in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>( return DispatchFunc::template invoke<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
// m in (128, inf) // m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>( return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
// Dispatch to GEMM implementations based on element types
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
@ -217,19 +193,24 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
torch::Tensor const& bt_nzs, torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta, torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
uint32_t const m = out.size(0);
uint32_t const n = out.size(1);
// TODO: add dispatch functions to all of these
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) { if (a.dtype() == torch::kInt8) {
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue, GemmCallerTraits>(
out, a, bt_nzs, bt_meta, m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
out, a, bt_nzs, bt_meta, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else if (a.dtype() == torch::kFloat8_e4m3fn) { } else if (a.dtype() == torch::kFloat8_e4m3fn) {
@ -237,47 +218,34 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue,
out, a, bt_nzs, bt_meta, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<
cutlass::half_t, Epilogue>( cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
out, a, bt_nzs, bt_meta, m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else if (a.dtype() == torch::kFloat16) { } else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue>( return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
out, a, bt_nzs, bt_meta, Epilogue, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
}
} else { // a.dtype() == torch::kBFloat16 } else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16); TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
TORCH_CHECK(out.dtype() == torch::kBFloat16);
if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_16bit_dispatch<
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
cutlass::bfloat16_t, Epilogue>( m, n, out, a, bt_nzs, bt_meta,
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} }
} }
@ -287,17 +255,53 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "CUTLASS scaled_mm bias dtype must match output dtype ",
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>( out.dtype());
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias); return cutlass_scaled_sparse_mm_sm90_epilogue<
c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
a_scales, *bias);
} else { } else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>( return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, b_scales, a_scales); out, a, bt_nzs, bt_meta, b_scales, a_scales);
} }
} }
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
// These m and n variables are fordispatching to different GEMM algorithms.
uint32_t const m = 1; // Set M to 1 for compression
uint32_t const n = a.size(1);
// Note: For correctess, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
if (a.dtype() == torch::kInt8) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_gemm_sm90_16bit_dispatch<
cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else {
TORCH_CHECK(a.dtype() == torch::kBFloat16,
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
"and bf16 datatypes");
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
}
}
#endif #endif

View File

@ -1,3 +1,5 @@
#pragma once
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
@ -12,6 +14,9 @@
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "core/math.hpp" #include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@ -22,7 +27,7 @@
using namespace cute; using namespace cute;
/* /*
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API, This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later. for NVIDIA GPUs with sm90a (Hopper) or later.
*/ */
@ -45,17 +50,20 @@ struct enable_sm90_or_later : Kernel {
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
/*
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
* for SM90 Hopper systems.
*/
template <typename ElementAB_, typename ElementD_, template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, typename AccType, typename EpilogueSchedule>
typename TileSchedule = cutlass::gemm::PersistentScheduler,
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
struct cutlass_sparse_3x_gemm { struct cutlass_sparse_3x_gemm {
static const GemmUniversalMode Mode = Mode_;
using ElementAB = ElementAB_; using ElementAB = ElementAB_;
using ElementD = ElementD_; using ElementD = ElementD_;
using ElementAcc = AccType; using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor = using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor< cutlass::epilogue::collective::detail::EpilogueDescriptor<
@ -66,30 +74,22 @@ struct cutlass_sparse_3x_gemm {
using ElementC = void; using ElementC = void;
using LayoutC = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using LayoutC_Transpose = using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type; typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using EVTCompute = typename Epilogue::EVTCompute; using EVTCompute = typename Epilogue::EVTCompute;
static constexpr int AlignmentA = // These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value; 128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentB = static constexpr int AlignmentCD = 4;
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD =
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD, ElementAcc, float, ElementC, LayoutC_Transpose, AlignmentCD, ElementD,
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule, LayoutC_Transpose, AlignmentCD, EpilogueSchedule,
EVTCompute>::CollectiveOp; EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
@ -101,8 +101,8 @@ struct cutlass_sparse_3x_gemm {
using CollectiveMainloop = using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentA, ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentB, ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape, ElementAcc, TileShape, ClusterShape,
Stages, Stages,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
@ -110,11 +110,100 @@ struct cutlass_sparse_3x_gemm {
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
TileSchedule>>; cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
// Sparse compressor definitions
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
using LayoutTagA = cutlass::layout::RowMajor;
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
SparseConfig, cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
}; };
/*
* This class defines kernel to compress a 2:4 sparse matrix.
* The particular format is defined by the Gemm template parameter,
* which is a cutlass_sparse_3x_gemm.
*/
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename Gemm>
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
using GemmKernel = typename Gemm::KernelType;
using ElementA = typename Gemm::ElementAB;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
int m = a.size(0);
int k = a.size(1);
using ProblemShape = typename GemmKernel::ProblemShape;
ProblemShape prob_shape{m, 1, k, 1};
int64_t lda = a.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
StrideA a_stride{lda, Int<1>{}, 0};
using CompressorUtility = typename Gemm::CompressorUtility;
CompressorUtility compressor_utility(prob_shape, a_stride);
// Allocate buffers for the metadata E and the compressed matrix A
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int MC = compressor_utility.get_tensorA_m_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto const a_meta_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto const a_nzs_options =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
using Compressor = typename Gemm::Compressor;
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return {a_meta, a_nzs};
}
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs, torch::Tensor const& bt_nzs,
@ -126,27 +215,25 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Interface stride expected from the argument a (will get transposed) // Interface stride expected from the argument a (will get transposed)
// We compute C^T = B^T * A^T, but we assume B is transposed before // We compute C^T = B^T * A^T, but we assume B is transposed before
// compression and hence the bt_* naming // compression and hence the bt_* naming
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using LayoutD = cutlass::layout::RowMajor;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; // M, N, K after transposition
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>; int32_t m = out.size(1);
int32_t n = out.size(0);
int32_t k = a.size(1);
auto layout_A = make_cute_layout<StrideA>(a, "A"); int64_t lda = a.stride(0);
auto layout_D = make_cute_layout<StrideD>(out, "D"); int64_t ldc = out.stride(0);
// Transpose A and D using StrideA = Stride<int64_t, Int<1>, int64_t>;
// A doesn't need to be transposed since cutlass expects a NxK matrix using StrideC = Stride<Int<1>, int64_t, int64_t>;
// for B (which is At)
auto stride_At = layout_A.stride(); StrideA a_stride{lda, Int<1>{}, Int<0>{}};
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); StrideC c_stride{Int<1>{}, ldc, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{ typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
static_cast<int>(size<1>(layout_A)), 1};
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
@ -158,13 +245,13 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr()); auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr()); auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{ typename GemmKernel::MainloopArguments mainloop_args{
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout}; b_ptr, b_layout, a_ptr, a_stride, e_ptr, e_layout};
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args( Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...), std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, stride_Dt, c_ptr, stride_Dt}; c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args}; prob_shape, mainloop_args, epilogue_args};
@ -185,6 +272,10 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
//////////////////////////////////////////////////
// Gemm Configs are defined below
//////////////////////////////////////////////////
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default {}; struct sm90_config_default {};
@ -192,28 +283,25 @@ struct sm90_config_default {};
template <typename OutType, template <typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<half_t, OutType, Epilogue> { struct sm90_config_default<half_t, OutType, Epilogue> {
// M in (128, inf) using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename OutType, template <typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> { struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
// M in (128, inf) using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape, cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule, ClusterShape, KernelSchedule, EpilogueSchedule>;
float>;
}; };
//////////////////////// Cherry-Picking Kernels //////////////////////// //////////////////////// Cherry-Picking Kernels ////////////////////////
@ -227,7 +315,7 @@ struct sm90_fp8_config_1 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -242,7 +330,7 @@ struct sm90_fp8_config_2 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -255,7 +343,7 @@ struct sm90_fp8_config_3 {
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -269,7 +357,7 @@ struct sm90_fp8_config_4 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -283,7 +371,7 @@ struct sm90_fp8_config_5 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -296,7 +384,7 @@ struct sm90_fp8_config_6 {
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -311,7 +399,7 @@ struct sm90_fp8_config_7 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -326,7 +414,7 @@ struct sm90_fp8_config_8 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -341,7 +429,7 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue, cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
TileShape, ClusterShape, KernelSchedule, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule, float>; EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -355,12 +443,9 @@ struct sm90_fp8_config_M64 {
using TileShape = Shape<_64, _64, _256>; using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -374,12 +459,9 @@ struct sm90_fp8_config_M128 {
using TileShape = Shape<_64, _128, _256>; using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -394,12 +476,9 @@ struct sm90_fp8_config_M256 {
using TileShape = Shape<_128, _128, _256>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -414,12 +493,9 @@ struct sm90_fp8_config_M512 {
using TileShape = Shape<_128, _128, _256>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename OutType, template <typename OutType,
@ -433,7 +509,7 @@ struct sm90_config_default<int8_t, OutType, Epilogue> {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -448,7 +524,7 @@ struct sm90_int8_config_M128 {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -462,7 +538,7 @@ struct sm90_int8_config_M64 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -476,7 +552,7 @@ struct sm90_int8_config_M32_NBig {
using ClusterShape = Shape<_1, _4, _1>; using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -490,7 +566,7 @@ struct sm90_int8_config_M32_NSmall {
using ClusterShape = Shape<_1, _8, _1>; using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
} // namespace } // namespace

View File

@ -23,6 +23,9 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a);
#endif #endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
@ -68,3 +71,30 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
"CUDA device capability: ", "CUDA device capability: ",
version_num); version_num);
} }
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
std::vector<torch::Tensor> result_tensors;
auto [a_meta, a_nzs] = cutlass_sparse_compress_sm90(a);
result_tensors.push_back(std::move(a_nzs));
result_tensors.push_back(std::move(a_meta));
return result_tensors;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_sparse_compress for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@ -348,10 +348,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm); ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor // CUTLASS sparse matrix compressor
ops.def( ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta," ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
" Tensor a) -> bool");
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
// Mamba selective scan kernel // Mamba selective scan kernel
ops.def( ops.def(

View File

@ -7,7 +7,6 @@ from typing import Tuple, Type
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -55,11 +54,39 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape) return pruned.reshape(original_shape)
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
b_compressed,
b_metadata,
eye_scale,
eye_scale,
out_dtype=out_dtype)
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
def make_rand_sparse_tensors( def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5 a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t() * 5 b = torch.randn((n, k), device='cuda').t()
if dtype == torch.int8:
# ensure A and B aren't all zeros after rounding
a = a * 5.0
b = b * 5.0
b = prune_to_2_4(b.t()).t() b = prune_to_2_4(b.t()).t()
@ -75,6 +102,7 @@ def make_rand_sparse_tensors(
raise ValueError("unsupported dtype") raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t()) b_compressed, e = ops.cutlass_sparse_compress(b.t())
check_compress_decompress_invariance(dtype, b, b_compressed, e)
# Compressed B, Metadata, Original A, B # Compressed B, Metadata, Original A, B
return b_compressed, e, a, b return b_compressed, e, a, b
@ -134,27 +162,37 @@ MNK_FACTORS = [
# Test working with a subset of A and B for sparse matmul # Test working with a subset of A and B for sparse matmul
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.") reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS) @pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]): @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
use_bias: bool):
# Create tensors # Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32) scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32) scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a, out = ops.cutlass_scaled_sparse_mm(a,
b_comp, b_comp,
e, e,
scale_a, scale_a,
scale_b, scale_b,
out_dtype=dtype) out_dtype=dtype,
baseline = F.linear(a, b.T) bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2) baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(not sparse_cutlass_supported(),
@ -162,27 +200,34 @@ def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
@pytest.mark.parametrize("m, k, n", MNK_FACTORS) @pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89), @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int): @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
# Create tensors # Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a, out = ops.cutlass_scaled_sparse_mm(a,
b_comp, b_comp,
e, e,
scale_a, scale_a,
scale_b, scale_b,
out_dtype=torch.bfloat16) out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(a, baseline = baseline_scaled_mm(a,
b, b,
scale_a, scale_a,
scale_b, scale_b,
out_dtype=torch.bfloat16) out_dtype=out_dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(not sparse_cutlass_supported(),
@ -198,18 +243,24 @@ def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a, out = ops.cutlass_scaled_sparse_mm(a,
b_comp, b_comp,
e, e,
scale_a, scale_a,
scale_b, scale_b,
out_dtype=torch.bfloat16) out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(a, baseline = baseline_scaled_mm(a,
b, b,
scale_a, scale_a,
scale_b, scale_b,
out_dtype=torch.bfloat16) out_dtype=out_dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)

View File

@ -564,22 +564,9 @@ def cutlass_sparse_compress(a: torch.Tensor) \
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem = 4 elemsPerMetaElem = 4
assert (a.shape[1] % (2 * elemsPerMetaElem) == 0)
m = a.shape[0] return torch.ops._C.cutlass_sparse_compress(a)
k = a.shape[1]
assert (k % 2 == 0)
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
dtype=torch.uint8,
device=a.device)
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
raise ValueError
assert (a_nzs.is_contiguous())
assert (a_meta.is_contiguous())
return a_nzs, a_meta
def cutlass_scaled_sparse_mm( def cutlass_scaled_sparse_mm(

View File

@ -408,13 +408,6 @@ class CompressedTensorsConfig(QuantizationConfig):
if self.supports_cutlass_24(weight_quant=weight_quant, if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant, input_quant=input_quant,
sparsity_scheme=sparsity_scheme): sparsity_scheme=sparsity_scheme):
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
# currently produce bad output in some cases
if weight_quant is None:
logger.warning_once(
"CompressedTensors24 scheme is disabled for the w16a16 "
"case. Falling back to UnquantizedLinearMethod")
return None
# Have a valid sparsity scheme # Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel # Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config = (None if sparsity_scheme is None model_compression_config = (None if sparsity_scheme is None

View File

@ -64,7 +64,6 @@ class CompressedTensors24(CompressedTensorsScheme):
"Sparse CUTLASS not supported. vLLM must be built with " "Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature") "CUDA 12.2 or later to use this feature")
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size = input_size layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
@ -205,6 +204,11 @@ class CompressedTensors24(CompressedTensorsScheme):
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False) layer.weight_scale.data, requires_grad=False)
# Set all negative zero values to 0 prior to compression
if (layer.weight.dtype.is_floating_point
and layer.weight.dtype.itemsize >= 2):
layer.weight.data[layer.weight.data == -0.0] = 0.0
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data) w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False)
@ -254,9 +258,10 @@ class CompressedTensors24(CompressedTensorsScheme):
bt_meta=layer.meta, bt_meta=layer.meta,
scale_a=input_scale, scale_a=input_scale,
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=self.output_dtype, out_dtype=x.dtype,
bias=bias, bias=bias,
) )
assert out.is_contiguous() assert out.is_contiguous()
return out return out