[Kernel] Factor out epilogues from cutlass kernels (#5391)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Tyler Michael Smith 2024-06-13 14:22:19 -04:00 committed by GitHub
parent 0ce7b952f8
commit 85657b5607
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 274 additions and 232 deletions

View File

@ -179,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
@ -189,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")

View File

@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype)
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench

View File

@ -90,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif

View File

@ -29,21 +29,14 @@
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args,
evt_args,
d_args,
};
@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
} // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
}
}
}

View File

@ -32,21 +32,14 @@
using namespace cute;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
@ -71,21 +64,25 @@ struct enable_sm90_or_later : Kernel {
}
};
template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
@ -111,19 +108,53 @@ struct cutlass_3x_gemm {
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
return ArgumentType{a_args, {b_args}};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute1>::CollectiveOp;
EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
@ -148,11 +179,10 @@ struct cutlass_3x_gemm {
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@ -182,19 +212,13 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
args.epilogue.thread = {a_args, {b_args}};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
@ -209,7 +233,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType, int32_t M>
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, int32_t M>
struct sm90_fp8_config {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
@ -219,12 +244,13 @@ struct sm90_fp8_config {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 128> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@ -233,12 +259,13 @@ struct sm90_fp8_config<InType, OutType, 128> {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 64> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@ -247,30 +274,28 @@ struct sm90_fp8_config<InType, OutType, 64> {
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType>
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
@ -278,23 +303,23 @@ void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
@ -308,16 +333,15 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<cutlass_3x_gemm<
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(
return cutlass_gemm_caller<
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
}
} else {
@ -325,13 +349,13 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t>(
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t>(
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}

View File

@ -3,31 +3,31 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
@ -57,19 +57,19 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
#else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
}
}

View File

@ -136,10 +136,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
#endif
// Quantized GEMM for GPTQ.

View File

@ -47,7 +47,7 @@ def cutlass_fp8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
@ -74,7 +74,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype)
@ -180,11 +180,11 @@ def test_cutlass_subset():
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
@ -203,8 +203,8 @@ class CutlassLayer(torch.nn.Module):
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])

View File

@ -212,9 +212,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
@ -222,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
return out

View File

@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
weight_scale, x.dtype)
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
weight_scale, x.dtype)

View File

@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
# Input quantize
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
weight_scale, x.dtype)

View File

@ -261,7 +261,7 @@ class Fp8LinearMethod(LinearMethodBase):
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,