[Model][Quantization] HQQ support through Marlin kernel expansion (#9766)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
This commit is contained in:
parent
efa9084628
commit
b00b33d77e
@ -210,7 +210,8 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
is_k_full=True)
|
||||
is_k_full=True,
|
||||
is_zp_float=False)
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
|
@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
@ -55,8 +55,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@ -518,8 +519,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@ -692,8 +694,10 @@ __global__ void Marlin(
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = is_zp_float
|
||||
? 16 * thread_n_blocks / 8
|
||||
: ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
@ -768,10 +772,17 @@ __global__ void Marlin(
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (is_zp_float) {
|
||||
if constexpr (group_blocks != -1) {
|
||||
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
}
|
||||
} else {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
}
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
@ -832,6 +843,7 @@ __global__ void Marlin(
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@ -1126,7 +1138,7 @@ __global__ void Marlin(
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
@ -1170,11 +1182,44 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
||||
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
@ -1219,10 +1264,14 @@ __global__ void Marlin(
|
||||
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
||||
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
||||
@ -1235,10 +1284,14 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||
@ -1510,7 +1563,7 @@ __global__ void Marlin(
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
@ -1697,23 +1750,27 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
|
||||
IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS>, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@ -1905,51 +1962,96 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
}
|
||||
|
||||
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false)
|
||||
|
||||
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
@ -1958,7 +2060,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce) {
|
||||
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
@ -2111,6 +2213,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
||||
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
||||
|
||||
HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
||||
HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
||||
HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
||||
HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
@ -2135,7 +2242,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
@ -2148,6 +2255,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
@ -2257,6 +2370,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(0),
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
@ -2264,6 +2386,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
@ -2282,7 +2405,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
@ -2291,7 +2414,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
@ -244,7 +244,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gptq_marlin repack from GPTQ.
|
||||
|
@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import
|
||||
marlin_qqq_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
@ -40,6 +41,8 @@ MARLIN_N_CHUNKS = [64, 256]
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
@ -226,7 +229,7 @@ def test_gptq_marlin_gemm(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, False, use_fp32_reduce),
|
||||
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
@ -244,6 +247,7 @@ def test_gptq_marlin_gemm(
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
@ -441,6 +445,7 @@ def test_awq_marlin_gemm(
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
@ -451,6 +456,87 @@ def test_awq_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_hqq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
quant_type = scalar_types.uint4
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
dev = a_input.device
|
||||
|
||||
b_weight = torch.randint(0,
|
||||
10, (size_n, size_k),
|
||||
dtype=torch.uint8,
|
||||
device=dev)
|
||||
scale = rand_data((size_n, size_k // group_size))
|
||||
zero = rand_data((size_n, size_k // group_size))
|
||||
|
||||
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
|
||||
4).to(dev)
|
||||
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[0],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=True,
|
||||
)
|
||||
|
||||
b_flat = b_weight.reshape(-1, group_size)
|
||||
zp_flat = zero.reshape(-1, 1)
|
||||
s_flat = scale.reshape(-1, 1)
|
||||
dequant = (b_flat - zp_flat) * s_flat
|
||||
|
||||
output_ref = torch.matmul(a_input,
|
||||
dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
|
@ -28,3 +28,4 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
|
||||
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
|
||||
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
|
||||
qqq, HandH1998/QQQ-Llama-3-8b, main
|
||||
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
|
@ -343,7 +343,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
size_k: torch.SymInt,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::ggml_dequantize")
|
||||
@ -601,11 +602,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
||||
size_k: int,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, b_q_type.id,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp, use_fp32_reduce)
|
||||
has_zp, use_fp32_reduce, is_zp_float)
|
||||
|
||||
|
||||
# fp8 marlin
|
||||
|
@ -27,7 +27,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
||||
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
|
||||
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
|
||||
"HQQMarlinMethod"
|
||||
]
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config)
|
||||
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
|
||||
@ -48,6 +49,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"hqq": HQQMarlinConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"neuron_quant": NeuronQuantConfig,
|
||||
"ipex": IPEXConfig,
|
||||
|
325
vllm/model_executor/layers/quantization/hqq_marlin.py
Normal file
325
vllm/model_executor/layers/quantization/hqq_marlin.py
Normal file
@ -0,0 +1,325 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
marlin_make_empty_g_idx, marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HQQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for HQQ Marlin"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
"currently 64.")
|
||||
assert weight_bits == 4, ("The only supported HQQ quantization "
|
||||
"bitsize is currently 4.")
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.skip_modules = skip_modules
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
|
||||
wq_params = (config["quant_config"]["weight_quant_params"])
|
||||
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
||||
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
||||
skip_modules = config["skip_modules"]
|
||||
return cls(weight_bits, group_size, skip_modules)
|
||||
|
||||
def is_layer_skipped(self, prefix: str) -> bool:
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
return self.skip_modules is not None and any(
|
||||
module_name in components for module_name in self.skip_modules)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return HQQMarlinMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
# Empty HQQ parameter, will be ignored during loading
|
||||
class HQQEmptyParameter(BasevLLMParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
pass
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
raise ValueError("No loader provided for HQQ parameter!")
|
||||
|
||||
|
||||
# HQQ packing creates issues with sharding - therefore, prior to loading, we
|
||||
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
|
||||
class HQQweightParameter(PackedvLLMParameter):
|
||||
|
||||
# unpack function from https://github.com/mobiusml/hqq
|
||||
def unpack_4bit_u8(self,
|
||||
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
|
||||
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
|
||||
|
||||
dtype = torch.uint8
|
||||
step = W_q.shape[0]
|
||||
tmp = torch.empty([2 * step, W_q.shape[1]],
|
||||
dtype=dtype,
|
||||
device=W_q.device)
|
||||
tmp[:step] = (W_q & 0b11110000) >> 4
|
||||
tmp[step:] = W_q & 0b00001111
|
||||
return tmp
|
||||
|
||||
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
|
||||
**kwargs):
|
||||
super().__init__(packed_factor, packed_dim, None, **kwargs)
|
||||
self.weight_bits = weight_bits
|
||||
self.input_shape = self.shape[self.input_dim] * self.packed_factor
|
||||
self.output_shape = self.shape[self.output_dim]
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(self.output_shape,
|
||||
-1).transpose(1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
|
||||
# GPTQ shape (transposed - we transpose them too when processing weights).
|
||||
class HQQZeroScaleParameter(GroupQuantScaleParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
class HQQMarlinMethod(LinearMethodBase):
|
||||
"""Linear method for HQQ Marlin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: HQQMarlinConfig,
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
|
||||
|
||||
self.scales_and_zp_size = (input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
qweight = HQQweightParameter(
|
||||
data=torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
zeros = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("W_q", qweight)
|
||||
layer.register_parameter("zero", zeros)
|
||||
layer.register_parameter("scale", scales)
|
||||
|
||||
# Ignore extra parameters in the HQQ model.
|
||||
# To be added as needed.
|
||||
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
|
||||
"encoded_state_dict", "group_size", "nbits",
|
||||
"offload_meta", "optimize", "packing",
|
||||
"quant_scale", "quant_zero", "round_zero",
|
||||
"shape", "stores_quant_config",
|
||||
"unpack_view_dtype", "view_as_float")
|
||||
for name in ignore_parameters:
|
||||
layer.register_parameter(
|
||||
name,
|
||||
HQQEmptyParameter(data=torch.empty(0),
|
||||
weight_loader=weight_loader))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
dev = layer.W_q.device
|
||||
|
||||
# Repack to Marlin
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(
|
||||
layer.W_q,
|
||||
sort_indices,
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.weight_bits,
|
||||
).to(dev)
|
||||
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
|
||||
layer.g_idx = marlin_make_empty_g_idx(dev)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
layer.marlin_qweight = marlin_w_q
|
||||
layer.marlin_zeros = marlin_zp
|
||||
layer.marlin_scales = marlin_s
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
workspace = MarlinWorkspace(self.output_size_per_partition,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
scales = layer.marlin_scales
|
||||
zeros = layer.marlin_zeros
|
||||
orig_type = x.dtype
|
||||
|
||||
if orig_type != torch.float16:
|
||||
x = x.to(torch.float16)
|
||||
scales = scales.to(torch.float16)
|
||||
zeros = zeros.to(torch.float16)
|
||||
|
||||
marlin_out = ops.gptq_marlin_gemm(
|
||||
x,
|
||||
layer.marlin_qweight,
|
||||
scales,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
scalar_types.uint4,
|
||||
x.shape[0],
|
||||
self.output_size_per_partition,
|
||||
self.input_size_per_partition,
|
||||
True, # is_k_full
|
||||
True, # has_zp
|
||||
True, # use 32-bit reduce
|
||||
True, # use float zp
|
||||
)
|
||||
|
||||
if orig_type != torch.float16:
|
||||
marlin_out = marlin_out.to(orig_type)
|
||||
|
||||
if bias is not None:
|
||||
marlin_out.add_(bias)
|
||||
|
||||
return marlin_out
|
@ -303,7 +303,8 @@ def apply_gptq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
@ -340,7 +341,8 @@ def apply_awq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
Loading…
x
Reference in New Issue
Block a user