[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2) (#11311)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
fdea8ec167
commit
5a9da2e6e9
@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
#
|
||||
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
|
||||
# For Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
|
||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 sparse or quantized models on "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building cutlass_c3x as no compatible archs found "
|
||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
|
@ -2,6 +2,7 @@
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
@ -161,3 +162,4 @@ bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a);
|
||||
#endif
|
||||
@ -28,7 +28,7 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
// clang-format on
|
||||
|
||||
|
@ -5,7 +5,18 @@
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
|
||||
// sparse CUTLASS kernels need at least
|
||||
// CUDA 12.2 and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& e,
|
||||
@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||
bias);
|
||||
|
@ -321,6 +321,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||
// given capability
|
||||
ops.def(
|
||||
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_sparse_scaled_mm_supported",
|
||||
&cutlass_sparse_scaled_mm_supported);
|
||||
|
||||
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
|
@ -8,6 +8,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
@ -102,10 +104,11 @@ def baseline_scaled_mm(a: torch.Tensor,
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
|
@ -14,6 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -212,7 +214,7 @@ def test_compressed_tensors_kv_cache(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
@ -254,7 +256,7 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize("args_2of4", [
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||
@ -279,7 +281,7 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
|
@ -552,6 +552,11 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_sparse_scaled_mm_supported(
|
||||
cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
@ -9,7 +9,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
convert_to_channelwise, sparse_cutlass_supported)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@ -40,6 +40,11 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with"
|
||||
"CUDA 12.2 or later to use this feature")
|
||||
|
||||
self.output_dtype = params_dtype
|
||||
layer.logical_widths = output_partition_sizes
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
@ -10,6 +10,17 @@ from vllm.platforms import current_platform
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def sparse_cutlass_supported() -> bool:
|
||||
# sparse cutlass is not supported on Rocm
|
||||
if current_platform.is_rocm():
|
||||
return False
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
return ops.cutlass_sparse_scaled_mm_supported(capability)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
# cutlass is not supported on Rocm
|
||||
if current_platform.is_rocm():
|
||||
|
Loading…
x
Reference in New Issue
Block a user