#include #include // clang-format will break include orders // clang-format off #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/util/device_memory.h" #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "cutlass_visitor_2x_broadcast_epilogue.hpp" #include "common.hpp" // clang-format on using namespace cute; /* This defines a quantized GEMM operation with dequantized output, similar to torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for NVIDIA GPUs with SM versions prior to sm90 (Hopper). A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or per-row. B can be quantized per-tensor or per-column. Any combination of per-tensor and per-row or column is supported. A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. ScaleA and ScaleB define the epilogue functions that apply the scales for the A and B operands respectively. These scales may be either per-tensor or per row or column. */ namespace { template struct cutlass_2x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = typename std::conditional, int32_t, float>::type; using Operator = typename std::conditional, cutlass::arch::OpMultiplyAddSaturate, cutlass::arch::OpMultiplyAdd>::type; using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< TileShape, WarpShape, float, 4, 1 /* epilogue stages */ >; using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT; using D = cutlass::epilogue::threadblock::VisitorAuxStore< OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, Stride, Int<0>>>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT; // clang-format off using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; using KernelType = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, float, cutlass::layout::RowMajor, 4, ElementAcc, float, cutlass::arch::OpClassTensorOp, Arch, TileShape, WarpShape, InstructionShape, EVTD, cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, MainLoopStages, Operator, 1 /* epilogue stages */ >::GemmKernel; // clang-format on using Op = cutlass::gemm::device::GemmUniversalAdapter; }; template void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); cutlass::gemm::GemmCoord problem_size{m, n, k}; int64_t lda = a.stride(0); int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); using StrideC = Stride, Int<0>>; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); // If A and B are quantized per-tensor, then these scale tensors are scalars, // and they are passed in via the second argument. using ScaleAArgs = typename Gemm::ScaleA::Arguments; ScaleAArgs a_args = a_scales.numel() == 1 ? ScaleAArgs{nullptr, a_scales.item(), {}} : ScaleAArgs{a_scales.data_ptr(), {}, {}}; using ScaleBArgs = typename Gemm::ScaleB::Arguments; ScaleBArgs b_args = b_scales.numel() == 1 ? ScaleBArgs{nullptr, b_scales.item(), {}} : ScaleBArgs{b_scales.data_ptr(), {}, {}}; typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, evt0_compute_args}; typename Gemm::D::Arguments d_args{c_ptr, c_stride}; typename Gemm::EVTD::Arguments epilogue_args{ evt1_compute_args, d_args, }; typename Gemm::Op::Arguments args{ cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode problem_size, // problem size 1, // batch count epilogue_args, a_ptr, b_ptr, nullptr, nullptr, 0, 0, 0, 0, lda, ldb, ldc, ldc}; // Launch the CUTLASS GEMM kernel. typename Gemm::Op gemm_op; size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); CUTLASS_CHECK(gemm_op.can_implement(args)); cutlass::Status status = gemm_op(args, workspace.get()); CUTLASS_CHECK(status); } } // namespace void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>(out, a, b, a_scales, b_scales); } } void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>(out, a, b, a_scales, b_scales); } } void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (a.dtype() == torch::kInt8) { TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>( out, a, b, a_scales, b_scales); } else { assert(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher< cutlass_2x_gemm>( out, a, b, a_scales, b_scales); } } else { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, b_scales); } } }