[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)
This commit is contained in:
parent
6a84164add
commit
6a92ff93e1
16
CMakeLists.txt
Executable file → Normal file
16
CMakeLists.txt
Executable file → Normal file
@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
" in CUDA target architectures")
|
" in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||||
|
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||||
|
if (ALLSPARK_ARCHS)
|
||||||
|
set(ALLSPARK_SRCS
|
||||||
|
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||||
|
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${ALLSPARK_SRCS}"
|
||||||
|
CUDA_ARCHS "${ALLSPARK_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
|
||||||
|
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||||
|
" in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
|
@ -10,6 +10,8 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||||
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||||
@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||||
marlin_24_quantize)
|
marlin_24_quantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||||
from vllm.scalar_type import ScalarType
|
from vllm.scalar_type import ScalarType
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
|
|
||||||
ACT_ORDER_OPTS = [False, True]
|
ACT_ORDER_OPTS = [False, True]
|
||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
|||||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||||
|
|
||||||
|
# AllSpark W8A16 quant
|
||||||
|
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||||
|
and group_size == -1 and not act_order and is_k_full)
|
||||||
|
if as_supported_case:
|
||||||
|
properties = torch.cuda.get_device_properties(b.device.index)
|
||||||
|
sm_count = properties.multi_processor_count
|
||||||
|
sm_version = properties.major * 10 + properties.minor
|
||||||
|
|
||||||
|
supported_arch = (sm_version >= 80 and sm_version < 90)
|
||||||
|
as_supported_case = as_supported_case and supported_arch
|
||||||
|
if supported_arch:
|
||||||
|
has_zp = False
|
||||||
|
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
|
||||||
|
has_zp)
|
||||||
|
qw = qw.to(torch.uint8)
|
||||||
|
|
||||||
|
qw_reorder, s_reorder, zp_reorder = \
|
||||||
|
ops.allspark_repack_weight(
|
||||||
|
qw, s, zp, has_zp)
|
||||||
|
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
# Gen params
|
# Gen params
|
||||||
"quant_type": quant_type,
|
"quant_type": quant_type,
|
||||||
@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
|||||||
# GPTQ params
|
# GPTQ params
|
||||||
"q_w_gptq": q_w_gptq,
|
"q_w_gptq": q_w_gptq,
|
||||||
"repack_sort_indices": repack_sort_indices,
|
"repack_sort_indices": repack_sort_indices,
|
||||||
|
# AllSpark W8A16 params
|
||||||
|
"qw_reorder": qw_reorder if as_supported_case else None,
|
||||||
|
"s_reorder": s_reorder if as_supported_case else None,
|
||||||
|
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||||
|
"sm_count": sm_count if as_supported_case else None,
|
||||||
|
"sm_version": sm_version if as_supported_case else None,
|
||||||
|
"CUBLAS_M_THRESHOLD":
|
||||||
|
CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||||
# Kernels
|
# Kernels
|
||||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||||
|
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||||
}
|
}
|
||||||
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
|||||||
description="gptq_marlin_repack",
|
description="gptq_marlin_repack",
|
||||||
).blocked_autorange(min_run_time=min_run_time))
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
if as_supported_case:
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="allspark_w8a16_gemm_fp32",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
print("Benchmarking models:")
|
print("Benchmarking models:")
|
||||||
|
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
File diff suppressed because it is too large
Load Diff
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
#include "allspark_utils.cuh"
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
|
namespace allspark {
|
||||||
|
|
||||||
|
// Rearrange B to facilitate Ampere Tensor Core load data
|
||||||
|
// reorder B from (K, N) to (N_32align / 4, K * 4)
|
||||||
|
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
|
||||||
|
template <typename FType>
|
||||||
|
__global__ void __launch_bounds__(128)
|
||||||
|
rearrange_kn_weight_as_n32k16_order_ldg16_kernel(
|
||||||
|
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||||
|
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||||
|
const int K, const int N, const int N_32align) {
|
||||||
|
const int lane_id = threadIdx.x % 32;
|
||||||
|
const int warp_id = threadIdx.x / 32;
|
||||||
|
|
||||||
|
if (blockIdx.x != gridDim.x - 1) {
|
||||||
|
// Load B
|
||||||
|
// per block process 64(k) * 128(n) B elements
|
||||||
|
// per warp process 16(k) * 128 B elements
|
||||||
|
const int src_row_base_idx =
|
||||||
|
blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2;
|
||||||
|
const int src_col_idx =
|
||||||
|
blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16;
|
||||||
|
uint8_t B_frag[4][16];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2);
|
||||||
|
int src_offset = src_row_idx * N + src_col_idx;
|
||||||
|
bool guard = src_row_idx < K && src_col_idx < N;
|
||||||
|
ldg128_cg_0(*reinterpret_cast<uint32_t*>(B_frag[i]),
|
||||||
|
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 1),
|
||||||
|
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 2),
|
||||||
|
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 3), B + src_offset,
|
||||||
|
guard);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reorder B
|
||||||
|
uint8_t B_reorder_frag[8][8];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 16; ++j) {
|
||||||
|
int dst_i = j % 8;
|
||||||
|
int dst_j = i + (j / 8) * 4;
|
||||||
|
B_reorder_frag[dst_i][dst_j] = B_frag[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store B
|
||||||
|
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||||
|
const int dst_col_idx =
|
||||||
|
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
int dst_row_idx = dst_row_base_idx + i;
|
||||||
|
int dst_offset = dst_row_idx * K * 4 + dst_col_idx;
|
||||||
|
bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4);
|
||||||
|
if (guard) {
|
||||||
|
*reinterpret_cast<int2*>(B_result + dst_offset) =
|
||||||
|
*reinterpret_cast<int2*>(B_reorder_frag[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Load B_scale and B_zero
|
||||||
|
FType b_scale_reg, b_zero_reg;
|
||||||
|
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||||
|
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||||
|
if (B_zero != nullptr)
|
||||||
|
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||||
|
int dst_offset =
|
||||||
|
blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8;
|
||||||
|
if (dst_offset < N_32align) {
|
||||||
|
B_scale_result[dst_offset] = b_scale_reg;
|
||||||
|
if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FType>
|
||||||
|
void rearrange_kn_weight_as_n32k16_order_ldg16(
|
||||||
|
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||||
|
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||||
|
const int64_t K, const int64_t N, const int64_t N_32align,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
if (N % 16 != 0 || K % 16 != 0) {
|
||||||
|
std::cerr << "Now only support N and K is multiples of 16" << std::endl;
|
||||||
|
}
|
||||||
|
const int BLOCK = 128;
|
||||||
|
int grid_x = (K + 64 - 1) / 64 + 1;
|
||||||
|
int grid_y = (N + 128 - 1) / 128;
|
||||||
|
dim3 grid(grid_x, grid_y);
|
||||||
|
|
||||||
|
rearrange_kn_weight_as_n32k16_order_ldg16_kernel<FType>
|
||||||
|
<<<grid, BLOCK, 0, stream>>>(B, B_scale, B_zero, B_result, B_scale_result,
|
||||||
|
B_zero_result, K, N, N_32align);
|
||||||
|
}
|
||||||
|
} // namespace allspark
|
||||||
|
|
||||||
|
void rearrange_kn_weight_as_n32k16_order(
|
||||||
|
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
|
||||||
|
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
|
||||||
|
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
|
||||||
|
const int64_t N, const int64_t N_32align) {
|
||||||
|
// Verify device and strides
|
||||||
|
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");
|
||||||
|
TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||||
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_qweight_reorder.device().is_cuda(),
|
||||||
|
"b_qweight_reorder is not on GPU");
|
||||||
|
TORCH_CHECK(b_qweight_reorder.is_contiguous(),
|
||||||
|
"b_qweight_reorder is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_scales_reorder.device().is_cuda(),
|
||||||
|
"b_scales_reorder is not on GPU");
|
||||||
|
TORCH_CHECK(b_scales_reorder.is_contiguous(),
|
||||||
|
"b_scales_reorder is not contiguous");
|
||||||
|
|
||||||
|
if (has_zp) {
|
||||||
|
TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU");
|
||||||
|
TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(),
|
||||||
|
"b_zeros_reorder is not on GPU");
|
||||||
|
TORCH_CHECK(b_zeros_reorder.value().is_contiguous(),
|
||||||
|
"b_zeros_reorder is not contiguous");
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t* matB = reinterpret_cast<const uint8_t*>(b_qweight.data_ptr());
|
||||||
|
const void* b_scale = b_scales.data_ptr();
|
||||||
|
const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr;
|
||||||
|
|
||||||
|
uint8_t* matB_reorder =
|
||||||
|
reinterpret_cast<uint8_t*>(b_qweight_reorder.data_ptr());
|
||||||
|
void* b_scale_reorder = b_scales_reorder.data_ptr();
|
||||||
|
void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr;
|
||||||
|
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
if (b_scales.dtype() == at::ScalarType::Half) {
|
||||||
|
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>(
|
||||||
|
matB, reinterpret_cast<const __half*>(b_scale),
|
||||||
|
reinterpret_cast<const __half*>(b_zero), matB_reorder,
|
||||||
|
reinterpret_cast<__half*>(b_scale_reorder),
|
||||||
|
reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream);
|
||||||
|
} else if (b_scales.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>(
|
||||||
|
matB, reinterpret_cast<const __nv_bfloat16*>(b_scale),
|
||||||
|
reinterpret_cast<const __nv_bfloat16*>(b_zero), matB_reorder,
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(b_scale_reorder),
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("rearrange_kn_weight_as_n32k16_order",
|
||||||
|
&rearrange_kn_weight_as_n32k16_order);
|
||||||
|
}
|
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace allspark {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t cuda_status = cmd; \
|
||||||
|
if (cuda_status != cudaSuccess) { \
|
||||||
|
std::string err_str = cudaGetErrorString(cuda_status); \
|
||||||
|
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||||
|
<< err_str; \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_CUBLAS(cmd) \
|
||||||
|
do { \
|
||||||
|
cublasStatus_t cublas_status = cmd; \
|
||||||
|
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
|
||||||
|
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||||
|
<< cublas_status << std::endl; \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
template <typename FType, typename QType>
|
||||||
|
struct SM8x_GEMM_W8A16_Splitk_Params {
|
||||||
|
const FType* A_ptr;
|
||||||
|
const QType* B_ptr;
|
||||||
|
const FType* B_scale_ptr;
|
||||||
|
const FType* B_zero_ptr;
|
||||||
|
FType* C_ptr;
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int SplitK;
|
||||||
|
int GroupCnt;
|
||||||
|
int GroupSize;
|
||||||
|
FType* C_split_ptr; // for non-fused splitk reduce
|
||||||
|
float* C_tmp_ptr; // for fused splitk reduce
|
||||||
|
uint32_t* red_count_ptr; // for fused splitk reduce
|
||||||
|
};
|
||||||
|
|
||||||
|
struct alignas(16) BlockTileSplitkParams {
|
||||||
|
int Mtile;
|
||||||
|
int Ntile;
|
||||||
|
int SplitK;
|
||||||
|
bool EnableFuse;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FType, int BLOCK, int N_MATRIX>
|
||||||
|
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||||
|
uint32_t n, uint32_t n_matrix,
|
||||||
|
uint32_t matrix_size) {
|
||||||
|
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||||
|
|
||||||
|
if (idx >= matrix_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
FType sum(0);
|
||||||
|
|
||||||
|
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||||
|
for (int i = 0; i < n_mat; ++i) {
|
||||||
|
sum += C_split[idx + i * matrix_size];
|
||||||
|
}
|
||||||
|
|
||||||
|
C[idx] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FType>
|
||||||
|
void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m,
|
||||||
|
const uint32_t n, const uint32_t n_matrix,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int BLOCK = 128;
|
||||||
|
uint32_t matrix_size = m * n;
|
||||||
|
int grid = (matrix_size + BLOCK - 1) / BLOCK;
|
||||||
|
|
||||||
|
void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr;
|
||||||
|
|
||||||
|
switch (n_matrix) {
|
||||||
|
case 4:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 4>;
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 5>;
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 6>;
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 7>;
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 8>;
|
||||||
|
break;
|
||||||
|
case 9:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 9>;
|
||||||
|
break;
|
||||||
|
case 10:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 10>;
|
||||||
|
break;
|
||||||
|
case 11:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 11>;
|
||||||
|
break;
|
||||||
|
case 12:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 12>;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, -1>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel<<<grid, BLOCK, 0, stream>>>(C_split, C, n, n_matrix, matrix_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct HalfType;
|
||||||
|
template <>
|
||||||
|
struct HalfType<half> {
|
||||||
|
using T1 = __half;
|
||||||
|
using T2 = __half2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct HalfType<__nv_bfloat16> {
|
||||||
|
using T1 = __nv_bfloat16;
|
||||||
|
using T2 = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
|
||||||
|
// convert 64-bit pointer to 32-bit smem addr
|
||||||
|
__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) {
|
||||||
|
uint32_t addr;
|
||||||
|
asm("{.reg .u64 u64addr;\n"
|
||||||
|
" cvta.to.shared.u64 u64addr, %1;\n"
|
||||||
|
" cvt.u32.u64 %0, u64addr;}\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"(smem_ptr));
|
||||||
|
|
||||||
|
return addr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) {
|
||||||
|
static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T");
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{.reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %2, 0;\n"
|
||||||
|
" @!p mov.b16 %0, 0;\n"
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||||
|
__CUDA_ARCH__ >= 750
|
||||||
|
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n"
|
||||||
|
#else
|
||||||
|
" @p ld.global.ca.b16 {%0}, [%1];}\n"
|
||||||
|
#endif
|
||||||
|
: "=h"(reinterpret_cast<uint16_t&>(r0))
|
||||||
|
: "l"(ptr), "r"((int)guard));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr,
|
||||||
|
bool guard) {
|
||||||
|
static_assert(sizeof(T) == 4, "ldg64_ca: invalid T");
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{.reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %3, 0;\n"
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||||
|
__CUDA_ARCH__ >= 750
|
||||||
|
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n"
|
||||||
|
#else
|
||||||
|
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n"
|
||||||
|
#endif
|
||||||
|
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r1))
|
||||||
|
: "l"(ptr), "r"((int)guard));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3,
|
||||||
|
const void* ptr, bool guard) {
|
||||||
|
static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T");
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{.reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %5, 0;\n"
|
||||||
|
" @!p mov.b32 %0, 0;\n"
|
||||||
|
" @!p mov.b32 %1, 0;\n"
|
||||||
|
" @!p mov.b32 %2, 0;\n"
|
||||||
|
" @!p mov.b32 %3, 0;\n"
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||||
|
__CUDA_ARCH__ >= 750
|
||||||
|
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||||
|
#else
|
||||||
|
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||||
|
#endif
|
||||||
|
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||||
|
: "l"(ptr), "r"((int)guard));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3,
|
||||||
|
const uint32_t addr) {
|
||||||
|
static_assert(sizeof(T) == 4, "lds128: invalid T");
|
||||||
|
|
||||||
|
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(reinterpret_cast<uint32_t&>(reg0)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(reg1)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(reg2)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(reg3))
|
||||||
|
: "r"(addr));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2,
|
||||||
|
const T& r3, const void* ptr,
|
||||||
|
bool guard) {
|
||||||
|
static_assert(sizeof(T) == 4, "stg128: invalid T");
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{.reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %1, 0;\n"
|
||||||
|
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n"
|
||||||
|
:
|
||||||
|
: "l"(ptr), "r"((int)guard), "r"(reinterpret_cast<const uint32_t&>(r0)),
|
||||||
|
"r"(reinterpret_cast<const uint32_t&>(r1)),
|
||||||
|
"r"(reinterpret_cast<const uint32_t&>(r2)),
|
||||||
|
"r"(reinterpret_cast<const uint32_t&>(r3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3,
|
||||||
|
const uint32_t& addr) {
|
||||||
|
static_assert(sizeof(T) == 4, "ldsm_4: invalid T");
|
||||||
|
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||||
|
asm volatile(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||||
|
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||||
|
: "r"(addr));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FType>
|
||||||
|
__device__ __forceinline__ void hmma16816_f32(float (&d)[4],
|
||||||
|
const uint32_t (&a)[4],
|
||||||
|
const uint32_t (&b)[2]);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4],
|
||||||
|
const uint32_t (&a)[4],
|
||||||
|
const uint32_t (&b)[2]) {
|
||||||
|
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||||
|
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||||
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>(
|
||||||
|
float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) {
|
||||||
|
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||||
|
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||||
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int SIZE_IN_BYTES>
|
||||||
|
__device__ __forceinline__ void cp_async(const uint32_t smem_addr,
|
||||||
|
const void* gmem_ptr,
|
||||||
|
const int src_in_bytes, bool guard) {
|
||||||
|
static_assert(
|
||||||
|
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||||
|
"Size is not supported");
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile(
|
||||||
|
"{.reg.pred p;\n"
|
||||||
|
" setp.ne.b32 p, %4, 0;\n"
|
||||||
|
#if __CUDACC_VER_MINOR__ >= 4
|
||||||
|
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||||
|
#else
|
||||||
|
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n"
|
||||||
|
#endif
|
||||||
|
::"r"(smem_addr),
|
||||||
|
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int SIZE_IN_BYTES>
|
||||||
|
__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr,
|
||||||
|
const void* gmem_ptr,
|
||||||
|
const int src_in_bytes,
|
||||||
|
bool guard) {
|
||||||
|
static_assert(
|
||||||
|
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||||
|
"Size is not supported");
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile(
|
||||||
|
"{.reg.pred p;\n"
|
||||||
|
" setp.ne.b32 p, %4, 0;\n"
|
||||||
|
#if __CUDACC_VER_MINOR__ >= 4
|
||||||
|
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||||
|
#else
|
||||||
|
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n"
|
||||||
|
#endif
|
||||||
|
::"r"(smem_addr),
|
||||||
|
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void cp_async_commit_group() {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cp.async.commit_group;\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ __forceinline__ void cp_asyc_wait_group() {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" : : "n"(N));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata,
|
||||||
|
T* fdata);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
|
||||||
|
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>(
|
||||||
|
const uint32_t& idata, __half2* fdata) {
|
||||||
|
uint32_t i10, i32;
|
||||||
|
asm volatile(
|
||||||
|
"prmt.b32 %0, %2, 0x64, 0x4140;"
|
||||||
|
"prmt.b32 %1, %2, 0x64, 0x4342;"
|
||||||
|
: "=r"(i10), "=r"(i32)
|
||||||
|
: "r"(idata));
|
||||||
|
|
||||||
|
static constexpr uint32_t MAGIC_NUM = 0x64806480;
|
||||||
|
fdata[0] = __hsub2(reinterpret_cast<const __half2&>(i10),
|
||||||
|
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||||
|
fdata[1] = __hsub2(reinterpret_cast<const __half2&>(i32),
|
||||||
|
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
|
||||||
|
// reference from marlin fast implementation
|
||||||
|
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
|
||||||
|
const uint32_t& idata, __nv_bfloat162* fdata) {
|
||||||
|
float fp32_imd[4];
|
||||||
|
uint32_t* fp32_imd_casted = reinterpret_cast<uint32_t*>(fp32_imd);
|
||||||
|
asm volatile(
|
||||||
|
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
|
||||||
|
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
|
||||||
|
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
|
||||||
|
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
|
||||||
|
: "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]),
|
||||||
|
"=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3])
|
||||||
|
: "r"(idata));
|
||||||
|
|
||||||
|
fp32_imd[0] -= 8388736.f;
|
||||||
|
fp32_imd[1] -= 8388736.f;
|
||||||
|
fp32_imd[2] -= 8388736.f;
|
||||||
|
fp32_imd[3] -= 8388736.f;
|
||||||
|
|
||||||
|
uint32_t* bf16_res = reinterpret_cast<uint32_t*>(fdata);
|
||||||
|
asm volatile(
|
||||||
|
"prmt.b32 %0, %2, %3, 0x7632;"
|
||||||
|
"prmt.b32 %1, %4, %5, 0x7632;"
|
||||||
|
: "=r"(bf16_res[0]), "=r"(bf16_res[1])
|
||||||
|
: "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]),
|
||||||
|
"r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __bfloat162bfloat162(x);
|
||||||
|
#endif
|
||||||
|
__builtin_unreachable(); // Suppress missing return statement warning
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ half2 inline num2num2(const half x) {
|
||||||
|
return __half2half2(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace allspark
|
@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor!? azp) -> ()");
|
"Tensor!? azp) -> ()");
|
||||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||||
&dynamic_scaled_int8_quant);
|
&dynamic_scaled_int8_quant);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
|
ops.def(
|
||||||
|
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||||
|
"Tensor? b_zeros, "
|
||||||
|
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
|
||||||
|
"Tensor!? b_zeros_reorder, "
|
||||||
|
"int K, int N, int N_32align) -> ()");
|
||||||
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
|
// AllSpark quantization ops
|
||||||
|
ops.def(
|
||||||
|
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
|
||||||
|
"Tensor? b_qzeros, "
|
||||||
|
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
|
||||||
|
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||||
|
// conditionally compiled so impl in source file
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
|
100
tests/kernels/test_allspark_gemm.py
Normal file
100
tests/kernels/test_allspark_gemm.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||||
|
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||||
|
ALLSPARK_AMPERE_N_ALIGN)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
quantize_weights)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
|
||||||
|
def is_gptq_allspark_supported(min_capability: int,
|
||||||
|
max_capability: int) -> bool:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
assert capability is not None
|
||||||
|
|
||||||
|
return capability.to_int() >= min_capability \
|
||||||
|
and capability.to_int() <= max_capability
|
||||||
|
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 4, 8),
|
||||||
|
(13, 17, 67),
|
||||||
|
(26, 37, 13),
|
||||||
|
(48, 16, 24),
|
||||||
|
(67, 13, 88),
|
||||||
|
(257, 13, 11),
|
||||||
|
(658, 13, 11),
|
||||||
|
(1033, 9, 17),
|
||||||
|
]
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
HAS_ZP_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_diff(output, output_ref):
|
||||||
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||||
|
torch.abs(output_ref))
|
||||||
|
|
||||||
|
|
||||||
|
def rand_data(shape, dtype=torch.float16):
|
||||||
|
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_gptq_allspark_supported(80, 89),
|
||||||
|
reason="AllSpark Ampere kernel is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("group_size", [-1])
|
||||||
|
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
m = m_factor
|
||||||
|
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
|
||||||
|
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
|
||||||
|
|
||||||
|
input = rand_data((m, k), dtype=dtype)
|
||||||
|
weight = rand_data((k, n), dtype=dtype)
|
||||||
|
|
||||||
|
# Quantize (and apply act_order if provided)
|
||||||
|
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
|
||||||
|
group_size, has_zp)
|
||||||
|
|
||||||
|
qw = qw.to(torch.uint8)
|
||||||
|
if has_zp:
|
||||||
|
zp = zp.to(dtype)
|
||||||
|
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||||
|
sm_count = properties.multi_processor_count
|
||||||
|
sm_version = properties.major * 10 + properties.minor
|
||||||
|
|
||||||
|
n_32align = (n + 32 - 1) // 32 * 32
|
||||||
|
|
||||||
|
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||||
|
qw, s, zp, has_zp)
|
||||||
|
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||||
|
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
|
||||||
|
n_32align))
|
||||||
|
|
||||||
|
opcheck(torch.ops._C.allspark_w8a16_gemm,
|
||||||
|
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
|
||||||
|
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
|
||||||
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
|
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
|
||||||
|
n, group_size, sm_count, sm_version,
|
||||||
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||||
|
has_zp, True)
|
||||||
|
|
||||||
|
output_ref = torch.matmul(input, w_ref)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
max_diff = compute_max_diff(output, output_ref)
|
||||||
|
|
||||||
|
assert max_diff < 0.04
|
@ -215,8 +215,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
|||||||
assert qkv_proj.scheme.group_size == (-1
|
assert qkv_proj.scheme.group_size == (-1
|
||||||
if group is None else group)
|
if group is None else group)
|
||||||
|
|
||||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float16
|
|
||||||
assert qkv_proj.scheme.pack_factor == pack_factor
|
assert qkv_proj.scheme.pack_factor == pack_factor
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
@ -404,6 +404,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
memory_format=torch.contiguous_format)
|
memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||||
|
|
||||||
|
@register_fake("_C::allspark_w8a16_gemm")
|
||||||
|
def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
b_qzeros: Optional[torch.Tensor],
|
||||||
|
n: torch.SymInt, group_size: torch.SymInt,
|
||||||
|
sm_count: torch.SymInt,
|
||||||
|
sm_version: torch.SymInt,
|
||||||
|
CUBLAS_M_THRESHOLD: torch.SymInt,
|
||||||
|
has_zp: bool,
|
||||||
|
n32k16_reorder: bool) -> torch.Tensor:
|
||||||
|
m = a.size(0)
|
||||||
|
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.ops._C, "ggml_dequantize"):
|
if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||||
|
|
||||||
@register_fake("_C::ggml_dequantize")
|
@register_fake("_C::ggml_dequantize")
|
||||||
@ -881,6 +897,67 @@ def scaled_fp8_quant(
|
|||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
|
# gptq allspark
|
||||||
|
def allspark_repack_weight(
|
||||||
|
qweight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
zero_point: Optional[torch.Tensor] = None,
|
||||||
|
has_zp: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
|
||||||
|
for Ampere W8A16 Fused Gemm kernel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qweight: uint8 weight tensor, original k x n format.
|
||||||
|
scale: fp16/bf16 weight scale tensor, 1 x n format.
|
||||||
|
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
|
||||||
|
Must be provided for asymmetric quantization.
|
||||||
|
has_zp: if use symmetric quantization, has_zp = False.
|
||||||
|
if use asymmetric quantization, has_zp = True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
|
||||||
|
rearranged weight, scale, and optionally zero_point.
|
||||||
|
"""
|
||||||
|
K = qweight.shape[0]
|
||||||
|
N = qweight.shape[1]
|
||||||
|
N_32align = (N + 32 - 1) // 32 * 32
|
||||||
|
|
||||||
|
qweight_reorder = torch.empty((N_32align, K),
|
||||||
|
device=qweight.device,
|
||||||
|
dtype=qweight.dtype)
|
||||||
|
scale_reorder = torch.empty((1, N_32align),
|
||||||
|
device=scale.device,
|
||||||
|
dtype=scale.dtype)
|
||||||
|
zero_point_reorder = None
|
||||||
|
if has_zp:
|
||||||
|
assert zero_point is not None, (
|
||||||
|
"zero_point must be provided for asymmetric quantization.")
|
||||||
|
zero_point_reorder = torch.empty((1, N_32align),
|
||||||
|
device=zero_point.device,
|
||||||
|
dtype=zero_point.dtype)
|
||||||
|
|
||||||
|
torch.ops._C.rearrange_kn_weight_as_n32k16_order(
|
||||||
|
qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder,
|
||||||
|
zero_point_reorder, K, N, N_32align)
|
||||||
|
|
||||||
|
return qweight_reorder, scale_reorder, zero_point_reorder
|
||||||
|
|
||||||
|
|
||||||
|
def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
b_qzeros: Optional[torch.Tensor], n: int,
|
||||||
|
group_size: int, sm_count: int, sm_version: int,
|
||||||
|
CUBLAS_M_THRESHOLD: int, has_zp: bool,
|
||||||
|
n32k16_reorder: bool) -> torch.Tensor:
|
||||||
|
|
||||||
|
return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros,
|
||||||
|
n, group_size, sm_count,
|
||||||
|
sm_version, CUBLAS_M_THRESHOLD,
|
||||||
|
has_zp, n32k16_reorder)
|
||||||
|
|
||||||
|
|
||||||
# int8
|
# int8
|
||||||
def scaled_int8_quant(
|
def scaled_int8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||||
|
AllSparkLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||||
ExllamaLinearKernel)
|
ExllamaLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||||
@ -16,6 +18,7 @@ from vllm.platforms import current_platform
|
|||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||||
MacheteLinearKernel,
|
MacheteLinearKernel,
|
||||||
|
AllSparkLinearKernel,
|
||||||
MarlinLinearKernel,
|
MarlinLinearKernel,
|
||||||
ExllamaLinearKernel,
|
ExllamaLinearKernel,
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,115 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||||
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
permute_param_layout_)
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AllSparkLinearKernel(MPLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if c.has_g_idx:
|
||||||
|
return False, "Act reordering currently not supported by AllSpark"
|
||||||
|
|
||||||
|
if c.zero_points:
|
||||||
|
return False, "Zero points currently not supported by AllSpark"
|
||||||
|
|
||||||
|
return check_allspark_supported_dtype_shape(
|
||||||
|
c.partition_weight_shape[0], # in_features
|
||||||
|
c.partition_weight_shape[1], # out_features
|
||||||
|
c.group_size,
|
||||||
|
c.weight_type,
|
||||||
|
c.act_type)
|
||||||
|
|
||||||
|
# note assumes that
|
||||||
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
device = getattr(layer, self.w_q_name).device
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
# prepare the parameters required for the kernel
|
||||||
|
properties = torch.cuda.get_device_properties(device.index)
|
||||||
|
sm_count = properties.multi_processor_count
|
||||||
|
sm_version = properties.major * 10 + properties.minor
|
||||||
|
gemm_args = {}
|
||||||
|
gemm_args['sm_count'] = sm_count
|
||||||
|
gemm_args['sm_version'] = sm_version
|
||||||
|
|
||||||
|
self.gemm_args = gemm_args
|
||||||
|
|
||||||
|
# transform param weight, scale
|
||||||
|
old_weight_param = getattr(layer, self.w_q_name)
|
||||||
|
old_scale_param = getattr(layer, self.w_s_name)
|
||||||
|
|
||||||
|
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||||
|
permute_param_layout_(old_weight_param,
|
||||||
|
input_dim=0,
|
||||||
|
output_dim=1,
|
||||||
|
packed_dim=0)
|
||||||
|
|
||||||
|
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||||
|
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||||
|
|
||||||
|
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||||
|
new_weight_param = torch.nn.Parameter(old_weight_param.data,
|
||||||
|
requires_grad=False)
|
||||||
|
new_weight_param.data = new_weight_param.data.t().contiguous().view(
|
||||||
|
dtype=torch.uint8)
|
||||||
|
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||||
|
|
||||||
|
new_scale_param = torch.nn.Parameter(old_scale_param.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||||
|
new_weight_param.data, new_scale_param.data, _ = \
|
||||||
|
ops.allspark_repack_weight(
|
||||||
|
new_weight_param.data, new_scale_param.data, None,
|
||||||
|
c.zero_points)
|
||||||
|
|
||||||
|
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||||
|
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
gemm_args = self.gemm_args
|
||||||
|
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||||
|
|
||||||
|
output = ops.allspark_w8a16_gemm(
|
||||||
|
a=reshaped_x,
|
||||||
|
b_qweight=w_q,
|
||||||
|
b_scales=w_s,
|
||||||
|
b_qzeros=None,
|
||||||
|
n=c.partition_weight_shape[1],
|
||||||
|
group_size=c.group_size,
|
||||||
|
sm_count=gemm_args['sm_count'],
|
||||||
|
sm_version=gemm_args['sm_version'],
|
||||||
|
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||||
|
has_zp=c.zero_points,
|
||||||
|
n32k16_reorder=True)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output.reshape(out_shape)
|
@ -0,0 +1,51 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||||
|
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||||
|
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||||
|
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||||
|
|
||||||
|
|
||||||
|
def check_allspark_supported_dtype_shape(input_size_per_partition: int,
|
||||||
|
output_size_per_partition: int,
|
||||||
|
group_size: int,
|
||||||
|
weight_dtype: ScalarType,
|
||||||
|
act_dtype: torch.dtype):
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
device_capability = (-1 if capability_tuple is None else
|
||||||
|
capability_tuple.to_int())
|
||||||
|
|
||||||
|
# For Ampere GPU
|
||||||
|
if device_capability >= 80 and device_capability < 90:
|
||||||
|
if group_size != -1:
|
||||||
|
return False, \
|
||||||
|
"For Ampere GPU, AllSpark does not support group_size "\
|
||||||
|
f"= {group_size}. Only group_size = -1 are supported."
|
||||||
|
|
||||||
|
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||||
|
return False, "For Ampere GPU, AllSpark does not support "\
|
||||||
|
f"quant type ({weight_dtype}). Only quant type "\
|
||||||
|
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
|
||||||
|
|
||||||
|
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
|
||||||
|
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
|
||||||
|
return False, \
|
||||||
|
"AllSpark needs input_size_per_partition % "\
|
||||||
|
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
|
||||||
|
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
|
||||||
|
"for Ampere GPU optimized kernels."
|
||||||
|
|
||||||
|
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||||
|
return False, \
|
||||||
|
"AllSpark only supports act_dtype = float16 or bfloat16,"\
|
||||||
|
f"for Ampere GPU, but got act_dtype = {act_dtype}."
|
||||||
|
else:
|
||||||
|
return False, "AllSpark currently does not support "\
|
||||||
|
f"device_capability = {device_capability}."
|
||||||
|
|
||||||
|
return True, None
|
Loading…
x
Reference in New Issue
Block a user