add cutlass support for blackwell fp8 gemm (#13798)
This commit is contained in:
parent
b3cf368d79
commit
f89978ad7c
@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
|
|||||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||||
|
|
||||||
# Supported NVIDIA architectures.
|
# Supported NVIDIA architectures.
|
||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
||||||
@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||||
# are not supported by Machete yet.
|
# are not supported by Machete yet.
|
||||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_ARCHS)
|
if (MARLIN_ARCHS)
|
||||||
set(MARLIN_SRCS
|
set(MARLIN_SRCS
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||||
@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||||
# kernels for the remaining archs that are not already built for 3x.
|
# kernels for the remaining archs that are not already built for 3x.
|
||||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||||
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
# subtract out the archs that are already built for 3x
|
# subtract out the archs that are already built for 3x
|
||||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
if (SCALED_MM_2X_ARCHS)
|
if (SCALED_MM_2X_ARCHS)
|
||||||
@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# 2:4 Sparse Kernels
|
# 2:4 Sparse Kernels
|
||||||
|
|
||||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
|
||||||
)
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${FP4_ARCHS}")
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(FP4_ARCHS)
|
set(FP4_ARCHS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# FP8 Blackwell Archs
|
||||||
|
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||||
|
)
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${BLACKWELL_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
|
||||||
|
else()
|
||||||
|
# clear BLACKWELL_ARCHS
|
||||||
|
set(BLACKWELL_ARCHS)
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
||||||
@ -514,6 +529,7 @@ define_gpu_extension_target(
|
|||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
@ -537,7 +553,7 @@ set_gencode_flags_for_srcs(
|
|||||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_MOE_ARCHS)
|
if (MARLIN_MOE_ARCHS)
|
||||||
set(MARLIN_MOE_SRC
|
set(MARLIN_MOE_SRC
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||||
|
@ -22,7 +22,7 @@ struct identity {
|
|||||||
T operator()(T lhs) const { return lhs; }
|
T operator()(T lhs) const { return lhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct TrivialEpilogue {
|
struct TrivialEpilogue {
|
||||||
private:
|
private:
|
||||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
@ -44,32 +44,30 @@ struct TrivialEpilogue {
|
|||||||
* This class provides the common load descriptors for the
|
* This class provides the common load descriptors for the
|
||||||
* ScaledEpilogue[...] classes
|
* ScaledEpilogue[...] classes
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogueBase {
|
struct ScaledEpilogueBase {
|
||||||
protected:
|
protected:
|
||||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
|
||||||
|
|
||||||
// Don't want to support nullptr by default
|
// Don't want to support nullptr by default
|
||||||
template <typename T, bool EnableNullPtr = false>
|
template <typename T, bool EnableNullPtr = false>
|
||||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
||||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
// Don't want to support nullptr by default
|
// Don't want to support nullptr by default
|
||||||
template <typename T, bool EnableNullPtr = false>
|
template <typename T, bool EnableNullPtr = false>
|
||||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
// This utility function constructs the arguments for the load descriptors
|
// This utility function constructs the arguments for the load descriptors
|
||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
|
|||||||
the A and B operands respectively. These scales may be either per-tensor or
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
per row or column.
|
per row or column.
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogue
|
struct ScaledEpilogue
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
@ -160,11 +158,11 @@ struct ScaledEpilogue
|
|||||||
* The bias tensor must be per-output channel.
|
* The bias tensor must be per-output channel.
|
||||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogueBias
|
struct ScaledEpilogueBias
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
@ -203,11 +201,11 @@ struct ScaledEpilogueBias
|
|||||||
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
||||||
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogueColumnBias
|
struct ScaledEpilogueColumnBias
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
|
|||||||
*
|
*
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogueBiasAzp
|
struct ScaledEpilogueBiasAzp
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp
|
|||||||
*
|
*
|
||||||
* This epilogue also supports bias, which remains per-channel.
|
* This epilogue also supports bias, which remains per-channel.
|
||||||
*/
|
*/
|
||||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||||
struct ScaledEpilogueBiasAzpToken
|
struct ScaledEpilogueBiasAzpToken
|
||||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||||
private:
|
private:
|
||||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||||
using Accum = typename SUPER::Accum;
|
using Accum = typename SUPER::Accum;
|
||||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/util/packed_stride.hpp"
|
||||||
|
|
||||||
#include "core/math.hpp"
|
#include "core/math.hpp"
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_params) {
|
EpilogueArgs&&... epilogue_params) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementC = typename Gemm::ElementC;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
|
||||||
int64_t lda = a.stride(0);
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
int64_t ldb = b.stride(1);
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
int64_t ldc = out.stride(0);
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||||
|
using StrideD = StrideC;
|
||||||
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
using StrideAux = StrideC;
|
||||||
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
|
||||||
using StrideC = typename Gemm::StrideC;
|
|
||||||
|
|
||||||
StrideA a_stride{lda, cute::Int<1>{}, 0};
|
|
||||||
StrideB b_stride{ldb, cute::Int<1>{}, 0};
|
|
||||||
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
|
|
||||||
|
|
||||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||||
|
auto [M, N, K, L] = prob_shape;
|
||||||
|
|
||||||
|
StrideA a_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||||
|
StrideB b_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||||
|
StrideC c_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||||
|
StrideD d_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||||
|
StrideAux aux_stride = d_stride;
|
||||||
|
|
||||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
b_stride};
|
b_stride};
|
||||||
|
|
||||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
Gemm::Epilogue::prepare_args(
|
Gemm::Epilogue::prepare_args(
|
||||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||||
c_ptr, c_stride, c_ptr, c_stride};
|
c_ptr, c_stride, c_ptr, d_stride};
|
||||||
|
|
||||||
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
epilogue_args);
|
epilogue_args);
|
||||||
|
@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
|
|||||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
float>::type;
|
float>::type;
|
||||||
|
|
||||||
using EpilogueDescriptor =
|
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
|
||||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
|
||||||
ElementD, EpilogueSchedule>;
|
|
||||||
|
|
||||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
|
|
||||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
using ElementC = void;
|
using ElementC = void;
|
||||||
@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
|
|||||||
struct GemmKernel : public KernelType {};
|
struct GemmKernel : public KernelType {};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename ElementAB_, typename ElementD_,
|
||||||
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule>
|
||||||
|
struct cutlass_3x_gemm_sm100 {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentA =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
static constexpr int AlignmentB =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
|
||||||
|
using ElementC = void;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentC =
|
||||||
|
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||||
|
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
static constexpr int AlignmentD = AlignmentC;
|
||||||
|
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||||
|
|
||||||
|
// MMA type
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
|
||||||
|
// Epilogue types
|
||||||
|
using ElementBias = cutlass::half_t;
|
||||||
|
using ElementCompute = float;
|
||||||
|
using ElementAux = ElementD;
|
||||||
|
using LayoutAux = LayoutD;
|
||||||
|
using ElementAmax = float;
|
||||||
|
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||||
|
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||||
|
EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||||
|
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||||
|
ElementAccumulator, TileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
|
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -0,0 +1,67 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "scaled_mm.cuh"
|
||||||
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
|
||||||
|
* Gemm shape.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
using c3x::cutlass_gemm_caller;
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm100_fp8_config_default {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_256, _128, _64>;
|
||||||
|
using ClusterShape = Shape<_2, _2, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm100_fp8_config_default<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -71,3 +71,28 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||||
azp, bias);
|
azp, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||||
|
TORCH_CHECK(
|
||||||
|
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||||
|
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||||
|
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||||
|
|
||||||
|
// Standard per-tensor/per-token/per-channel scaling
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||||
|
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||||
|
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
|||||||
// and at least SM90 (Hopper)
|
// and at least SM90 (Hopper)
|
||||||
|
|
||||||
#if defined CUDA_VERSION
|
#if defined CUDA_VERSION
|
||||||
if (cuda_device_capability >= 90) {
|
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||||
return CUDA_VERSION >= 12000;
|
return CUDA_VERSION >= 12000;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
// Guard against compilation issues for sm90 kernels
|
||||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
if (version_num >= 90) {
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
||||||
|
if (version_num >= 90 && version_num < 100) {
|
||||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
if (version_num >= 90 && version_num < 100) {
|
||||||
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
} else if (version_num >= 100) {
|
||||||
|
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
|
@ -126,15 +126,10 @@ struct MacheteKernelTemplate {
|
|||||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||||
"Currently token and channel scales (if present) must be the same type");
|
"Currently token and channel scales (if present) must be the same type");
|
||||||
|
|
||||||
using EpilogueDescriptor =
|
|
||||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
|
||||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
|
||||||
ElementD, EpilogueSchedule>;
|
|
||||||
|
|
||||||
// Currently only supports float scales
|
// Currently only supports float scales
|
||||||
using ChTokScalesEpilogue =
|
using ChTokScalesEpilogue =
|
||||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||||
EpilogueDescriptor>;
|
TileShape>;
|
||||||
static_assert((with_channel_scales || with_token_scales) ||
|
static_assert((with_channel_scales || with_token_scales) ||
|
||||||
(std::is_same_v<ElementSChannel, float> &&
|
(std::is_same_v<ElementSChannel, float> &&
|
||||||
std::is_same_v<ElementSToken, float>),
|
std::is_same_v<ElementSToken, float>),
|
||||||
|
@ -65,12 +65,7 @@ struct cutlass_sparse_3x_gemm {
|
|||||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
float>::type;
|
float>::type;
|
||||||
|
|
||||||
using EpilogueDescriptor =
|
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
|
||||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
|
||||||
ElementD, EpilogueSchedule>;
|
|
||||||
|
|
||||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
||||||
|
|
||||||
using ElementC = void;
|
using ElementC = void;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user