[Build] Only build 9.0a for scaled_mm and sparse kernels (#12339)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
ce69f7f754
commit
103bd17ac5
@ -275,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||||
# are not supported by Machete yet.
|
# are not supported by Machete yet.
|
||||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_ARCHS)
|
if (MARLIN_ARCHS)
|
||||||
set(MARLIN_SRCS
|
set(MARLIN_SRCS
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
@ -296,8 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
@ -351,7 +351,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# 2:4 Sparse Kernels
|
# 2:4 Sparse Kernels
|
||||||
|
|
||||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||||
|
@ -259,7 +259,7 @@ endmacro()
|
|||||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||||
# 9.0a to the result.
|
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
|
||||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||||
#
|
#
|
||||||
# Example:
|
# Example:
|
||||||
@ -270,32 +270,45 @@ endmacro()
|
|||||||
#
|
#
|
||||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||||
|
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||||
|
|
||||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||||
set(_CUDA_ARCHS)
|
set(_CUDA_ARCHS)
|
||||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||||
|
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||||
set(_CUDA_ARCHS "9.0a")
|
set(_CUDA_ARCHS "9.0a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
|
||||||
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||||
# less or eqault to ARCH
|
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||||
foreach(_ARCH ${CUDA_ARCHS})
|
# compatibility is only forward compatible within the same major version).
|
||||||
|
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||||
set(_TMP_ARCH)
|
set(_TMP_ARCH)
|
||||||
|
# Extract the major version of the target arch
|
||||||
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||||
|
# Extract the major version of the source arch
|
||||||
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||||
|
# Check major-version match AND version-less-or-equal
|
||||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||||
set(_TMP_ARCH ${_SRC_ARCH})
|
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||||
|
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
|
# If we hit a version greater than the target, we can break
|
||||||
break()
|
break()
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
||||||
if (_TMP_ARCH)
|
if (_TMP_ARCH)
|
||||||
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user