[Build/BugFix] Fix hopper 12.8 build (#14354)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
be0b399d74
commit
7caff01a7b
@ -333,36 +333,64 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
" in CUDA target architectures, or CUDA not >= 12.0")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
set(SCALED_MM_3X_ARCHS)
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
# CUDA 12.0 or later
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||||
|
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||||
else()
|
else()
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
"later if you intend on running FP8 quantized models on "
|
"later if you intend on running FP8 quantized models on "
|
||||||
"Hopper.")
|
"Hopper.")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||||
"in CUDA target architectures")
|
"in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
|
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
|
||||||
# build any 3x kernels
|
# CUDA 12.8 or later
|
||||||
set(SCALED_MM_3X_ARCHS)
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||||
|
)
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||||
|
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||||
|
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||||
|
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||||
|
"later if you intend on running FP8 quantized models on "
|
||||||
|
"Blackwell.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -395,16 +423,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||||
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||||
else()
|
else()
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||||
@ -432,22 +460,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(FP4_ARCHS)
|
set(FP4_ARCHS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# FP8 Blackwell Archs
|
|
||||||
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
|
||||||
)
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${BLACKWELL_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
|
|
||||||
else()
|
|
||||||
# clear BLACKWELL_ARCHS
|
|
||||||
set(BLACKWELL_ARCHS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
||||||
|
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include "c3x/scaled_mm_kernels.hpp"
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
|
NVIDIA GPUs with sm100 (Blackwell).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||||
|
TORCH_CHECK(
|
||||||
|
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||||
|
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||||
|
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||||
|
|
||||||
|
// Standard per-tensor/per-token/per-channel scaling
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||||
|
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||||
|
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@ -5,9 +5,11 @@
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
NVIDIA GPUs with sm90a (Hopper).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -72,27 +74,4 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
azp, bias);
|
azp, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
|
||||||
TORCH_CHECK(
|
|
||||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
|
||||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
|
||||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
|
||||||
|
|
||||||
// Standard per-tensor/per-token/per-channel scaling
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
|
||||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
|
||||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -23,12 +23,15 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -60,7 +63,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
std::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -121,26 +124,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
// Hopper
|
|
||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
if (version_num >= 100) {
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
|
||||||
if (version_num >= 90 && version_num < 100) {
|
|
||||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if (version_num >= 90 && version_num < 100) {
|
|
||||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
} else if (version_num >= 100) {
|
|
||||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
if (version_num >= 90 && version_num < 100) {
|
||||||
|
// Hopper
|
||||||
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
@ -211,7 +209,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
|
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
if (version_num >= 90) {
|
if (version_num >= 90) {
|
||||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
return;
|
return;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user