From f89978ad7cf3732a4a482db8e784aba7e417abbd Mon Sep 17 00:00:00 2001 From: kushanam <42385577+kushanam@users.noreply.github.com> Date: Tue, 4 Mar 2025 07:55:07 -0800 Subject: [PATCH] add cutlass support for blackwell fp8 gemm (#13798) --- CMakeLists.txt | 32 ++++++--- .../epilogue/scaled_mm_epilogues_c3x.hpp | 48 +++++++------ .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 32 +++++---- .../cutlass_w8a8/c3x/scaled_mm.cuh | 68 +++++++++++++++++-- .../cutlass_w8a8/c3x/scaled_mm_kernels.hpp | 6 ++ .../cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu | 24 +++++++ .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 67 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_c3x.cu | 25 +++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 21 +++++- .../machete/machete_mm_kernel.cuh | 7 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 7 +- 11 files changed, 272 insertions(+), 65 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index c5fc2f3c..f7e32929 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101") @@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_ARCHS) set(MARLIN_SRCS "csrc/quantization/fp8/fp8_marlin.cu" @@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). - cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" @@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # 2:4 Sparse Kernels # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). + # require CUDA 12.2 or later (and only work on Hopper and Blackwell). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( @@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" - ) + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() + # FP8 Blackwell Archs + cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${BLACKWELL_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}") + else() + # clear BLACKWELL_ARCHS + set(BLACKWELL_ARCHS) + endif() + # # Machete kernels @@ -514,6 +529,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -537,7 +553,7 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC "csrc/moe/marlin_kernels/marlin_moe_kernel.h" diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 1a0cd45f..0a812dc5 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -22,7 +22,7 @@ struct identity { T operator()(T lhs) const { return lhs; } }; -template +template struct TrivialEpilogue { private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; @@ -44,32 +44,30 @@ struct TrivialEpilogue { * This class provides the common load descriptors for the * ScaledEpilogue[...] classes */ -template +template struct ScaledEpilogueBase { protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; template using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -116,11 +114,11 @@ struct ScaledEpilogueBase { the A and B operands respectively. These scales may be either per-tensor or per row or column. */ -template +template struct ScaledEpilogue - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -160,11 +158,11 @@ struct ScaledEpilogue * The bias tensor must be per-output channel. * ScaleA and ScaleB can be per-tensor or per-token/per-channel. */ -template +template struct ScaledEpilogueBias - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -203,11 +201,11 @@ struct ScaledEpilogueBias * bias is a column vector instead of a row vector. Useful e.g. if we are * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. */ -template +template struct ScaledEpilogueColumnBias - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 69a3f64c..26de32ce 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -16,6 +16,7 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" #include "core/math.hpp" #include "cutlass_extensions/common.hpp" @@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; + using ElementC = typename Gemm::ElementC; using ElementD = typename Gemm::ElementD; using GemmKernel = typename Gemm::GemmKernel; - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = cute::Stride, int64_t>; - using StrideB = cute::Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, cute::Int<1>{}, 0}; - StrideB b_stride{ldb, cute::Int<1>{}, 0}; - StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = StrideC; + using StrideAux = StrideC; typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + auto [M, N, K, L] = prob_shape; + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + StrideC c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + StrideD d_stride = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + StrideAux aux_stride = d_stride; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); @@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, b_stride}; auto c_ptr = static_cast(out.data_ptr()); + // auto d_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; + c_ptr, c_stride, c_ptr, d_stride}; cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, epilogue_args); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index d2f43e2b..8f4df836 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -40,12 +40,7 @@ struct cutlass_3x_gemm { typename std::conditional, int32_t, float>::type; - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; + using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; @@ -88,4 +83,65 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm100 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 7ede9e06..85272804 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 00000000..cf2cccc9 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh new file mode 100644 index 00000000..468b77d9 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" + +/** + * This file defines Gemm kernel configurations for SM100 (fp8) based on the + * Gemm shape. + */ + +namespace vllm { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm100_fp8_config_default::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); +} + +template