[Kernel] GGUF MoE kernel (#14613)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
This commit is contained in:
parent
e392d85831
commit
e22ee1e7a2
@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "dequantize.cuh"
|
||||
#include "mmvq.cuh"
|
||||
#include "mmq.cuh"
|
||||
#include "moe.cuh"
|
||||
|
||||
// Q8 gemv
|
||||
template <typename scalar_t>
|
||||
@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
||||
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
|
||||
const int block_num_x =
|
||||
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ky, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<scalar_t>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
constexpr int MAX_BLOCK_SIZE = 65535;
|
||||
for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) {
|
||||
const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off;
|
||||
const dim3 num_blocks(block_num_x, num_blocks_y, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(
|
||||
&x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
torch::Tensor W, // expert weights
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, tokens, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_moe_q4_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_moe_q4_1_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_moe_q5_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_moe_q5_1_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_moe_q8_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_moe_q2_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_moe_q3_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_moe_q4_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_moe_q5_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_moe_q6_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return MMQ_X_Q4_0;
|
||||
case 3:
|
||||
return MMQ_X_Q4_1;
|
||||
case 6:
|
||||
return MMQ_X_Q5_0;
|
||||
case 7:
|
||||
return MMQ_X_Q5_1;
|
||||
case 8:
|
||||
return MMQ_X_Q8_0;
|
||||
case 10:
|
||||
return MMQ_X_Q2_K;
|
||||
case 11:
|
||||
return MMQ_X_Q3_K;
|
||||
case 12:
|
||||
return MMQ_X_Q4_K;
|
||||
case 13:
|
||||
return MMQ_X_Q5_K;
|
||||
case 14:
|
||||
return MMQ_X_Q6_K;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
739
csrc/quantization/gguf/moe.cuh
Normal file
739
csrc/quantization/gguf/moe.cuh
Normal file
@ -0,0 +1,739 @@
|
||||
#include <cstdint>
|
||||
|
||||
/* Adapted from ./csrc/quantization/gguf/mmq.cuh
|
||||
based on ./vllm/model_executor/layers/fused_moe/fused_moe.py */
|
||||
template <typename scalar_t, int qk, int qr, int qi, bool need_sum,
|
||||
typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles,
|
||||
int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void moe_q(
|
||||
const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* __restrict__ sorted_token_ids,
|
||||
const int* __restrict__ expert_ids,
|
||||
const int* __restrict__ num_tokens_post_padded, const int exp_stride,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y,
|
||||
const int nrows_dst, const int top_k) {
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||
const int blocks_per_warp = WARP_SIZE_GGUF / qi;
|
||||
|
||||
const int ncols_dst = ncols_y * top_k;
|
||||
|
||||
const int row_dst_0 = blockIdx.x * mmq_y;
|
||||
const int& row_x_0 = row_dst_0;
|
||||
|
||||
const int col_dst_0 = blockIdx.y * mmq_x;
|
||||
|
||||
int token_offs[mmq_x / nwarps];
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
token_offs[i / nwarps] = sorted_token_ids[col_dst_0 + threadIdx.y + i];
|
||||
}
|
||||
|
||||
const int exp_idx = expert_ids[blockIdx.y];
|
||||
if (exp_idx > 255 || exp_idx < 0) return;
|
||||
if (blockIdx.y * mmq_x > num_tokens_post_padded[0]) return;
|
||||
|
||||
const block_q_t* x = (const block_q_t*)((char*)vx + exp_idx * exp_stride);
|
||||
const block_q8_1* y = (const block_q8_1*)(vy);
|
||||
|
||||
int* tile_x_ql = nullptr;
|
||||
half2* tile_x_dm = nullptr;
|
||||
int* tile_x_qh = nullptr;
|
||||
int* tile_x_sc = nullptr;
|
||||
|
||||
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
|
||||
|
||||
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
|
||||
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF / QI8_1];
|
||||
|
||||
float sum[mmq_y / WARP_SIZE_GGUF][mmq_x / nwarps] = {{0.0f}};
|
||||
|
||||
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
||||
load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
|
||||
tile_x_qh, tile_x_sc, threadIdx.y, nrows_x - row_x_0 - 1,
|
||||
threadIdx.x, blocks_per_row_x);
|
||||
|
||||
const int n_per_r = ((qk * blocks_per_warp) / qr);
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
|
||||
const int kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
const int col_y_eff = token_offs[i / nwarps] / top_k;
|
||||
const int block_x = ib0 * (qk / QK8_1) + kbxd;
|
||||
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
|
||||
const block_q8_1* by0 = &y[col_y_eff * blocks_per_col_y + block_x];
|
||||
const int index_y =
|
||||
(threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
|
||||
tile_y_qs[index_y] =
|
||||
get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x < n_per_r / QK8_1) {
|
||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
||||
const int col_y_eff = token_offs[threadIdx.y] / top_k;
|
||||
const int block_x =
|
||||
ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
|
||||
|
||||
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
|
||||
const half2* dsi_src = &y[col_y_eff * blocks_per_col_y + block_x].ds;
|
||||
half2* dsi_dst =
|
||||
&tile_y_ds[threadIdx.y * (WARP_SIZE_GGUF / QI8_1) + kby];
|
||||
|
||||
if (need_sum) {
|
||||
*dsi_dst = *dsi_src;
|
||||
} else {
|
||||
float* dfi_dst = (float*)dsi_dst;
|
||||
*dfi_dst = __low2float(*dsi_src);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// #pragma unroll // unrolling this loop causes too much register pressure
|
||||
for (int k = ir * WARP_SIZE_GGUF / qr; k < (ir + 1) * WARP_SIZE_GGUF / qr;
|
||||
k += vdr) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
sum[i / WARP_SIZE_GGUF][j / nwarps] +=
|
||||
vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs,
|
||||
tile_y_ds, threadIdx.x + i, threadIdx.y + j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
const int col_dst = token_offs[j / nwarps];
|
||||
if (col_dst >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE_GGUF][j / nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_0 64
|
||||
#define MMQ_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
#define MMQ_X_Q4_0 4
|
||||
#define MMQ_Y_Q4_0 32
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
|
||||
#endif
|
||||
moe_q4_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_0<mmq_y>, load_tiles_q4_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
int mmq_y = MMQ_Y_Q4_0;
|
||||
int nwarps = NWARPS_Q4_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_1 64
|
||||
#define MMQ_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
#define MMQ_X_Q4_1 4
|
||||
#define MMQ_Y_Q4_1 32
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
|
||||
#endif
|
||||
moe_q4_1(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_1<mmq_y>, load_tiles_q4_1<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_1_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
int mmq_y = MMQ_Y_Q4_1;
|
||||
int nwarps = NWARPS_Q4_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_0 64
|
||||
#define MMQ_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
#define MMQ_X_Q5_0 4
|
||||
#define MMQ_Y_Q5_0 32
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
|
||||
#endif
|
||||
moe_q5_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_0<mmq_y>, load_tiles_q5_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_1 64
|
||||
#define MMQ_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
#define MMQ_X_Q5_1 4
|
||||
#define MMQ_Y_Q5_1 32
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
|
||||
#endif
|
||||
moe_q5_1(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_1<mmq_y>, load_tiles_q5_1<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_1_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q8_0 64
|
||||
#define MMQ_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
#define MMQ_X_Q8_0 4
|
||||
#define MMQ_Y_Q8_0 32
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
|
||||
#endif
|
||||
moe_q8_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q8_0<mmq_y>, load_tiles_q8_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q8_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q2_K 64
|
||||
#define MMQ_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
#define MMQ_X_Q2_K 4
|
||||
#define MMQ_Y_Q2_K 32
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
|
||||
#endif
|
||||
moe_q2_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q2_K<mmq_y>, load_tiles_q2_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q2_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q3_K 64
|
||||
#define MMQ_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
#define MMQ_X_Q3_K 4
|
||||
#define MMQ_Y_Q3_K 32
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
|
||||
#endif
|
||||
moe_q3_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q3_K<mmq_y>, load_tiles_q3_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q3_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_K 64
|
||||
#define MMQ_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
#define MMQ_X_Q4_K 4
|
||||
#define MMQ_Y_Q4_K 32
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
|
||||
#endif
|
||||
moe_q4_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_K<mmq_y>, load_tiles_q4_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_K 64
|
||||
#define MMQ_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
#define MMQ_X_Q5_K 4
|
||||
#define MMQ_Y_Q5_K 32
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
|
||||
#endif
|
||||
moe_q5_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_K<mmq_y>, load_tiles_q5_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q6_K 64
|
||||
#define MMQ_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
#define MMQ_X_Q6_K 4
|
||||
#define MMQ_Y_Q6_K 32
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
|
||||
#endif
|
||||
moe_q6_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q6_K<mmq_y>, load_tiles_q6_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q6_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
@ -305,6 +305,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||
|
||||
// moe kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_moe_a8(Tensor X, Tensor W, "
|
||||
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
|
||||
"num_tokens_post_padded, "
|
||||
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
|
||||
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
|
||||
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def(
|
||||
|
@ -22,3 +22,16 @@ def test_ggml_opcheck(quant_type):
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
|
||||
shape = [256, 1024, 336]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device='cuda')
|
||||
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
|
||||
num_tokens_post_padded = torch.tensor([1],
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
|
||||
opcheck(torch.ops._C.ggml_moe_a8,
|
||||
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
quant_type, qweight.shape[0], 1, x.shape[0]))
|
||||
|
@ -8,9 +8,13 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
|
||||
|
||||
|
||||
def get_gguf_sample_tensors(
|
||||
@ -22,6 +26,15 @@ def get_gguf_sample_tensors(
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
def get_gguf_MoE_tensors(
|
||||
hidden_size: int,
|
||||
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE_MOE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
@ -132,3 +145,54 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
ref_output,
|
||||
atol=atols[dtype],
|
||||
rtol=rtols[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("top_k", [4, 8])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType, top_k: int):
|
||||
current_platform.seed_everything(0)
|
||||
H, E = 1024, 256
|
||||
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda")
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
w13 = tensors[0]
|
||||
w2 = tensors[1]
|
||||
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
act = SiluAndMul()
|
||||
|
||||
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data,
|
||||
device="cuda"), topk_weights,
|
||||
topk_ids, quant_type, quant_type, act)
|
||||
|
||||
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
|
||||
topk_ids).reshape(output.shape)
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
@ -448,6 +448,23 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
|
||||
batch = X.size(0)
|
||||
return torch.empty((batch, row), dtype=X.dtype, device=W.device)
|
||||
|
||||
@register_fake("_C::ggml_moe_a8")
|
||||
def _ggml_moe_a8_fake(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: torch.SymInt,
|
||||
top_k: torch.SymInt,
|
||||
tokens: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
tokens = X.size(0)
|
||||
return torch.empty((tokens * top_k, row),
|
||||
dtype=torch.float16,
|
||||
device=W.device)
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
@ -1034,6 +1051,26 @@ def ggml_mul_mat_a8(
|
||||
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
def ggml_moe_a8(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
quant_type: int,
|
||||
row: int,
|
||||
top_k: int,
|
||||
tokens: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, quant_type, row,
|
||||
top_k, tokens)
|
||||
|
||||
|
||||
def ggml_moe_get_block_size(quant_type: int) -> int:
|
||||
return torch.ops._C.ggml_moe_get_block_size(quant_type)
|
||||
|
||||
|
||||
# mamba
|
||||
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
|
@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
return y
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
act,
|
||||
) -> torch.Tensor:
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
|
||||
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type, N, top_k,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type2,
|
||||
w2.shape[1], 1, num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once("There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
"Falling back to slow implementation. ")
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = _fuse_mul_mat(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
out_hidden_states[tok] = current_hidden_state
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
final_hidden_states = torch.empty_like(x)
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = layer.w13_qweight[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up,
|
||||
layer.w13_qweight_type.weight_type)
|
||||
out = self.act(out)
|
||||
|
||||
expert_down = layer.w2_qweight[ii]
|
||||
current_state = _fuse_mul_mat(
|
||||
out, expert_down,
|
||||
layer.w2_qweight_type.weight_type).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
final_hidden_states[tok] = current_hidden_state
|
||||
return final_hidden_states
|
||||
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, self.act)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
|
Loading…
x
Reference in New Issue
Block a user