[Model][Quantization] HQQ support through Marlin kernel expansion (#9766)

Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
This commit is contained in:
ElizaWszola 2024-11-19 22:31:12 +01:00 committed by GitHub
parent efa9084628
commit b00b33d77e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 632 additions and 89 deletions

View File

@ -210,7 +210,8 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
size_m=bt.a.shape[0],
size_n=bt.w_ref.shape[1],
size_k=bt.w_ref.shape[0],
is_k_full=True)
is_k_full=True,
is_zp_float=False)
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8

View File

@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,

View File

@ -55,8 +55,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
const int group_blocks = -1, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
bool is_k_full, bool has_zp, bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
@ -518,8 +519,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
const int group_blocks = -1, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -692,8 +694,10 @@ __global__ void Marlin(
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = is_zp_float
? 16 * thread_n_blocks / 8
: ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
@ -768,10 +772,17 @@ __global__ void Marlin(
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
}
} else {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
@ -832,6 +843,7 @@ __global__ void Marlin(
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
// Zero accumulators.
auto zero_accums = [&]() {
@ -1126,7 +1138,7 @@ __global__ void Marlin(
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp) {
if constexpr (has_zp && !is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
@ -1170,11 +1182,44 @@ __global__ void Marlin(
}
}
}
else if constexpr (has_zp && is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
if constexpr (has_zp) {
if constexpr (has_zp && !is_zp_float) {
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
@ -1219,10 +1264,14 @@ __global__ void Marlin(
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
if constexpr (has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
}
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
@ -1235,10 +1284,14 @@ __global__ void Marlin(
}
// Apply zero-point to frag_b1
if constexpr (has_zp) {
if constexpr (has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
@ -1510,7 +1563,7 @@ __global__ void Marlin(
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && group_blocks == -1) {
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
@ -1697,23 +1750,27 @@ __global__ void Marlin(
}
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
} \
}
typedef struct {
@ -1905,51 +1962,96 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
}
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false)
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
@ -1958,7 +2060,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, int max_par, bool use_fp32_reduce) {
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
@ -2111,6 +2213,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
@ -2135,7 +2242,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
vllm::ScalarTypeId const& b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
bool use_fp32_reduce, bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
if (has_zp) {
TORCH_CHECK(
@ -2148,6 +2255,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_q_type.str());
}
if (has_zp && is_zp_float) {
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}
int pack_factor = 32 / b_q_type.size_bits();
// Verify A
@ -2257,6 +2370,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
if (is_zp_float) {
TORCH_CHECK(b_zeros.size(1) == size_n,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n = ", size_n);
TORCH_CHECK(num_groups == b_zeros.size(0),
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
} else {
TORCH_CHECK(b_zeros.size(0) == num_groups,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
@ -2264,6 +2386,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
}
// Verify workspace size
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
@ -2282,7 +2405,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
@ -2291,7 +2414,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}

View File

@ -244,7 +244,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.

View File

@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
@ -40,6 +41,8 @@ MARLIN_N_CHUNKS = [64, 256]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
HQQ_SUPPORTED_GROUP_SIZES = [64]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
@ -226,7 +229,7 @@ def test_gptq_marlin_gemm(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm(
@ -244,6 +247,7 @@ def test_gptq_marlin_gemm(
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
@ -441,6 +445,7 @@ def test_awq_marlin_gemm(
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
@ -451,6 +456,87 @@ def test_awq_marlin_gemm(
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
k_chunk,
n_chunk,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
quant_type = scalar_types.uint4
a_input = rand_data((size_m, size_k))
dev = a_input.device
b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)
g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)
b_flat = b_weight.reshape(-1, group_size)
zp_flat = zero.reshape(-1, 1)
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat
output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)

View File

@ -28,3 +28,4 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main

View File

@ -343,7 +343,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::ggml_dequantize")
@ -601,11 +602,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce, is_zp_float)
# fp8 marlin

View File

@ -27,7 +27,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod"
]

View File

@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
@ -48,6 +49,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,

View File

@ -0,0 +1,325 @@
from typing import Any, Dict, List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
marlin_make_empty_g_idx, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
from vllm.model_executor.parameter import (BasevLLMParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class HQQMarlinConfig(QuantizationConfig):
"""Config class for HQQ Marlin"""
def __init__(
self,
weight_bits: int,
group_size: int,
skip_modules: Optional[List[str]] = None,
) -> None:
assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization "
"bitsize is currently 4.")
self.weight_bits = weight_bits
self.group_size = group_size
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
self.quant_type = scalar_types.uint4
self.skip_modules = skip_modules
def __repr__(self) -> str:
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
return "hqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"])
skip_modules = config["skip_modules"]
return cls(weight_bits, group_size, skip_modules)
def is_layer_skipped(self, prefix: str) -> bool:
# Split the prefix into its dot-separated components
components = prefix.split('.')
# Check if any of the skip modules exactly matches any component
return self.skip_modules is not None and any(
module_name in components for module_name in self.skip_modules)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self.is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return HQQMarlinMethod(self)
return None
# Empty HQQ parameter, will be ignored during loading
class HQQEmptyParameter(BasevLLMParameter):
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
pass
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
raise ValueError("No loader provided for HQQ parameter!")
# HQQ packing creates issues with sharding - therefore, prior to loading, we
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
class HQQweightParameter(PackedvLLMParameter):
# unpack function from https://github.com/mobiusml/hqq
def unpack_4bit_u8(self,
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
dtype = torch.uint8
step = W_q.shape[0]
tmp = torch.empty([2 * step, W_q.shape[1]],
dtype=dtype,
device=W_q.device)
tmp[:step] = (W_q & 0b11110000) >> 4
tmp[step:] = W_q & 0b00001111
return tmp
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
**kwargs):
super().__init__(packed_factor, packed_dim, None, **kwargs)
self.weight_bits = weight_bits
self.input_shape = self.shape[self.input_dim] * self.packed_factor
self.output_shape = self.shape[self.output_dim]
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(self.output_shape,
-1).transpose(1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_row_parallel_weight(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
# GPTQ shape (transposed - we transpose them too when processing weights).
class HQQZeroScaleParameter(GroupQuantScaleParameter):
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
super().load_row_parallel_weight(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)
class HQQMarlinMethod(LinearMethodBase):
"""Linear method for HQQ Marlin.
"""
def __init__(
self,
quant_config: HQQMarlinConfig,
):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
self.output_size_per_partition = sum(output_partition_sizes)
self.input_size_per_partition = input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
self.scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)
qweight = HQQweightParameter(
data=torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_bits=self.quant_config.weight_bits,
weight_loader=weight_loader)
zeros = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
scales = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("W_q", qweight)
layer.register_parameter("zero", zeros)
layer.register_parameter("scale", scales)
# Ignore extra parameters in the HQQ model.
# To be added as needed.
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
"encoded_state_dict", "group_size", "nbits",
"offload_meta", "optimize", "packing",
"quant_scale", "quant_zero", "round_zero",
"shape", "stores_quant_config",
"unpack_view_dtype", "view_as_float")
for name in ignore_parameters:
layer.register_parameter(
name,
HQQEmptyParameter(data=torch.empty(0),
weight_loader=weight_loader))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
dev = layer.W_q.device
# Repack to Marlin
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(
layer.W_q,
sort_indices,
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.weight_bits,
).to(dev)
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)
layer.g_idx = marlin_make_empty_g_idx(dev)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
layer.marlin_qweight = marlin_w_q
layer.marlin_zeros = marlin_zp
layer.marlin_scales = marlin_s
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(self.output_size_per_partition,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
scales = layer.marlin_scales
zeros = layer.marlin_zeros
orig_type = x.dtype
if orig_type != torch.float16:
x = x.to(torch.float16)
scales = scales.to(torch.float16)
zeros = zeros.to(torch.float16)
marlin_out = ops.gptq_marlin_gemm(
x,
layer.marlin_qweight,
scales,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
scalar_types.uint4,
x.shape[0],
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
True, # use 32-bit reduce
True, # use float zp
)
if orig_type != torch.float16:
marlin_out = marlin_out.to(orig_type)
if bias is not None:
marlin_out.add_(bias)
return marlin_out

View File

@ -303,7 +303,8 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce)
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
@ -340,7 +341,8 @@ def apply_awq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce)
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add