[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
parent
0dcc8cbe5a
commit
05d686432f
@ -433,6 +433,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
|
@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
__device__ inline void scale_float(float* c, FragS& s) {
|
||||
__half* s_ptr = reinterpret_cast<__half*>(&s);
|
||||
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
||||
__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
|
||||
half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
|
||||
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
|
||||
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
||||
}
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
__device__ inline void scale_float(float* c, FragS& s) {
|
||||
__half* s_ptr = reinterpret_cast<__half*>(&s);
|
||||
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__device__ inline void MarlinMoESingle(
|
||||
__device__ void MarlinMoESingle(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle(
|
||||
int tb_n_warps = thread_n_blocks / 4;
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
constexpr int sorted_sh_stride = threads;
|
||||
constexpr int sorted_gl_stride = threads;
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
|
||||
// Global A read index of current thread.
|
||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle(
|
||||
int s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
int zp_gl_rd;
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (group_blocks == -1) {
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
constexpr int num_row_threads = 4;
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
|
||||
int sh_first_group_id = -1;
|
||||
int sh_num_groups = -1;
|
||||
constexpr int sh_max_num_groups = 32;
|
||||
|
||||
int shs_size;
|
||||
if constexpr (has_act_order)
|
||||
shs_size = sh_max_num_groups * s_sh_stride + threads;
|
||||
else
|
||||
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
// Shared memory storage for global fetch pipelines.
|
||||
int4* sh_a = sh;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
||||
int* sh_sorted = (int*)(sh_s + shs_size);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle(
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||
@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle(
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
// TODO we are currently hitting illegal memory accesses when fetching
|
||||
// sorted_ids to shared data: fix this
|
||||
auto fetch_sorted_ids_to_shared = [&]() {
|
||||
const int mpt = ceildiv(prob_m, threads);
|
||||
for (int i = 0; i < mpt; i++) {
|
||||
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
|
||||
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
|
||||
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
|
||||
}
|
||||
auto fetch_zp_to_shared = [&]() {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
};
|
||||
|
||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||
// This code does not handle group_blocks == 0,
|
||||
// which signifies act_order.
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||
}
|
||||
|
||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = 0;
|
||||
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = zp_quant_0 >> 8;
|
||||
} else {
|
||||
static_assert(w_type.size_bits() == 8);
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = frag_qzp[k % 2][1];
|
||||
}
|
||||
|
||||
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
|
||||
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
frag_zp[3] = frag_zp_1[1];
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
FragB frag_b0 = dequant<w_type_id>(b_quant_0);
|
||||
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
sub_zp(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
sub_zp(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
// TODO re-enable after fixing this function
|
||||
// fetch_sorted_ids_to_shared();
|
||||
// __syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++) {
|
||||
@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
}
|
||||
|
||||
@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle(
|
||||
init_same_group(0);
|
||||
fetch_to_registers(0, 0);
|
||||
fetch_scales_to_registers(0, 0);
|
||||
fetch_zp_to_registers(0, 0);
|
||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
};
|
||||
@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle(
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
fetch_to_registers(k + 1, pipe % stages);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
} else {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
start_pipes();
|
||||
}
|
||||
}
|
||||
@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@ -1261,6 +1444,8 @@ __global__ void MarlinMoE(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@ -1309,29 +1494,29 @@ __global__ void MarlinMoE(
|
||||
|
||||
if (max_block == 1) {
|
||||
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else if (max_block == 2) {
|
||||
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else if (max_block == 3) {
|
||||
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else {
|
||||
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@ -1358,6 +1544,8 @@ __global__ void MarlinMoE(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@ -1374,7 +1562,6 @@ __global__ void MarlinMoE(
|
||||
int current_m_block, // current m block to start kernel computation from
|
||||
int max_par, // maximum parallelism
|
||||
int cfg_max_m_blocks // upper bound on m blocks
|
||||
|
||||
) {
|
||||
// Marlin is not implemented yet for SM < 8.0
|
||||
assert(false);
|
||||
@ -1389,37 +1576,41 @@ __global__ void MarlinMoE(
|
||||
const int USER_THREADS =
|
||||
256; // Note: This is only used with user-provided thread_k/n
|
||||
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||
// const int SHARED_MEM =
|
||||
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS, NUM_THREADS) \
|
||||
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
||||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
||||
replicate_input, apply_weights, m_block, max_par, \
|
||||
cfg_max_m_blocks); \
|
||||
}
|
||||
|
||||
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
|
||||
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
} // namespace marlin_moe
|
||||
|
31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
Normal file
31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
Normal file
@ -0,0 +1,31 @@
|
||||
#include "marlin_moe_kernel_ku4.h"
|
||||
|
||||
namespace marlin_moe {
|
||||
|
||||
// We return bool so we can create these different kernel calls as a sequence
|
||||
// of if-elseif's.
|
||||
bool call_marlin_moe_kernel_ku4(
|
||||
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks) {
|
||||
bool has_zp = true;
|
||||
|
||||
if (false) {
|
||||
}
|
||||
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
|
||||
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
|
||||
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
|
||||
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
|
||||
else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace marlin_moe
|
20
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
Normal file
20
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "marlin_moe_kernel.h"
|
||||
|
||||
namespace marlin_moe {
|
||||
|
||||
// We return bool so we can create these different kernel calls as a sequence
|
||||
// of if-elseif's.
|
||||
bool call_marlin_moe_kernel_ku4(
|
||||
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks);
|
||||
|
||||
} // namespace marlin_moe
|
@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8(
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
|
||||
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
|
||||
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
|
||||
bool replicate_input, bool apply_weights, int m_block, int max_par,
|
||||
int cfg_max_m_blocks) {
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks) {
|
||||
bool has_zp = false;
|
||||
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
|
||||
|
@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8(
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
|
||||
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
|
||||
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
|
||||
bool replicate_input, bool apply_weights, int m_block, int max_par,
|
||||
int cfg_max_m_blocks);
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks);
|
||||
|
||||
} // namespace marlin_moe
|
||||
|
@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
|
||||
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
|
||||
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
|
||||
bool replicate_input, bool apply_weights, int m_block, int max_par,
|
||||
int cfg_max_m_blocks) {
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks) {
|
||||
bool has_zp = false;
|
||||
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
|
||||
|
@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
|
||||
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
||||
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
||||
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
|
||||
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
|
||||
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
|
||||
bool replicate_input, bool apply_weights, int m_block, int max_par,
|
||||
int cfg_max_m_blocks);
|
||||
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
||||
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
||||
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
||||
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
||||
int m_block, int max_par, int cfg_max_m_blocks);
|
||||
|
||||
}
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include "core/registration.h"
|
||||
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
|
||||
template <typename T>
|
||||
inline std::string str(T x) {
|
||||
@ -157,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
|
||||
{128, 64, 128}, // Reduce N 2X, same K
|
||||
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
||||
{64, 128, 128}, // Reduce K 2X, same N
|
||||
{64, 64, 128}, // Reduce both 2X
|
||||
};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
@ -167,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
|
||||
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
||||
{64, 128, 128}, // Reduce N 2X, same K
|
||||
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
||||
{64, 64, 128}, // Reduce N 4X, same K
|
||||
};
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
@ -312,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
return exec_config_t{0, {-1, -1, -1}};
|
||||
}
|
||||
|
||||
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
|
||||
else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \
|
||||
has_act_order, group_blocks, num_threads, blocks, \
|
||||
max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \
|
||||
sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \
|
||||
expert_offsets_ptr, num_groups, expert_idx, \
|
||||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
|
||||
locks, replicate_input, apply_weights, m_block, \
|
||||
max_par, exec_cfg.max_m_blocks)) { \
|
||||
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
|
||||
else if (KERNEL_FUNCTION( \
|
||||
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
|
||||
group_blocks, num_threads, blocks, max_shared_mem, stream, \
|
||||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
||||
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
||||
replicate_input, apply_weights, m_block, max_par, \
|
||||
exec_cfg.max_m_blocks)) { \
|
||||
}
|
||||
|
||||
void marlin_mm_moe(const void* A, const void* B, void* C,
|
||||
const void* sorted_ids, const void* topk_weights,
|
||||
const void* topk_ids, const void* s, const void* g_idx,
|
||||
const void* perm, void* a_tmp, void* expert_offsets,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, int num_groups, int group_size,
|
||||
int num_experts, int topk, int moe_block_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par, bool replicate_input, bool apply_weights) {
|
||||
const void* topk_ids, const void* s, void* zp,
|
||||
const void* g_idx, const void* perm, void* a_tmp,
|
||||
void* expert_offsets, int prob_m, int prob_n, int prob_k,
|
||||
void* workspace, vllm::ScalarType const& q_type,
|
||||
bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int num_experts, int topk,
|
||||
int moe_block_size, int dev, cudaStream_t stream,
|
||||
int thread_k, int thread_n, int sms, int max_par,
|
||||
bool replicate_input, bool apply_weights) {
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -436,6 +440,8 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
||||
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
|
||||
const int4* zp_ptr =
|
||||
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
|
||||
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
|
||||
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
|
||||
int* locks = (int*)workspace;
|
||||
@ -456,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
|
||||
}
|
||||
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
|
||||
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
|
||||
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
||||
@ -475,13 +482,21 @@ torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
|
||||
int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
|
||||
const torch::Tensor& perm, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
||||
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
|
||||
bool has_zp = b_zeros.size(1) != 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
@ -543,14 +558,27 @@ torch::Tensor marlin_gemm_moe(
|
||||
}
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
|
||||
marlin_moe::marlin_mm_moe(
|
||||
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
|
||||
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
||||
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
||||
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts,
|
||||
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, max_par, replicate_input, apply_weights);
|
||||
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
||||
num_experts, topk, moe_block_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
|
||||
replicate_input, apply_weights);
|
||||
return c;
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
|
||||
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
|
@ -2260,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_scales.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
|
||||
|
160
tests/kernels/test_awq_marlin.py
Normal file
160
tests/kernels/test_awq_marlin.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Test AWQ with fused MoE Marlin kernels.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_marlin.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
def test_fused_marlin_moe_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref1_l = []
|
||||
qweights1_l = []
|
||||
scales1_l = []
|
||||
zp1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweights1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zp1_l.append(zp1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweights1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
zp1 = stack_and_dev(zp1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweights2_l = []
|
||||
scales2_l = []
|
||||
zp2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweights2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zp2_l.append(zp2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweights2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
zp2 = stack_and_dev(zp2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=zp1,
|
||||
w2_zeros=zp2,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(
|
||||
a,
|
||||
w_ref1.transpose(1, 2),
|
||||
w_ref2.transpose(1, 2),
|
||||
score,
|
||||
topk,
|
||||
)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
def test_single_marlin_moe_multiply_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
zp_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
w_ref, qweight, scales, zp = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zp_l.append(zp)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l).contiguous()
|
||||
zp = stack_and_dev(zp_l).contiguous()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
marlin_output = single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
w_zeros=zp,
|
||||
num_bits=num_bits)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
@ -2,16 +2,14 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
|
||||
torch_moe, torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@ -159,9 +116,6 @@ def test_fused_marlin_moe(
|
||||
):
|
||||
seed_everything(7)
|
||||
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
@ -241,15 +195,15 @@ def test_fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
g_idx1,
|
||||
g_idx2,
|
||||
sort_indices1,
|
||||
sort_indices2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@ -280,9 +234,13 @@ def test_fused_marlin_moe(
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
zp = torch.empty((0, 0),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
||||
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
||||
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
2 * n, k, True, e, topk, block_size_m, True, False))
|
||||
|
||||
|
||||
@ -291,7 +249,7 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply(
|
||||
num_bits: int,
|
||||
is_k_full: bool,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply(
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False,
|
||||
g_idx=g_idx,
|
||||
sort_indices=sort_indices,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
@ -12,6 +12,7 @@ import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
|
||||
@ -974,6 +975,50 @@ def fp8_allclose(
|
||||
equal_nan=equal_nan)).item())
|
||||
|
||||
|
||||
# Marlin MoE test utils
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
|
@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
||||
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
||||
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
|
@ -1,7 +1,20 @@
|
||||
#!/bin/bash
|
||||
SUCCESS=0
|
||||
|
||||
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt"
|
||||
while getopts "c:" OPT; do
|
||||
case ${OPT} in
|
||||
c )
|
||||
CONFIG="$OPTARG"
|
||||
;;
|
||||
\? )
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG
|
||||
|
||||
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
|
||||
do
|
||||
|
@ -568,6 +568,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
||||
return output
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
||||
size_k: int, size_n: int,
|
||||
num_bits: int) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
|
||||
size_n, num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def gptq_marlin_gemm(a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
@ -828,11 +842,12 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
sorted_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
||||
g_idx: torch.Tensor, perm: torch.Tensor,
|
||||
workspace: torch.Tensor, b_q_type: ScalarType,
|
||||
size_m: int, size_n: int, size_k: int,
|
||||
is_k_full: bool, num_experts: int, topk: int,
|
||||
moe_block_size: int, replicate_input: bool,
|
||||
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
|
||||
perm: torch.Tensor, workspace: torch.Tensor,
|
||||
b_q_type: ScalarType, size_m: int, size_n: int,
|
||||
size_k: int, is_k_full: bool, num_experts: int,
|
||||
topk: int, moe_block_size: int,
|
||||
replicate_input: bool,
|
||||
apply_weights: bool) -> torch.Tensor:
|
||||
return torch.empty((size_m, topk, size_n),
|
||||
dtype=a.dtype,
|
||||
|
@ -10,15 +10,24 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def single_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
g_idx: Optional[torch.Tensor] = None,
|
||||
sort_indices: Optional[torch.Tensor] = None,
|
||||
w_zeros: Optional[torch.Tensor] = None,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
@ -34,10 +43,12 @@ def single_marlin_moe(
|
||||
- scales (torch.Tensor): The quantization scales.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx (torch.Tensor): The act_order indices.
|
||||
- perm (torch.Tensor): The act_order input permutation.
|
||||
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
|
||||
- sort_indices (Optional[torch.Tensor]): Optional act_order input
|
||||
permutation.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
@ -79,16 +90,34 @@ def single_marlin_moe(
|
||||
max_workspace_size = (N // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
has_zero_point = w_zeros is not None
|
||||
if w_zeros is None:
|
||||
w_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
if g_idx is None:
|
||||
g_idx = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
if sort_indices is None:
|
||||
sort_indices = torch.empty((0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type = get_scalar_type(num_bits, has_zero_point)
|
||||
|
||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
|
||||
block_size_m, True, False)
|
||||
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
|
||||
is_k_full, E, topk, block_size_m, True, False)
|
||||
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
|
||||
@ -97,16 +126,18 @@ def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx1: torch.Tensor,
|
||||
g_idx2: torch.Tensor,
|
||||
perm1: torch.Tensor,
|
||||
perm2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@ -118,21 +149,22 @@ def fused_marlin_moe(
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (torch.Tensor): The first set of act_order indices.
|
||||
- g_idx2 (torch.Tensor): The second set of act_order indices.
|
||||
- perm1 (torch.Tensor): The first act_order input permutation.
|
||||
- perm2 (torch.Tensor): The second act_order input permutation.
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
@ -152,6 +184,20 @@ def fused_marlin_moe(
|
||||
assert hidden_states.dtype == torch.float16
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
has_no_act_order = (g_idx1 is None and g_idx2 is None
|
||||
and sort_indices1 is None and sort_indices2 is None)
|
||||
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
|
||||
and sort_indices1 is not None
|
||||
and sort_indices2 is not None)
|
||||
assert has_no_act_order or has_all_act_order, (
|
||||
"g_idx and sorted_indices "
|
||||
"must be all not None or must be all None")
|
||||
|
||||
has_no_zp = w1_zeros is None and w2_zeros is None
|
||||
has_all_zp = w1_zeros is not None and w2_zeros is not None
|
||||
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
|
||||
"must be both None")
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
@ -172,14 +218,42 @@ def fused_marlin_moe(
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
|
||||
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
||||
max_workspace_size = (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
if has_no_zp:
|
||||
w1_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
w2_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
if has_no_act_order:
|
||||
g_idx1 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
g_idx2 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
sort_indices1 = torch.empty((0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
sort_indices2 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
|
||||
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
@ -194,10 +268,11 @@ def fused_marlin_moe(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
perm1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
scalar_type,
|
||||
scalar_type1,
|
||||
M,
|
||||
2 * N,
|
||||
K,
|
||||
@ -218,10 +293,11 @@ def fused_marlin_moe(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
perm2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
scalar_type,
|
||||
scalar_type2,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
|
@ -1,16 +1,21 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
@ -35,12 +40,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[weight_bits]
|
||||
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
@ -98,10 +104,12 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMarlinLinearMethod"]:
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@ -271,4 +279,182 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
bias=bias)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed":
|
||||
True,
|
||||
"quant_method":
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
})
|
||||
|
||||
w13_qweight = Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(torch.empty(num_experts,
|
||||
intermediate_size,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = intermediate_size // self.quant_config.group_size
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size * 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w13_qzeros,
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w2_qzeros,
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.num_bits,
|
||||
)
|
||||
|
@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
).to(orig_dtype)
|
||||
|
@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
|
||||
device=s.device,
|
||||
dtype=s.dtype,
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
return output
|
||||
@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
size_n: int, num_bits: int):
|
||||
num_experts = q_zp_packed.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
||||
device=q_zp_packed.device,
|
||||
dtype=q_zp_packed.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
|
||||
num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
@ -23,7 +23,9 @@ def get_model_architecture(
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]
|
||||
mixtral_supported = [
|
||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
|
||||
]
|
||||
|
||||
if (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
|
Loading…
x
Reference in New Issue
Block a user