[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)

This commit is contained in:
YajieWang 2025-03-01 14:30:59 +08:00 committed by GitHub
parent 6a84164add
commit 6a92ff93e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 2005 additions and 4 deletions

16
CMakeLists.txt Executable file → Normal file
View File

@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures") " in CUDA target architectures")
endif() endif()
# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${ALLSPARK_SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now). # CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")

View File

@ -10,6 +10,8 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, sort_weights) gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
ACT_ORDER_OPTS = [False, True] ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL) GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
# AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1 and not act_order and is_k_full)
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
supported_arch = (sm_version >= 80 and sm_version < 90)
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
has_zp)
qw = qw.to(torch.uint8)
qw_reorder, s_reorder, zp_reorder = \
ops.allspark_repack_weight(
qw, s, zp, has_zp)
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals = { globals = {
# Gen params # Gen params
"quant_type": quant_type, "quant_type": quant_type,
@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params # GPTQ params
"q_w_gptq": q_w_gptq, "q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices, "repack_sort_indices": repack_sort_indices,
# AllSpark W8A16 params
"qw_reorder": qw_reorder if as_supported_case else None,
"s_reorder": s_reorder if as_supported_case else None,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD":
CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels # Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack, "gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
} }
min_run_time = 1 min_run_time = 1
@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description="gptq_marlin_repack", description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time))
if as_supported_case:
results.append(
benchmark.Timer(
stmt=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
def main(args): def main(args):
print("Benchmarking models:") print("Benchmarking models:")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,163 @@
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "core/registration.h"
namespace allspark {
// Rearrange B to facilitate Ampere Tensor Core load data
// reorder B from (K, N) to (N_32align / 4, K * 4)
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
template <typename FType>
__global__ void __launch_bounds__(128)
rearrange_kn_weight_as_n32k16_order_ldg16_kernel(
const uint8_t* B, const FType* B_scale, const FType* B_zero,
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
const int K, const int N, const int N_32align) {
const int lane_id = threadIdx.x % 32;
const int warp_id = threadIdx.x / 32;
if (blockIdx.x != gridDim.x - 1) {
// Load B
// per block process 64(k) * 128(n) B elements
// per warp process 16(k) * 128 B elements
const int src_row_base_idx =
blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2;
const int src_col_idx =
blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16;
uint8_t B_frag[4][16];
#pragma unroll
for (int i = 0; i < 4; ++i) {
int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2);
int src_offset = src_row_idx * N + src_col_idx;
bool guard = src_row_idx < K && src_col_idx < N;
ldg128_cg_0(*reinterpret_cast<uint32_t*>(B_frag[i]),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 1),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 2),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 3), B + src_offset,
guard);
}
// reorder B
uint8_t B_reorder_frag[8][8];
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int j = 0; j < 16; ++j) {
int dst_i = j % 8;
int dst_j = i + (j / 8) * 4;
B_reorder_frag[dst_i][dst_j] = B_frag[i][j];
}
}
// Store B
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
const int dst_col_idx =
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
for (int i = 0; i < 8; ++i) {
int dst_row_idx = dst_row_base_idx + i;
int dst_offset = dst_row_idx * K * 4 + dst_col_idx;
bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4);
if (guard) {
*reinterpret_cast<int2*>(B_result + dst_offset) =
*reinterpret_cast<int2*>(B_reorder_frag[i]);
}
}
} else {
// Load B_scale and B_zero
FType b_scale_reg, b_zero_reg;
int src_offset = blockIdx.y * 128 + threadIdx.x;
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
if (B_zero != nullptr)
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
int dst_offset =
blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8;
if (dst_offset < N_32align) {
B_scale_result[dst_offset] = b_scale_reg;
if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg;
}
}
}
template <typename FType>
void rearrange_kn_weight_as_n32k16_order_ldg16(
const uint8_t* B, const FType* B_scale, const FType* B_zero,
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
const int64_t K, const int64_t N, const int64_t N_32align,
cudaStream_t stream) {
if (N % 16 != 0 || K % 16 != 0) {
std::cerr << "Now only support N and K is multiples of 16" << std::endl;
}
const int BLOCK = 128;
int grid_x = (K + 64 - 1) / 64 + 1;
int grid_y = (N + 128 - 1) / 128;
dim3 grid(grid_x, grid_y);
rearrange_kn_weight_as_n32k16_order_ldg16_kernel<FType>
<<<grid, BLOCK, 0, stream>>>(B, B_scale, B_zero, B_result, B_scale_result,
B_zero_result, K, N, N_32align);
}
} // namespace allspark
void rearrange_kn_weight_as_n32k16_order(
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
const int64_t N, const int64_t N_32align) {
// Verify device and strides
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");
TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_qweight_reorder.device().is_cuda(),
"b_qweight_reorder is not on GPU");
TORCH_CHECK(b_qweight_reorder.is_contiguous(),
"b_qweight_reorder is not contiguous");
TORCH_CHECK(b_scales_reorder.device().is_cuda(),
"b_scales_reorder is not on GPU");
TORCH_CHECK(b_scales_reorder.is_contiguous(),
"b_scales_reorder is not contiguous");
if (has_zp) {
TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous");
TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(),
"b_zeros_reorder is not on GPU");
TORCH_CHECK(b_zeros_reorder.value().is_contiguous(),
"b_zeros_reorder is not contiguous");
}
const uint8_t* matB = reinterpret_cast<const uint8_t*>(b_qweight.data_ptr());
const void* b_scale = b_scales.data_ptr();
const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr;
uint8_t* matB_reorder =
reinterpret_cast<uint8_t*>(b_qweight_reorder.data_ptr());
void* b_scale_reorder = b_scales_reorder.data_ptr();
void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (b_scales.dtype() == at::ScalarType::Half) {
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>(
matB, reinterpret_cast<const __half*>(b_scale),
reinterpret_cast<const __half*>(b_zero), matB_reorder,
reinterpret_cast<__half*>(b_scale_reorder),
reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream);
} else if (b_scales.dtype() == at::ScalarType::BFloat16) {
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>(
matB, reinterpret_cast<const __nv_bfloat16*>(b_scale),
reinterpret_cast<const __nv_bfloat16*>(b_zero), matB_reorder,
reinterpret_cast<__nv_bfloat16*>(b_scale_reorder),
reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align,
stream);
}
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("rearrange_kn_weight_as_n32k16_order",
&rearrange_kn_weight_as_n32k16_order);
}

