[Build] Only build 9.0a for scaled_mm and sparse kernels (#12339)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
ce69f7f754
commit
103bd17ac5
@ -275,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
@ -296,8 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -351,7 +351,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
|
@ -259,7 +259,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||
# 9.0a to the result.
|
||||
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -270,34 +270,47 @@ endmacro()
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||
|
||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
||||
# less or eqault to ARCH
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
set(_TMP_ARCH ${_SRC_ARCH})
|
||||
else()
|
||||
break()
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check major-version match AND version-less-or-equal
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
# If we hit a version greater than the target, we can break
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
|
Loading…
x
Reference in New Issue
Block a user