[Kernel][Bugfix] Refactor and Fix CUTLASS 2:4 Sparse Kernels (#13198)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
2344192a55
commit
c1e37bf71b
@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# Please keep this in sync with CUTLASS_REVISION line above.
|
||||
GIT_TAG v3.7.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
@ -266,7 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -359,8 +360,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
@ -476,7 +476,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
@ -16,6 +16,30 @@ namespace vllm::c3x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
struct identity {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T lhs) const { return lhs; }
|
||||
};
|
||||
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct TrivialEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
using Compute = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
template <typename... Args>
|
||||
static ArgumentType prepare_args(Args... args) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
@ -174,6 +198,49 @@ struct ScaledEpilogueBias
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogueBias, but the
|
||||
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
||||
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueColumnBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template ColLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c3x
|
||||
}; // namespace vllm::c3x
|
||||
|
@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
torch::Tensor& e, torch::Tensor const& a);
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
@ -53,12 +53,17 @@ struct cutlass_3x_gemm {
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
|
||||
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, 16,
|
||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
float, cutlass::layout::RowMajor, 4,
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
|
||||
float, cutlass::layout::RowMajor, AlignmentCD,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
|
@ -1,165 +0,0 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename ElementA_, typename ElementAcc_>
|
||||
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
|
||||
// Sparse kernel setup; this kernel is not used for matmul,
|
||||
// but just for setting up the compressor utility
|
||||
// A matrix configuration
|
||||
using ElementA = ElementA_;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA;
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
// C/D matrix configuration
|
||||
using ElementC = float;
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = ElementAcc_;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using TileShapeRef = Shape<_128, _128, _64>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = typename std::conditional<
|
||||
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::type;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using ProblemShape = Shape<int, int, int, int>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
|
||||
AlignmentC, ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
|
||||
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
|
||||
// The n (=1) dimension does not matter for the compressor
|
||||
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
|
||||
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
|
||||
CompressorUtility compressor_utility(prob_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
|
||||
a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
if (a.dtype() == torch::kBFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
90
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
Normal file
90
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
Normal file
@ -0,0 +1,90 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename Gemm>
|
||||
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
using GemmKernel = typename Gemm::KernelType;
|
||||
using ElementA = typename Gemm::ElementAB;
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
using ProblemShape = typename GemmKernel::ProblemShape;
|
||||
ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
|
||||
using CompressorUtility = typename Gemm::CompressorUtility;
|
||||
CompressorUtility compressor_utility(prob_shape, a_stride);
|
||||
|
||||
// Allocate buffers for the metadata E and the compressed matrix A
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int MC = compressor_utility.get_tensorA_m_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto const a_meta_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto const a_nzs_options =
|
||||
torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
|
||||
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
|
||||
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
using Compressor = typename Gemm::Compressor;
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return {a_meta, a_nzs};
|
||||
}
|
||||
|
||||
#endif
|
@ -1,42 +0,0 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
|
||||
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
|
||||
a_nzs.size(1) * 2 == a.size(1) &&
|
||||
a_meta.size(1) * 2 * 4 == a.size(1));
|
||||
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
|
||||
a_meta.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
@ -9,17 +9,30 @@
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
struct GemmCallerTraits {
|
||||
using return_type = void;
|
||||
|
||||
template <typename GemmConfig, typename... Args>
|
||||
static return_type invoke(Args&&... args) {
|
||||
return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmCompressorTraits {
|
||||
using return_type = CompressorResult;
|
||||
|
||||
template <typename GemmConfig, typename... Args>
|
||||
static return_type invoke(Args&&... args) {
|
||||
return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
using Cutlass3xGemm8 =
|
||||
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = bt_nzs.size(0);
|
||||
uint32_t const m = a.size(0); // Batch size
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm2>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 4096 || n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm1>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 128) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm3>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm5>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm4>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 256) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm6>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm8>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm7>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else {
|
||||
if (n == 6144 || n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm8>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemm7>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise the default heuristic
|
||||
if (mp2 <= 64) {
|
||||
// n in [1, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// n in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// n in (128, 256]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM256>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
// n in (256, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM512>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::half_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
typename DispatchFunc, typename... Args>
|
||||
typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
|
||||
uint32_t m, uint32_t n, Args&&... args) {
|
||||
static_assert(std::is_same_v<InType, int8_t>);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
|
||||
std::forward<Args>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch to GEMM implementations based on element types
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
@ -217,19 +193,24 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
uint32_t const m = out.size(0);
|
||||
uint32_t const n = out.size(1);
|
||||
|
||||
// TODO: add dispatch functions to all of these
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
|
||||
GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
@ -237,47 +218,34 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
cutlass::bfloat16_t, Epilogue,
|
||||
GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
return cutlass_gemm_sm90_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else { // a.dtype() == torch::kBFloat16
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(out.dtype() == torch::kBFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
return cutlass_gemm_sm90_16bit_dispatch<
|
||||
cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
|
||||
m, n, out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,17 +255,53 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
|
||||
"CUTLASS scaled_mm bias dtype must match output dtype ",
|
||||
out.dtype());
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<
|
||||
c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
|
||||
a_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales);
|
||||
}
|
||||
}
|
||||
|
||||
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
|
||||
// These m and n variables are fordispatching to different GEMM algorithms.
|
||||
uint32_t const m = 1; // Set M to 1 for compression
|
||||
uint32_t const n = a.size(1);
|
||||
|
||||
// Note: For correctess, the compressed format must be invariant in:
|
||||
// - M, the flattened number of tokens
|
||||
// - Whether output dtype is fp16 or bf16
|
||||
// - CUTLASS epilogues
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_gemm_sm90_16bit_dispatch<
|
||||
cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16,
|
||||
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
|
||||
"and bf16 datatypes");
|
||||
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
c3x::TrivialEpilogue,
|
||||
GemmCompressorTraits>(m, n, a);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
@ -12,6 +14,9 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@ -22,7 +27,7 @@
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
|
||||
This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
|
||||
for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
@ -45,17 +50,20 @@ struct enable_sm90_or_later : Kernel {
|
||||
|
||||
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
|
||||
|
||||
/*
|
||||
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
|
||||
* for SM90 Hopper systems.
|
||||
*/
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, typename AccType,
|
||||
typename TileSchedule = cutlass::gemm::PersistentScheduler,
|
||||
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_sparse_3x_gemm {
|
||||
static const GemmUniversalMode Mode = Mode_;
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc = AccType;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
@ -66,30 +74,22 @@ struct cutlass_sparse_3x_gemm {
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
static constexpr int AlignmentA =
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
|
||||
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
|
||||
ElementAcc, float, ElementC, LayoutC_Transpose, AlignmentCD, ElementD,
|
||||
LayoutC_Transpose, AlignmentCD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
@ -101,8 +101,8 @@ struct cutlass_sparse_3x_gemm {
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentA,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentB,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
@ -110,11 +110,100 @@ struct cutlass_sparse_3x_gemm {
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
TileSchedule>>;
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
// Sparse compressor definitions
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
|
||||
SparseConfig>;
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
|
||||
SparseConfig, cutlass::arch::Sm90>;
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
};
|
||||
|
||||
/*
|
||||
* This class defines kernel to compress a 2:4 sparse matrix.
|
||||
* The particular format is defined by the Gemm template parameter,
|
||||
* which is a cutlass_sparse_3x_gemm.
|
||||
*/
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename Gemm>
|
||||
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
using GemmKernel = typename Gemm::KernelType;
|
||||
using ElementA = typename Gemm::ElementAB;
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
using ProblemShape = typename GemmKernel::ProblemShape;
|
||||
ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
|
||||
using CompressorUtility = typename Gemm::CompressorUtility;
|
||||
CompressorUtility compressor_utility(prob_shape, a_stride);
|
||||
|
||||
// Allocate buffers for the metadata E and the compressed matrix A
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int MC = compressor_utility.get_tensorA_m_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto const a_meta_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto const a_nzs_options =
|
||||
torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
|
||||
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
|
||||
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
using Compressor = typename Gemm::Compressor;
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return {a_meta, a_nzs};
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
@ -126,27 +215,25 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
// Interface stride expected from the argument a (will get transposed)
|
||||
// We compute C^T = B^T * A^T, but we assume B is transposed before
|
||||
// compression and hence the bt_* naming
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
// M, N, K after transposition
|
||||
int32_t m = out.size(1);
|
||||
int32_t n = out.size(0);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(a, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(out, "D");
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = Stride<Int<1>, int64_t, int64_t>;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideC c_stride{Int<1>{}, ldc, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{
|
||||
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
|
||||
static_cast<int>(size<1>(layout_A)), 1};
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
@ -158,13 +245,13 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
|
||||
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout};
|
||||
b_ptr, b_layout, a_ptr, a_stride, e_ptr, e_layout};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, stride_Dt, c_ptr, stride_Dt};
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
@ -185,6 +272,10 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// Gemm Configs are defined below
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default {};
|
||||
@ -192,28 +283,25 @@ struct sm90_config_default {};
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<half_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
float>;
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
//////////////////////// Cherry-Picking Kernels ////////////////////////
|
||||
@ -227,7 +315,7 @@ struct sm90_fp8_config_1 {
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -242,7 +330,7 @@ struct sm90_fp8_config_2 {
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -255,7 +343,7 @@ struct sm90_fp8_config_3 {
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -269,7 +357,7 @@ struct sm90_fp8_config_4 {
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -283,7 +371,7 @@ struct sm90_fp8_config_5 {
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -296,7 +384,7 @@ struct sm90_fp8_config_6 {
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -311,7 +399,7 @@ struct sm90_fp8_config_7 {
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -326,7 +414,7 @@ struct sm90_fp8_config_8 {
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -341,7 +429,7 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
|
||||
TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule, float>;
|
||||
EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -355,12 +443,9 @@ struct sm90_fp8_config_M64 {
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -374,12 +459,9 @@ struct sm90_fp8_config_M128 {
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -394,12 +476,9 @@ struct sm90_fp8_config_M256 {
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -414,12 +493,9 @@ struct sm90_fp8_config_M512 {
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
@ -433,7 +509,7 @@ struct sm90_config_default<int8_t, OutType, Epilogue> {
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -448,7 +524,7 @@ struct sm90_int8_config_M128 {
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -462,7 +538,7 @@ struct sm90_int8_config_M64 {
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -476,7 +552,7 @@ struct sm90_int8_config_M32_NBig {
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -490,7 +566,7 @@ struct sm90_int8_config_M32_NSmall {
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
@ -23,6 +23,9 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
|
||||
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -68,3 +71,30 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
std::vector<torch::Tensor> result_tensors;
|
||||
|
||||
auto [a_meta, a_nzs] = cutlass_sparse_compress_sm90(a);
|
||||
result_tensors.push_back(std::move(a_nzs));
|
||||
result_tensors.push_back(std::move(a_meta));
|
||||
return result_tensors;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_sparse_compress for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
@ -348,10 +348,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
ops.def(
|
||||
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
|
||||
" Tensor a) -> bool");
|
||||
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
|
||||
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
||||
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
|
@ -7,7 +7,6 @@ from typing import Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -55,11 +54,39 @@ def prune_to_2_4(tensor):
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor):
|
||||
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
|
||||
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
|
||||
b_compressed,
|
||||
b_metadata,
|
||||
eye_scale,
|
||||
eye_scale,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
a = torch.randn((m, k), device='cuda')
|
||||
b = torch.randn((n, k), device='cuda').t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
a = a * 5.0
|
||||
b = b * 5.0
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
@ -75,6 +102,7 @@ def make_rand_sparse_tensors(
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
@ -134,27 +162,37 @@ MNK_FACTORS = [
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
|
||||
use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype)
|
||||
baseline = F.linear(a, b.T)
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
@ -162,27 +200,34 @@ def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
@ -198,18 +243,24 @@ def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
|
@ -564,22 +564,9 @@ def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
|
||||
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
elemsPerMetaElem = 4
|
||||
assert (a.shape[1] % (2 * elemsPerMetaElem) == 0)
|
||||
|
||||
m = a.shape[0]
|
||||
k = a.shape[1]
|
||||
assert (k % 2 == 0)
|
||||
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
|
||||
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
|
||||
dtype=torch.uint8,
|
||||
device=a.device)
|
||||
|
||||
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
|
||||
raise ValueError
|
||||
|
||||
assert (a_nzs.is_contiguous())
|
||||
assert (a_meta.is_contiguous())
|
||||
|
||||
return a_nzs, a_meta
|
||||
return torch.ops._C.cutlass_sparse_compress(a)
|
||||
|
||||
|
||||
def cutlass_scaled_sparse_mm(
|
||||
|
@ -408,13 +408,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme):
|
||||
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
|
||||
# currently produce bad output in some cases
|
||||
if weight_quant is None:
|
||||
logger.warning_once(
|
||||
"CompressedTensors24 scheme is disabled for the w16a16 "
|
||||
"case. Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (None if sparsity_scheme is None
|
||||
|
@ -64,7 +64,6 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature")
|
||||
|
||||
self.output_dtype = params_dtype
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
@ -205,6 +204,11 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if (layer.weight.dtype.is_floating_point
|
||||
and layer.weight.dtype.itemsize >= 2):
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
@ -254,9 +258,10 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=self.output_dtype,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user