View File

@ -0,0 +1,408 @@
#pragma once
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
namespace allspark {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t cuda_status = cmd; \
if (cuda_status != cudaSuccess) { \
std::string err_str = cudaGetErrorString(cuda_status); \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< err_str; \
exit(-1); \
} \
} while (0)
#define CHECK_CUBLAS(cmd) \
do { \
cublasStatus_t cublas_status = cmd; \
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< cublas_status << std::endl; \
exit(-1); \
} \
} while (0)
template <typename FType, typename QType>
struct SM8x_GEMM_W8A16_Splitk_Params {
const FType* A_ptr;
const QType* B_ptr;
const FType* B_scale_ptr;
const FType* B_zero_ptr;
FType* C_ptr;
int M;
int N;
int K;
int SplitK;
int GroupCnt;
int GroupSize;
FType* C_split_ptr; // for non-fused splitk reduce
float* C_tmp_ptr; // for fused splitk reduce
uint32_t* red_count_ptr; // for fused splitk reduce
};
struct alignas(16) BlockTileSplitkParams {
int Mtile;
int Ntile;
int SplitK;
bool EnableFuse;
};
template <typename FType, int BLOCK, int N_MATRIX>
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
uint32_t n, uint32_t n_matrix,
uint32_t matrix_size) {
int idx = blockIdx.x * BLOCK + threadIdx.x;
if (idx >= matrix_size) {
return;
}
FType sum(0);
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) {
sum += C_split[idx + i * matrix_size];
}
C[idx] = sum;
}
template <typename FType>
void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m,
const uint32_t n, const uint32_t n_matrix,
cudaStream_t stream) {
const int BLOCK = 128;
uint32_t matrix_size = m * n;
int grid = (matrix_size + BLOCK - 1) / BLOCK;
void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr;
switch (n_matrix) {
case 4:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 4>;
break;
case 5:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 5>;
break;
case 6:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 6>;
break;
case 7:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 7>;
break;
case 8:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 8>;
break;
case 9:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 9>;
break;
case 10:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 10>;
break;
case 11:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 11>;
break;
case 12:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 12>;
break;
default:
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, -1>;
break;
}
kernel<<<grid, BLOCK, 0, stream>>>(C_split, C, n, n_matrix, matrix_size);
}
template <typename T>
struct HalfType;
template <>
struct HalfType<half> {
using T1 = __half;
using T2 = __half2;
};
template <>
struct HalfType<__nv_bfloat16> {
using T1 = __nv_bfloat16;
using T2 = __nv_bfloat162;
};
// convert 64-bit pointer to 32-bit smem addr
__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) {
uint32_t addr;
asm("{.reg .u64 u64addr;\n"
" cvta.to.shared.u64 u64addr, %1;\n"
" cvt.u32.u64 %0, u64addr;}\n"
: "=r"(addr)
: "l"(smem_ptr));
return addr;
}
template <typename T>
__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) {
static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T");
asm volatile(
"{.reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" @!p mov.b16 %0, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n"
#else
" @p ld.global.ca.b16 {%0}, [%1];}\n"
#endif
: "=h"(reinterpret_cast<uint16_t&>(r0))
: "l"(ptr), "r"((int)guard));
}
template <typename T>
__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr,
bool guard) {
static_assert(sizeof(T) == 4, "ldg64_ca: invalid T");
asm volatile(
"{.reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n"
#else
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n"
#endif
: "=r"(reinterpret_cast<uint32_t&>(r0)),
"=r"(reinterpret_cast<uint32_t&>(r1))
: "l"(ptr), "r"((int)guard));
}
template <typename T>
__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3,
const void* ptr, bool guard) {
static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T");
asm volatile(
"{.reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @!p mov.b32 %0, 0;\n"
" @!p mov.b32 %1, 0;\n"
" @!p mov.b32 %2, 0;\n"
" @!p mov.b32 %3, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
#else
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
#endif
: "=r"(reinterpret_cast<uint32_t&>(r0)),
"=r"(reinterpret_cast<uint32_t&>(r1)),
"=r"(reinterpret_cast<uint32_t&>(r2)),
"=r"(reinterpret_cast<uint32_t&>(r3))
: "l"(ptr), "r"((int)guard));
}
template <typename T>
__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3,
const uint32_t addr) {
static_assert(sizeof(T) == 4, "lds128: invalid T");
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(reinterpret_cast<uint32_t&>(reg0)),
"=r"(reinterpret_cast<uint32_t&>(reg1)),
"=r"(reinterpret_cast<uint32_t&>(reg2)),
"=r"(reinterpret_cast<uint32_t&>(reg3))
: "r"(addr));
}
template <typename T>
__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2,
const T& r3, const void* ptr,
bool guard) {
static_assert(sizeof(T) == 4, "stg128: invalid T");
asm volatile(
"{.reg .pred p;\n"
" setp.ne.b32 p, %1, 0;\n"
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n"
:
: "l"(ptr), "r"((int)guard), "r"(reinterpret_cast<const uint32_t&>(r0)),
"r"(reinterpret_cast<const uint32_t&>(r1)),
"r"(reinterpret_cast<const uint32_t&>(r2)),
"r"(reinterpret_cast<const uint32_t&>(r3)));
}
template <typename T>
__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3,
const uint32_t& addr) {
static_assert(sizeof(T) == 4, "ldsm_4: invalid T");
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(reinterpret_cast<uint32_t&>(r0)),
"=r"(reinterpret_cast<uint32_t&>(r1)),
"=r"(reinterpret_cast<uint32_t&>(r2)),
"=r"(reinterpret_cast<uint32_t&>(r3))
: "r"(addr));
#endif
}
template <typename FType>
__device__ __forceinline__ void hmma16816_f32(float (&d)[4],
const uint32_t (&a)[4],
const uint32_t (&b)[2]);
template <>
__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4],
const uint32_t (&a)[4],
const uint32_t (&b)[2]) {
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
#endif
}
template <>
__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>(
float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) {
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
#endif
}
template <int SIZE_IN_BYTES>
__device__ __forceinline__ void cp_async(const uint32_t smem_addr,
const void* gmem_ptr,
const int src_in_bytes, bool guard) {
static_assert(
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
"Size is not supported");
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile(
"{.reg.pred p;\n"
" setp.ne.b32 p, %4, 0;\n"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
#else
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n"
#endif
::"r"(smem_addr),
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
#endif
}
template <int SIZE_IN_BYTES>
__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr,
const void* gmem_ptr,
const int src_in_bytes,
bool guard) {
static_assert(
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
"Size is not supported");
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile(
"{.reg.pred p;\n"
" setp.ne.b32 p, %4, 0;\n"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
#else
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n"
#endif
::"r"(smem_addr),
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
#endif
}
__device__ __forceinline__ void cp_async_commit_group() {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n");
#endif
}
template <int N>
__device__ __forceinline__ void cp_asyc_wait_group() {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" : : "n"(N));
#endif
}
template <typename T>
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata,
T* fdata);
template <>
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>(
const uint32_t& idata, __half2* fdata) {
uint32_t i10, i32;
asm volatile(
"prmt.b32 %0, %2, 0x64, 0x4140;"
"prmt.b32 %1, %2, 0x64, 0x4342;"
: "=r"(i10), "=r"(i32)
: "r"(idata));
static constexpr uint32_t MAGIC_NUM = 0x64806480;
fdata[0] = __hsub2(reinterpret_cast<const __half2&>(i10),
reinterpret_cast<const __half2&>(MAGIC_NUM));
fdata[1] = __hsub2(reinterpret_cast<const __half2&>(i32),
reinterpret_cast<const __half2&>(MAGIC_NUM));
}
template <>
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
// reference from marlin fast implementation
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
const uint32_t& idata, __nv_bfloat162* fdata) {
float fp32_imd[4];
uint32_t* fp32_imd_casted = reinterpret_cast<uint32_t*>(fp32_imd);
asm volatile(
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
: "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]),
"=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3])
: "r"(idata));
fp32_imd[0] -= 8388736.f;
fp32_imd[1] -= 8388736.f;
fp32_imd[2] -= 8388736.f;
fp32_imd[3] -= 8388736.f;
uint32_t* bf16_res = reinterpret_cast<uint32_t*>(fdata);
asm volatile(
"prmt.b32 %0, %2, %3, 0x7632;"
"prmt.b32 %1, %4, %5, 0x7632;"
: "=r"(bf16_res[0]), "=r"(bf16_res[1])
: "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]),
"r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3]));
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(x);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
} // namespace allspark

View File

@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor!? azp) -> ()"); "Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant); &dynamic_scaled_int8_quant);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()");
// conditionally compiled so impl in source file
// AllSpark quantization ops
ops.def(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
// conditionally compiled so impl in source file
#endif
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

View File

@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
def is_gptq_allspark_supported(min_capability: int,
max_capability: int) -> bool:
if not current_platform.is_cuda():
return False
capability = current_platform.get_device_capability()
assert capability is not None
return capability.to_int() >= min_capability \
and capability.to_int() <= max_capability
MNK_FACTORS = [
(1, 4, 8),
(13, 17, 67),
(26, 37, 13),
(48, 16, 24),
(67, 13, 88),
(257, 13, 11),
(658, 13, 11),
(1033, 9, 17),
]
DTYPES = [torch.float16, torch.bfloat16]
HAS_ZP_OPTS = [False, True]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(
not is_gptq_allspark_supported(80, 89),
reason="AllSpark Ampere kernel is not supported on this GPU type.")
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
m_factor, n_factor, k_factor = mnk_factors
m = m_factor
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
input = rand_data((m, k), dtype=dtype)
weight = rand_data((k, n), dtype=dtype)
# Quantize (and apply act_order if provided)
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
group_size, has_zp)
qw = qw.to(torch.uint8)
if has_zp:
zp = zp.to(dtype)
properties = torch.cuda.get_device_properties(qw.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
n_32align = (n + 32 - 1) // 32 * 32
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
qw, s, zp, has_zp)
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
n_32align))
opcheck(torch.ops._C.allspark_w8a16_gemm,
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
n, group_size, sm_count, sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp, True)
output_ref = torch.matmul(input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

View File

@ -215,8 +215,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert qkv_proj.scheme.group_size == (-1 assert qkv_proj.scheme.group_size == (-1
if group is None else group) if group is None else group)
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.scheme.pack_factor == pack_factor assert qkv_proj.scheme.pack_factor == pack_factor
llm.apply_model(check_model) llm.apply_model(check_model)

View File

@ -404,6 +404,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format=torch.contiguous_format) memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
@register_fake("_C::allspark_w8a16_gemm")
def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
n: torch.SymInt, group_size: torch.SymInt,
sm_count: torch.SymInt,
sm_version: torch.SymInt,
CUBLAS_M_THRESHOLD: torch.SymInt,
has_zp: bool,
n32k16_reorder: bool) -> torch.Tensor:
m = a.size(0)
return torch.empty((m, n), device=a.device, dtype=a.dtype)
if hasattr(torch.ops._C, "ggml_dequantize"): if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize") @register_fake("_C::ggml_dequantize")
@ -881,6 +897,67 @@ def scaled_fp8_quant(
return output, scale return output, scale
# gptq allspark
def allspark_repack_weight(
qweight: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
has_zp: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
Args:
qweight: uint8 weight tensor, original k x n format.
scale: fp16/bf16 weight scale tensor, 1 x n format.
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True.
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K = qweight.shape[0]
N = qweight.shape[1]
N_32align = (N + 32 - 1) // 32 * 32
qweight_reorder = torch.empty((N_32align, K),
device=qweight.device,
dtype=qweight.dtype)
scale_reorder = torch.empty((1, N_32align),
device=scale.device,
dtype=scale.dtype)
zero_point_reorder = None
if has_zp:
assert zero_point is not None, (
"zero_point must be provided for asymmetric quantization.")
zero_point_reorder = torch.empty((1, N_32align),
device=zero_point.device,
dtype=zero_point.dtype)
torch.ops._C.rearrange_kn_weight_as_n32k16_order(
qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder,
zero_point_reorder, K, N, N_32align)
return qweight_reorder, scale_reorder, zero_point_reorder
def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor], n: int,
group_size: int, sm_count: int, sm_version: int,
CUBLAS_M_THRESHOLD: int, has_zp: bool,
n32k16_reorder: bool) -> torch.Tensor:
return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros,
n, group_size, sm_count,
sm_version, CUBLAS_M_THRESHOLD,
has_zp, n32k16_reorder)
# int8 # int8
def scaled_int8_quant( def scaled_int8_quant(
input: torch.Tensor, input: torch.Tensor,

View File

@ -3,6 +3,8 @@
from typing import List, Optional, Type from typing import List, Optional, Type
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel) ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@ -16,6 +18,7 @@ from vllm.platforms import current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel, MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel, MarlinLinearKernel,
ExllamaLinearKernel, ExllamaLinearKernel,
] ]

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
if c.zero_points:
return False, "Zero points currently not supported by AllSpark"
return check_allspark_supported_dtype_shape(
c.partition_weight_shape[0], # in_features
c.partition_weight_shape[1], # out_features
c.group_size,
c.weight_type,
c.act_type)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
# prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args['sm_count'] = sm_count
gemm_args['sm_version'] = sm_version
self.gemm_args = gemm_args
# transform param weight, scale
old_weight_param = getattr(layer, self.w_q_name)
old_scale_param = getattr(layer, self.w_s_name)
assert isinstance(old_weight_param, BasevLLMParameter)
permute_param_layout_(old_weight_param,
input_dim=0,
output_dim=1,
packed_dim=0)
assert isinstance(old_scale_param, BasevLLMParameter)
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param = torch.nn.Parameter(old_weight_param.data,
requires_grad=False)
new_weight_param.data = new_weight_param.data.t().contiguous().view(
dtype=torch.uint8)
new_weight_param.data = new_weight_param.data.t().contiguous()
new_scale_param = torch.nn.Parameter(old_scale_param.data,
requires_grad=False)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param.data, new_scale_param.data, _ = \
ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None,
c.zero_points)
replace_parameter(layer, self.w_q_name, new_weight_param.data)
replace_parameter(layer, self.w_s_name, new_scale_param.data)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args
w_q, w_s, _, _ = self._get_weight_params(layer)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
output = ops.allspark_w8a16_gemm(
a=reshaped_x,
b_qweight=w_q,
b_scales=w_s,
b_qzeros=None,
n=c.partition_weight_shape[1],
group_size=c.group_size,
sm_count=gemm_args['sm_count'],
sm_version=gemm_args['sm_version'],
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp=c.zero_points,
n32k16_reorder=True)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
ALLSPARK_AMPERE_N_ALIGN = 16
ALLSPARK_AMPERE_K_ALIGN = 16
def check_allspark_supported_dtype_shape(input_size_per_partition: int,
output_size_per_partition: int,
group_size: int,
weight_dtype: ScalarType,
act_dtype: torch.dtype):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
# For Ampere GPU
if device_capability >= 80 and device_capability < 90:
if group_size != -1:
return False, \
"For Ampere GPU, AllSpark does not support group_size "\
f"= {group_size}. Only group_size = -1 are supported."
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
return False, "For Ampere GPU, AllSpark does not support "\
f"quant type ({weight_dtype}). Only quant type "\
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
return False, \
"AllSpark needs input_size_per_partition % "\
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
"for Ampere GPU optimized kernels."
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
return False, \
"AllSpark only supports act_dtype = float16 or bfloat16,"\
f"for Ampere GPU, but got act_dtype = {act_dtype}."
else:
return False, "AllSpark currently does not support "\
f"device_capability = {device_capability}."
return True, None