[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)
This commit is contained in:
parent
6a84164add
commit
6a92ff93e1
16
CMakeLists.txt
Executable file → Normal file
16
CMakeLists.txt
Executable file → Normal file
@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${ALLSPARK_SRCS}"
|
||||
CUDA_ARCHS "${ALLSPARK_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
|
@ -10,6 +10,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
||||
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
# AllSpark W8A16 quant
|
||||
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1 and not act_order and is_k_full)
|
||||
if as_supported_case:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = (sm_version >= 80 and sm_version < 90)
|
||||
as_supported_case = as_supported_case and supported_arch
|
||||
if supported_arch:
|
||||
has_zp = False
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
|
||||
has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = \
|
||||
ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"quant_type": quant_type,
|
||||
@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
# GPTQ params
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# AllSpark W8A16 params
|
||||
"qw_reorder": qw_reorder if as_supported_case else None,
|
||||
"s_reorder": s_reorder if as_supported_case else None,
|
||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||
"sm_count": sm_count if as_supported_case else None,
|
||||
"sm_version": sm_version if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD":
|
||||
CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||
}
|
||||
|
||||
min_run_time = 1
|
||||
@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if as_supported_case:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="allspark_w8a16_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
|
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
File diff suppressed because it is too large
Load Diff
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
@ -0,0 +1,163 @@
|
||||
#include "allspark_utils.cuh"
|
||||
#include <torch/all.h>
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace allspark {
|
||||
|
||||
// Rearrange B to facilitate Ampere Tensor Core load data
|
||||
// reorder B from (K, N) to (N_32align / 4, K * 4)
|
||||
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
|
||||
template <typename FType>
|
||||
__global__ void __launch_bounds__(128)
|
||||
rearrange_kn_weight_as_n32k16_order_ldg16_kernel(
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int K, const int N, const int N_32align) {
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
|
||||
if (blockIdx.x != gridDim.x - 1) {
|
||||
// Load B
|
||||
// per block process 64(k) * 128(n) B elements
|
||||
// per warp process 16(k) * 128 B elements
|
||||
const int src_row_base_idx =
|
||||
blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2;
|
||||
const int src_col_idx =
|
||||
blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16;
|
||||
uint8_t B_frag[4][16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2);
|
||||
int src_offset = src_row_idx * N + src_col_idx;
|
||||
bool guard = src_row_idx < K && src_col_idx < N;
|
||||
ldg128_cg_0(*reinterpret_cast<uint32_t*>(B_frag[i]),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 1),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 2),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 3), B + src_offset,
|
||||
guard);
|
||||
}
|
||||
|
||||
// reorder B
|
||||
uint8_t B_reorder_frag[8][8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
int dst_i = j % 8;
|
||||
int dst_j = i + (j / 8) * 4;
|
||||
B_reorder_frag[dst_i][dst_j] = B_frag[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Store B
|
||||
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const int dst_col_idx =
|
||||
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
int dst_row_idx = dst_row_base_idx + i;
|
||||
int dst_offset = dst_row_idx * K * 4 + dst_col_idx;
|
||||
bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4);
|
||||
if (guard) {
|
||||
*reinterpret_cast<int2*>(B_result + dst_offset) =
|
||||
*reinterpret_cast<int2*>(B_reorder_frag[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load B_scale and B_zero
|
||||
FType b_scale_reg, b_zero_reg;
|
||||
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||
if (B_zero != nullptr)
|
||||
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||
int dst_offset =
|
||||
blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8;
|
||||
if (dst_offset < N_32align) {
|
||||
B_scale_result[dst_offset] = b_scale_reg;
|
||||
if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
void rearrange_kn_weight_as_n32k16_order_ldg16(
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int64_t K, const int64_t N, const int64_t N_32align,
|
||||
cudaStream_t stream) {
|
||||
if (N % 16 != 0 || K % 16 != 0) {
|
||||
std::cerr << "Now only support N and K is multiples of 16" << std::endl;
|
||||
}
|
||||
const int BLOCK = 128;
|
||||
int grid_x = (K + 64 - 1) / 64 + 1;
|
||||
int grid_y = (N + 128 - 1) / 128;
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
rearrange_kn_weight_as_n32k16_order_ldg16_kernel<FType>
|
||||
<<<grid, BLOCK, 0, stream>>>(B, B_scale, B_zero, B_result, B_scale_result,
|
||||
B_zero_result, K, N, N_32align);
|
||||
}
|
||||
} // namespace allspark
|
||||
|
||||
void rearrange_kn_weight_as_n32k16_order(
|
||||
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
|
||||
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
|
||||
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
|
||||
const int64_t N, const int64_t N_32align) {
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");
|
||||
TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_qweight_reorder.device().is_cuda(),
|
||||
"b_qweight_reorder is not on GPU");
|
||||
TORCH_CHECK(b_qweight_reorder.is_contiguous(),
|
||||
"b_qweight_reorder is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales_reorder.device().is_cuda(),
|
||||
"b_scales_reorder is not on GPU");
|
||||
TORCH_CHECK(b_scales_reorder.is_contiguous(),
|
||||
"b_scales_reorder is not contiguous");
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(),
|
||||
"b_zeros_reorder is not on GPU");
|
||||
TORCH_CHECK(b_zeros_reorder.value().is_contiguous(),
|
||||
"b_zeros_reorder is not contiguous");
|
||||
}
|
||||
|
||||
const uint8_t* matB = reinterpret_cast<const uint8_t*>(b_qweight.data_ptr());
|
||||
const void* b_scale = b_scales.data_ptr();
|
||||
const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr;
|
||||
|
||||
uint8_t* matB_reorder =
|
||||
reinterpret_cast<uint8_t*>(b_qweight_reorder.data_ptr());
|
||||
void* b_scale_reorder = b_scales_reorder.data_ptr();
|
||||
void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr;
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (b_scales.dtype() == at::ScalarType::Half) {
|
||||
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>(
|
||||
matB, reinterpret_cast<const __half*>(b_scale),
|
||||
reinterpret_cast<const __half*>(b_zero), matB_reorder,
|
||||
reinterpret_cast<__half*>(b_scale_reorder),
|
||||
reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream);
|
||||
} else if (b_scales.dtype() == at::ScalarType::BFloat16) {
|
||||
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>(
|
||||
matB, reinterpret_cast<const __nv_bfloat16*>(b_scale),
|
||||
reinterpret_cast<const __nv_bfloat16*>(b_zero), matB_reorder,
|
||||
reinterpret_cast<__nv_bfloat16*>(b_scale_reorder),
|
||||
reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("rearrange_kn_weight_as_n32k16_order",
|
||||
&rearrange_kn_weight_as_n32k16_order);
|
||||
}
|
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
@ -0,0 +1,408 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace allspark {
|
||||
|
||||
#define CHECK_CUDA(cmd) \
|
||||
do { \
|
||||
cudaError_t cuda_status = cmd; \
|
||||
if (cuda_status != cudaSuccess) { \
|
||||
std::string err_str = cudaGetErrorString(cuda_status); \
|
||||
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||
<< err_str; \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_CUBLAS(cmd) \
|
||||
do { \
|
||||
cublasStatus_t cublas_status = cmd; \
|
||||
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
|
||||
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||
<< cublas_status << std::endl; \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename FType, typename QType>
|
||||
struct SM8x_GEMM_W8A16_Splitk_Params {
|
||||
const FType* A_ptr;
|
||||
const QType* B_ptr;
|
||||
const FType* B_scale_ptr;
|
||||
const FType* B_zero_ptr;
|
||||
FType* C_ptr;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int SplitK;
|
||||
int GroupCnt;
|
||||
int GroupSize;
|
||||
FType* C_split_ptr; // for non-fused splitk reduce
|
||||
float* C_tmp_ptr; // for fused splitk reduce
|
||||
uint32_t* red_count_ptr; // for fused splitk reduce
|
||||
};
|
||||
|
||||
struct alignas(16) BlockTileSplitkParams {
|
||||
int Mtile;
|
||||
int Ntile;
|
||||
int SplitK;
|
||||
bool EnableFuse;
|
||||
};
|
||||
|
||||
template <typename FType, int BLOCK, int N_MATRIX>
|
||||
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
uint32_t n, uint32_t n_matrix,
|
||||
uint32_t matrix_size) {
|
||||
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
|
||||
if (idx >= matrix_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
FType sum(0);
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += C_split[idx + i * matrix_size];
|
||||
}
|
||||
|
||||
C[idx] = sum;
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m,
|
||||
const uint32_t n, const uint32_t n_matrix,
|
||||
cudaStream_t stream) {
|
||||
const int BLOCK = 128;
|
||||
uint32_t matrix_size = m * n;
|
||||
int grid = (matrix_size + BLOCK - 1) / BLOCK;
|
||||
|
||||
void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr;
|
||||
|
||||
switch (n_matrix) {
|
||||
case 4:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 4>;
|
||||
break;
|
||||
case 5:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 5>;
|
||||
break;
|
||||
case 6:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 6>;
|
||||
break;
|
||||
case 7:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 7>;
|
||||
break;
|
||||
case 8:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 8>;
|
||||
break;
|
||||
case 9:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 9>;
|
||||
break;
|
||||
case 10:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 10>;
|
||||
break;
|
||||
case 11:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 11>;
|
||||
break;
|
||||
case 12:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 12>;
|
||||
break;
|
||||
default:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, -1>;
|
||||
break;
|
||||
}
|
||||
|
||||
kernel<<<grid, BLOCK, 0, stream>>>(C_split, C, n, n_matrix, matrix_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct HalfType;
|
||||
template <>
|
||||
struct HalfType<half> {
|
||||
using T1 = __half;
|
||||
using T2 = __half2;
|
||||
};
|
||||
template <>
|
||||
struct HalfType<__nv_bfloat16> {
|
||||
using T1 = __nv_bfloat16;
|
||||
using T2 = __nv_bfloat162;
|
||||
};
|
||||
|
||||
// convert 64-bit pointer to 32-bit smem addr
|
||||
__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) {
|
||||
uint32_t addr;
|
||||
asm("{.reg .u64 u64addr;\n"
|
||||
" cvta.to.shared.u64 u64addr, %1;\n"
|
||||
" cvt.u32.u64 %0, u64addr;}\n"
|
||||
: "=r"(addr)
|
||||
: "l"(smem_ptr));
|
||||
|
||||
return addr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) {
|
||||
static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" @!p mov.b16 %0, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n"
|
||||
#else
|
||||
" @p ld.global.ca.b16 {%0}, [%1];}\n"
|
||||
#endif
|
||||
: "=h"(reinterpret_cast<uint16_t&>(r0))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr,
|
||||
bool guard) {
|
||||
static_assert(sizeof(T) == 4, "ldg64_ca: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n"
|
||||
#else
|
||||
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n"
|
||||
#endif
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3,
|
||||
const void* ptr, bool guard) {
|
||||
static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @!p mov.b32 %0, 0;\n"
|
||||
" @!p mov.b32 %1, 0;\n"
|
||||
" @!p mov.b32 %2, 0;\n"
|
||||
" @!p mov.b32 %3, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||
#else
|
||||
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||
#endif
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3,
|
||||
const uint32_t addr) {
|
||||
static_assert(sizeof(T) == 4, "lds128: invalid T");
|
||||
|
||||
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(reinterpret_cast<uint32_t&>(reg0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg3))
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2,
|
||||
const T& r3, const void* ptr,
|
||||
bool guard) {
|
||||
static_assert(sizeof(T) == 4, "stg128: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n"
|
||||
:
|
||||
: "l"(ptr), "r"((int)guard), "r"(reinterpret_cast<const uint32_t&>(r0)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r1)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r2)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r3)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3,
|
||||
const uint32_t& addr) {
|
||||
static_assert(sizeof(T) == 4, "ldsm_4: invalid T");
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||
: "r"(addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
__device__ __forceinline__ void hmma16816_f32(float (&d)[4],
|
||||
const uint32_t (&a)[4],
|
||||
const uint32_t (&b)[2]);
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4],
|
||||
const uint32_t (&a)[4],
|
||||
const uint32_t (&b)[2]) {
|
||||
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>(
|
||||
float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) {
|
||||
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int SIZE_IN_BYTES>
|
||||
__device__ __forceinline__ void cp_async(const uint32_t smem_addr,
|
||||
const void* gmem_ptr,
|
||||
const int src_in_bytes, bool guard) {
|
||||
static_assert(
|
||||
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||
"Size is not supported");
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile(
|
||||
"{.reg.pred p;\n"
|
||||
" setp.ne.b32 p, %4, 0;\n"
|
||||
#if __CUDACC_VER_MINOR__ >= 4
|
||||
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||
#else
|
||||
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n"
|
||||
#endif
|
||||
::"r"(smem_addr),
|
||||
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int SIZE_IN_BYTES>
|
||||
__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr,
|
||||
const void* gmem_ptr,
|
||||
const int src_in_bytes,
|
||||
bool guard) {
|
||||
static_assert(
|
||||
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||
"Size is not supported");
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile(
|
||||
"{.reg.pred p;\n"
|
||||
" setp.ne.b32 p, %4, 0;\n"
|
||||
#if __CUDACC_VER_MINOR__ >= 4
|
||||
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||
#else
|
||||
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n"
|
||||
#endif
|
||||
::"r"(smem_addr),
|
||||
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void cp_async_commit_group() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.commit_group;\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ __forceinline__ void cp_asyc_wait_group() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_group %0;\n" : : "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata,
|
||||
T* fdata);
|
||||
|
||||
template <>
|
||||
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>(
|
||||
const uint32_t& idata, __half2* fdata) {
|
||||
uint32_t i10, i32;
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %2, 0x64, 0x4140;"
|
||||
"prmt.b32 %1, %2, 0x64, 0x4342;"
|
||||
: "=r"(i10), "=r"(i32)
|
||||
: "r"(idata));
|
||||
|
||||
static constexpr uint32_t MAGIC_NUM = 0x64806480;
|
||||
fdata[0] = __hsub2(reinterpret_cast<const __half2&>(i10),
|
||||
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||
fdata[1] = __hsub2(reinterpret_cast<const __half2&>(i32),
|
||||
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
|
||||
// reference from marlin fast implementation
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
|
||||
const uint32_t& idata, __nv_bfloat162* fdata) {
|
||||
float fp32_imd[4];
|
||||
uint32_t* fp32_imd_casted = reinterpret_cast<uint32_t*>(fp32_imd);
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
|
||||
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
|
||||
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
|
||||
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
|
||||
: "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]),
|
||||
"=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3])
|
||||
: "r"(idata));
|
||||
|
||||
fp32_imd[0] -= 8388736.f;
|
||||
fp32_imd[1] -= 8388736.f;
|
||||
fp32_imd[2] -= 8388736.f;
|
||||
fp32_imd[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_res = reinterpret_cast<uint32_t*>(fdata);
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %2, %3, 0x7632;"
|
||||
"prmt.b32 %1, %4, %5, 0x7632;"
|
||||
: "=r"(bf16_res[0]), "=r"(bf16_res[1])
|
||||
: "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]),
|
||||
"r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3]));
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat162bfloat162(x);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
} // namespace allspark
|
@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_zeros, "
|
||||
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
|
||||
"Tensor!? b_zeros_reorder, "
|
||||
"int K, int N, int N_32align) -> ()");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// AllSpark quantization ops
|
||||
ops.def(
|
||||
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_qzeros, "
|
||||
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
|
||||
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
100
tests/kernels/test_allspark_gemm.py
Normal file
100
tests/kernels/test_allspark_gemm.py
Normal file
@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int,
|
||||
max_capability: int) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
|
||||
return capability.to_int() >= min_capability \
|
||||
and capability.to_int() <= max_capability
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 4, 8),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(48, 16, 24),
|
||||
(67, 13, 88),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
(1033, 9, 17),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
HAS_ZP_OPTS = [False, True]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_gptq_allspark_supported(80, 89),
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
m = m_factor
|
||||
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
|
||||
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
|
||||
|
||||
input = rand_data((m, k), dtype=dtype)
|
||||
weight = rand_data((k, n), dtype=dtype)
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
|
||||
group_size, has_zp)
|
||||
|
||||
qw = qw.to(torch.uint8)
|
||||
if has_zp:
|
||||
zp = zp.to(dtype)
|
||||
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
|
||||
n_32align))
|
||||
|
||||
opcheck(torch.ops._C.allspark_w8a16_gemm,
|
||||
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
|
||||
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
|
||||
n, group_size, sm_count, sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp, True)
|
||||
|
||||
output_ref = torch.matmul(input, w_ref)
|
||||
torch.cuda.synchronize()
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
@ -215,8 +215,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
assert qkv_proj.scheme.group_size == (-1
|
||||
if group is None else group)
|
||||
|
||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||
assert qkv_proj.weight_scale.dtype is torch.float16
|
||||
assert qkv_proj.scheme.pack_factor == pack_factor
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
@ -404,6 +404,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||
|
||||
@register_fake("_C::allspark_w8a16_gemm")
|
||||
def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_qzeros: Optional[torch.Tensor],
|
||||
n: torch.SymInt, group_size: torch.SymInt,
|
||||
sm_count: torch.SymInt,
|
||||
sm_version: torch.SymInt,
|
||||
CUBLAS_M_THRESHOLD: torch.SymInt,
|
||||
has_zp: bool,
|
||||
n32k16_reorder: bool) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
|
||||
@register_fake("_C::ggml_dequantize")
|
||||
@ -881,6 +897,67 @@ def scaled_fp8_quant(
|
||||
return output, scale
|
||||
|
||||
|
||||
# gptq allspark
|
||||
def allspark_repack_weight(
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor] = None,
|
||||
has_zp: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
|
||||
for Ampere W8A16 Fused Gemm kernel
|
||||
|
||||
Args:
|
||||
qweight: uint8 weight tensor, original k x n format.
|
||||
scale: fp16/bf16 weight scale tensor, 1 x n format.
|
||||
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
|
||||
Must be provided for asymmetric quantization.
|
||||
has_zp: if use symmetric quantization, has_zp = False.
|
||||
if use asymmetric quantization, has_zp = True.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
|
||||
rearranged weight, scale, and optionally zero_point.
|
||||
"""
|
||||
K = qweight.shape[0]
|
||||
N = qweight.shape[1]
|
||||
N_32align = (N + 32 - 1) // 32 * 32
|
||||
|
||||
qweight_reorder = torch.empty((N_32align, K),
|
||||
device=qweight.device,
|
||||
dtype=qweight.dtype)
|
||||
scale_reorder = torch.empty((1, N_32align),
|
||||
device=scale.device,
|
||||
dtype=scale.dtype)
|
||||
zero_point_reorder = None
|
||||
if has_zp:
|
||||
assert zero_point is not None, (
|
||||
"zero_point must be provided for asymmetric quantization.")
|
||||
zero_point_reorder = torch.empty((1, N_32align),
|
||||
device=zero_point.device,
|
||||
dtype=zero_point.dtype)
|
||||
|
||||
torch.ops._C.rearrange_kn_weight_as_n32k16_order(
|
||||
qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder,
|
||||
zero_point_reorder, K, N, N_32align)
|
||||
|
||||
return qweight_reorder, scale_reorder, zero_point_reorder
|
||||
|
||||
|
||||
def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_qzeros: Optional[torch.Tensor], n: int,
|
||||
group_size: int, sm_count: int, sm_version: int,
|
||||
CUBLAS_M_THRESHOLD: int, has_zp: bool,
|
||||
n32k16_reorder: bool) -> torch.Tensor:
|
||||
|
||||
return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros,
|
||||
n, group_size, sm_count,
|
||||
sm_version, CUBLAS_M_THRESHOLD,
|
||||
has_zp, n32k16_reorder)
|
||||
|
||||
|
||||
# int8
|
||||
def scaled_int8_quant(
|
||||
input: torch.Tensor,
|
||||
|
@ -3,6 +3,8 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
@ -16,6 +18,7 @@ from vllm.platforms import current_platform
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
]
|
||||
|
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args['sm_count'] = sm_count
|
||||
gemm_args['sm_version'] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(old_weight_param.data,
|
||||
requires_grad=False)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous().view(
|
||||
dtype=torch.uint8)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data,
|
||||
requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = \
|
||||
ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None,
|
||||
c.zero_points)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args['sm_count'],
|
||||
sm_version=gemm_args['sm_version'],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||
|
||||
|
||||
def check_allspark_supported_dtype_shape(input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
group_size: int,
|
||||
weight_dtype: ScalarType,
|
||||
act_dtype: torch.dtype):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
# For Ampere GPU
|
||||
if device_capability >= 80 and device_capability < 90:
|
||||
if group_size != -1:
|
||||
return False, \
|
||||
"For Ampere GPU, AllSpark does not support group_size "\
|
||||
f"= {group_size}. Only group_size = -1 are supported."
|
||||
|
||||
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||
return False, "For Ampere GPU, AllSpark does not support "\
|
||||
f"quant type ({weight_dtype}). Only quant type "\
|
||||
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
|
||||
|
||||
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
|
||||
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
|
||||
return False, \
|
||||
"AllSpark needs input_size_per_partition % "\
|
||||
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
|
||||
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
|
||||
"for Ampere GPU optimized kernels."
|
||||
|
||||
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||
return False, \
|
||||
"AllSpark only supports act_dtype = float16 or bfloat16,"\
|
||||
f"for Ampere GPU, but got act_dtype = {act_dtype}."
|
||||
else:
|
||||
return False, "AllSpark currently does not support "\
|
||||
f"device_capability = {device_capability}."
|
||||
|
||||
return True, None
|
Loading…
x
Reference in New Issue
Block a user