[Kernel] Factor out epilogues from cutlass kernels (#5391)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Tyler Michael Smith 2024-06-13 14:22:19 -04:00 committed by GitHub
parent 0ce7b952f8
commit 85657b5607
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 274 additions and 232 deletions

View File

@ -179,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu" "csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
# #
# The CUTLASS kernels for Hopper require sm90a to be enabled. # The CUTLASS kernels for Hopper require sm90a to be enabled.
@ -189,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# That adds an extra 17MB to compiled binary, so instead we selectively enable it. # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties( set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES PROPERTIES
COMPILE_FLAGS COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a") "-gencode arch=compute_90a,code=sm_90a")

View File

@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor, scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a, return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
b,
scale_a,
scale_b,
out_dtype=out_dtype)
# bench # bench

View File

@ -90,7 +90,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits); int64_t num_bits);
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);

View File

@ -29,21 +29,14 @@
using namespace cute; using namespace cute;
/* /*
This defines a quantized GEMM operation with dequantized output, similar to This file defines quantized GEMM operations using the CUTLASS 2.x API, for
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
NVIDIA GPUs with SM versions prior to sm90 (Hopper). NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or Epilogue functions can be defined to post-process the output before it is
per-row. B can be quantized per-tensor or per-column. written to GPU memory.
Any combination of per-tensor and per-row or column is supported. Epilogues must contain a public type named EVTCompute of type Sm80EVT,
A and B must have symmetric quantization (zero point == 0). as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/ */
namespace { namespace {
@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
} }
}; };
template <typename Arch, template <typename> typename ArchGuard, /*
typename ElementAB_, typename ElementD_, typename TileShape, This epilogue function defines a quantized GEMM operation similar to
typename WarpShape, typename InstructionShape, int32_t MainLoopStages> torch._scaled_mm.
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc = A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, per-row. B can be quantized per-tensor or per-column.
float>::type; Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
using Operator = So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
typename std::conditional<std::is_same_v<ElementAB, int8_t>, scales are applied elementwise with numpy-style broadcasting.
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
cutlass::multiplies, ElementD, float, cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>; cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore< using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>; Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off // clang-format off
using RowMajor = typename cutlass::layout::RowMajor; using RowMajor = typename cutlass::layout::RowMajor;
@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>; using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
}; };
template <typename Gemm> template <typename Gemm, typename... EpilogueArgs>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, EpilogueArgs&&... epilogue_params) {
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr()); auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride}; typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{ typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args, evt_args,
d_args, d_args,
}; };
@ -229,7 +252,7 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
} // namespace } // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
@ -243,20 +266,20 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t, cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t, cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
b_scales); out, a, b, a_scales, b_scales);
} }
} }
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
@ -270,20 +293,20 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t, cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t, cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
b_scales); out, a, b, a_scales, b_scales);
} }
} }
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t, cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
b_scales); out, a, b, a_scales, b_scales);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t, cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
b_scales); out, a, b, a_scales, b_scales);
} }
} else { } else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t, cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>( cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
out, a, b, a_scales, b_scales); InstructionShape, 5>>(out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t, cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>( cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
out, a, b, a_scales, b_scales); InstructionShape, 5>>(out, a, b, a_scales, b_scales);
} }
} }
} }

View File

@ -32,21 +32,14 @@
using namespace cute; using namespace cute;
/* /*
This defines a quantized GEMM operation with dequantized output, similar to This file defines quantized GEMM operations using the CUTLASS 3.x API, for
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
NVIDIA GPUs with sm90a (Hopper) or later. NVIDIA GPUs with sm90a (Hopper) or later.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or Epilogue functions can be defined to post-process the output before it is
per-row. B can be quantized per-tensor or per-column. written to GPU memory.
Any combination of per-tensor and per-row or column is supported. Epilogues must contain a public type named EVTCompute of type Sm90EVT,
A and B must have symmetric quantization (zero point == 0). as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/ */
namespace { namespace {
@ -71,21 +64,25 @@ struct enable_sm90_or_later : Kernel {
} }
}; };
template <typename ElementAB_, typename ElementD_, typename TileShape, /*
typename ClusterShape, typename KernelSchedule, This epilogue function defines a quantized GEMM operation similar to
typename EpilogueSchedule> torch.scaled_mm_.
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor = A and B may be both either int8 or fp8_e4m3. A can be
cutlass::epilogue::collective::detail::EpilogueDescriptor< quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, Any combination of per-tensor and per-row or column is supported.
ElementD, EpilogueSchedule>; A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
@ -111,19 +108,53 @@ struct cutlass_3x_gemm {
cutlass::multiplies, ElementD, float, cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>; cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
return ArgumentType{a_args, {b_args}};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>; using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void; using ElementC = void;
using StrideC = StrideD; using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute1>::CollectiveOp; EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage); sizeof(typename CollectiveEpilogue::SharedStorage);
@ -148,11 +179,10 @@ struct cutlass_3x_gemm {
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
}; };
template <typename Gemm> template <typename Gemm, typename... EpilogueArgs>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, EpilogueArgs&&... epilogue_params) {
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@ -182,19 +212,13 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride}; Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args}; prob_shape, mainloop_args, epilogue_args};
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
args.epilogue.thread = {a_args, {b_args}};
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op; GemmOp gemm_op;
@ -209,7 +233,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
template <typename InType, typename OutType, int32_t M> template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, int32_t M>
struct sm90_fp8_config { struct sm90_fp8_config {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = using KernelSchedule =
@ -219,12 +244,13 @@ struct sm90_fp8_config {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType,
struct sm90_fp8_config<InType, OutType, 128> { template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@ -233,12 +259,13 @@ struct sm90_fp8_config<InType, OutType, 128> {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType,
struct sm90_fp8_config<InType, OutType, 64> { template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@ -247,30 +274,28 @@ struct sm90_fp8_config<InType, OutType, 64> {
using ClusterShape = Shape<_1, _8, _1>; using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
}; };
} // namespace } // namespace
template <typename InType, typename OutType> template <typename InType, typename OutType,
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, template <typename, typename, typename> typename Epilogue,
torch::Tensor const& a, typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, EpilogueArgs&&... args) {
torch::Tensor const& b_scales) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm; typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
using Cutlass3xGemmM64 = using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm; typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 = using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm; typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
@ -278,20 +303,20 @@ void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
if (mp2 <= 64) { if (mp2 <= 64) {
// m in [1, 64] // m in [1, 64]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>( return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, a_scales, b_scales); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// m in (64, 128] // m in (64, 128]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>( return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales); out, a, b, std::forward<EpilogueArgs>(args)...);
} else { } else {
// m in (128, inf) // m in (128, inf)
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>( return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, a_scales, b_scales); out, a, b, std::forward<EpilogueArgs>(args)...);
} }
} }
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
@ -308,16 +333,15 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher< return cutlass_gemm_caller<cutlass_3x_gemm<
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape, int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>( KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher< return cutlass_gemm_caller<
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape, cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
KernelSchedule, EpilogueSchedule>>( ClusterShape, KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} else { } else {
@ -325,13 +349,13 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<
cutlass::bfloat16_t>( cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t>( cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }

View File

@ -3,29 +3,29 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/all.h> #include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
#endif #endif
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::Tensor const& b_scales) {
int32_t major_capability; int32_t major_capability;
@ -57,19 +57,19 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels // Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
#else #else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
#endif #endif
} else if (version_num == 89) { } else if (version_num == 89) {
// Ada Lovelace // Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
} else if (version_num >= 80) { } else if (version_num >= 80) {
// Ampere // Ampere
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
} else { } else {
// Turing // Turing
TORCH_CHECK(version_num >= 75); TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
} }
} }

View File

@ -136,10 +136,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization. // quantization.
ops.def( ops.def(
"cutlass_scaled_mm_dq(Tensor! out, Tensor a," "cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales," " Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()"); " Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.

View File

@ -47,7 +47,7 @@ def cutlass_fp8_gemm_helper(m: int,
scale_b = (torch.randn( scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10) (1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32), baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype) scale_b * b.to(dtype=torch.float32)).to(out_dtype)
@ -74,7 +74,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_b = (torch.randn( scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10) (1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32), baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype) b.to(dtype=torch.float32)).to(dtype=out_dtype)
@ -180,7 +180,7 @@ def test_cutlass_subset():
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_mm_dq(a, out = ops.cutlass_scaled_mm(a,
b, b,
scale_a, scale_a,
scale_b, scale_b,
@ -203,7 +203,7 @@ class CutlassLayer(torch.nn.Module):
self.out_dtype = out_dtype self.out_dtype = out_dtype
def forward(self, a): def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
self.out_dtype) self.out_dtype)

View File

@ -212,8 +212,8 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass # cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor: out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
@ -222,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1] n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
return out return out

View File

@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x) x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales, return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
weight_scale, x.dtype) weight_scale, x.dtype)

View File

@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
# Input quantize # Input quantize
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale) x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
weight_scale, x.dtype) weight_scale, x.dtype)

View File

@ -261,7 +261,7 @@ class Fp8LinearMethod(LinearMethodBase):
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ # Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq( output = ops.cutlass_scaled_mm(
qinput, qinput,
layer.weight, layer.weight,
out_dtype=x.dtype, out_dtype=x.dtype,