[Kernel] moe wna16 cuda kernel (#13321)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
04421dff8a
commit
90e88ab756
@ -558,6 +558,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -565,6 +566,13 @@ set_gencode_flags_for_srcs(
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(VLLM_MOE_WNA16_SRC
|
||||
"csrc/moe/moe_wna16.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_WNA16_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
|
@ -18,3 +18,13 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
|
346
csrc/moe/moe_wna16.cu
Normal file
346
csrc/moe/moe_wna16.cu
Normal file
@ -0,0 +1,346 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include "moe_wna16_utils.h"
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_token_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ num_tokens_post_pad,
|
||||
|
||||
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
|
||||
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
|
||||
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
|
||||
bool mul_topk_weight) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Dtype = ScalarType<scalar_t>;
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
|
||||
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
|
||||
|
||||
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
|
||||
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
|
||||
|
||||
const int32_t expert_id = expert_ids[blockIdx.x];
|
||||
|
||||
int32_t num_valid_tokens = 0;
|
||||
extern __shared__ uint16_t block_input_tmp[];
|
||||
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
|
||||
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
|
||||
|
||||
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
|
||||
for (int m = 0; m < BLOCK_SIZE_M; m++) {
|
||||
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
|
||||
const int32_t token_index = sorted_token_ids[offset_m];
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
for (int i = 0; i < k_per_thread; i++) {
|
||||
int k = BLOCK_SIZE_N * i + threadIdx.x;
|
||||
if (k >= BLOCK_SIZE_K) break;
|
||||
if (offset_k + k >= size_k) break;
|
||||
|
||||
// load input to shared memory
|
||||
// use a special layout to fit the layout of dequanted-weight
|
||||
int origin_k;
|
||||
if constexpr (bit == 4) {
|
||||
// [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
|
||||
} else {
|
||||
// [0, 2, 1, 3]
|
||||
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
|
||||
}
|
||||
|
||||
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
|
||||
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (expert_id == -1) return;
|
||||
__syncthreads();
|
||||
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
|
||||
|
||||
float res[64]; // assume BLOCK_SIZE_M <= 64
|
||||
scalar_t2 res2;
|
||||
scalar_t2 scale_f2;
|
||||
scalar_t2 qzero_f2;
|
||||
|
||||
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
|
||||
constexpr int8_t pack_factor = 32 / bit;
|
||||
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
|
||||
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
|
||||
const scalar_t* expert_scales = scales + expert_offset / group_size;
|
||||
const uint32_t* expert_qzeros =
|
||||
qzeros + expert_offset / group_size / pack_factor;
|
||||
|
||||
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
|
||||
// weight would be loaded in loop
|
||||
uint32_t expert_qweight_tmp[4];
|
||||
float4* expert_qweight_tmp_float4 =
|
||||
reinterpret_cast<float4*>(expert_qweight_tmp);
|
||||
|
||||
// load all required scales one time
|
||||
scalar_t expert_scales_groups[GROUPS];
|
||||
int scales_offset_tmp =
|
||||
(offset_n * size_k + offset_k) / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
*expert_scales_groups = expert_scales[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
float* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
float2* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float2*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
float4* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float4*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
|
||||
}
|
||||
|
||||
// load all required qzeros one time
|
||||
uint8_t expert_qzeros_groups[GROUPS];
|
||||
if (!has_zp) {
|
||||
if constexpr (bit == 4) {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
|
||||
} else {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
|
||||
}
|
||||
} else {
|
||||
int qzeros_offset_tmp =
|
||||
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
|
||||
offset_k / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
uint8_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
uint16_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
uint32_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
uint64_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
}
|
||||
}
|
||||
|
||||
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
|
||||
int k = offset_k + tmp_k * pack_factor;
|
||||
if (k >= size_k) break;
|
||||
const int32_t weight_offset = offset_n * size_k + k;
|
||||
|
||||
if (tmp_k % 4 == 0) {
|
||||
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
|
||||
expert_qweight)[weight_offset / pack_factor / 4];
|
||||
}
|
||||
|
||||
if (tmp_k % (group_size / pack_factor) == 0) {
|
||||
scalar_t scale_f =
|
||||
expert_scales_groups[tmp_k / (group_size / pack_factor)];
|
||||
scale_f2 = Dtype::num2num2(scale_f);
|
||||
|
||||
if (has_zp) {
|
||||
uint8_t qzero =
|
||||
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
|
||||
if constexpr (bit == 4) {
|
||||
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
|
||||
}
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t2 weight_half2[16 / bit];
|
||||
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; m++) {
|
||||
res2 = {};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16 / bit; i++) {
|
||||
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
|
||||
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
|
||||
block_input_half2[offset_input], res2);
|
||||
}
|
||||
|
||||
if (tmp_k == 0) {
|
||||
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
} else {
|
||||
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; ++m) {
|
||||
const int32_t token_index =
|
||||
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
|
||||
if (mul_topk_weight) {
|
||||
res[m] *= topk_weights[token_index];
|
||||
}
|
||||
atomicAdd(&output[token_index * size_n + offset_n],
|
||||
Dtype::float2num(res[m]));
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
|
||||
const uint32_t* b_qweight, const scalar_t* b_scales,
|
||||
const uint32_t* b_qzeros, const float* topk_weights,
|
||||
const int32_t* sorted_token_ids,
|
||||
const int32_t* expert_ids,
|
||||
const int32_t* num_tokens_post_pad, int num_experts,
|
||||
int group_size, int num_token_blocks, int top_k,
|
||||
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
|
||||
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
|
||||
bool has_zp, bool mul_topk_weight) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_SIZE_N;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = num_token_blocks;
|
||||
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
|
||||
|
||||
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
|
||||
if (bit == 4) {
|
||||
if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
|
||||
}
|
||||
} else {
|
||||
if (BLOCK_SIZE_K / group_size == 1) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
|
||||
}
|
||||
}
|
||||
|
||||
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
|
||||
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
|
||||
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K, has_zp, mul_topk_weight);
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
const int size_n = b_qweight.size(1);
|
||||
const int size_k = input.size(1);
|
||||
const int group_size = size_k / b_scales.size(2);
|
||||
|
||||
int64_t EM = sorted_token_ids.size(0);
|
||||
if (size_m <= BLOCK_SIZE_M) {
|
||||
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
|
||||
}
|
||||
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
||||
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
|
||||
"size_k must divisible by BLOCK_SIZE_K");
|
||||
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
|
||||
"BLOCK_SIZE_K must divisible by group_size");
|
||||
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
|
||||
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
|
||||
groups_per_block_row == 4 || groups_per_block_row == 8,
|
||||
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half) {
|
||||
run_moe_wna16_gemm<half>(
|
||||
(const half*)input.data_ptr<at::Half>(),
|
||||
(half*)output.data_ptr<at::Half>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
run_moe_wna16_gemm<nv_bfloat16>(
|
||||
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
|
||||
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else {
|
||||
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
return output;
|
||||
}
|
200
csrc/moe/moe_wna16_utils.h
Normal file
200
csrc/moe/moe_wna16_utils.h
Normal file
@ -0,0 +1,200 @@
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline int2num(const float x) {
|
||||
return __int2half_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half2 inline float22num2(const float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
|
||||
return __int2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, int bit>
|
||||
__device__ inline void dequant(int q, scalar_t2* res) {}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 4>(int q, half2* res) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
q >>= 8;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 8>(int q, half2* res) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
#endif
|
@ -31,6 +31,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
"Tensor b_scales, Tensor? b_qzeros, "
|
||||
"Tensor? topk_weights, Tensor sorted_token_ids, "
|
||||
"Tensor expert_ids, Tensor num_tokens_post_pad, "
|
||||
"int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, "
|
||||
"int bit) -> Tensor");
|
||||
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
|
@ -1098,6 +1098,21 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
experts_ids, num_tokens_post_pad)
|
||||
|
||||
|
||||
def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
|
||||
b_qweight: torch.Tensor, b_scales: torch.Tensor,
|
||||
b_qzeros: Optional[torch.Tensor],
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor, top_k: int,
|
||||
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
|
||||
bit: int) -> torch.Tensor:
|
||||
torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales,
|
||||
b_qzeros, topk_weights, sorted_token_ids,
|
||||
experts_ids, num_tokens_post_pad, top_k,
|
||||
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
|
||||
bit)
|
||||
|
||||
|
||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indicies: torch.Tensor,
|
||||
gating_output: float) -> None:
|
||||
|
@ -719,6 +719,33 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.shape[0],
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
group_size=block_shape[1],
|
||||
real_top_k=topk_ids.shape[1],
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, top_k,
|
||||
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"], bit)
|
||||
return
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
@ -852,6 +879,70 @@ def get_moe_configs(
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_wna16_block_config(config: Dict[str,
|
||||
int], use_moe_wna16_cuda: bool,
|
||||
num_valid_tokens: int, size_k: int, size_n: int,
|
||||
num_experts: int, group_size: int,
|
||||
real_top_k: int, block_size_m: int):
|
||||
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
|
||||
# optimal block config is set
|
||||
return {}
|
||||
if not use_moe_wna16_cuda:
|
||||
# triton moe wna16 kernel
|
||||
if num_valid_tokens // real_top_k == 1:
|
||||
# if bs=1, use a smaller BLOCK_SIZE_N
|
||||
return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
|
||||
else:
|
||||
return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
|
||||
else:
|
||||
# cuda moe wna16 kernel
|
||||
# set default block_size 128, and increase them when num_blocks
|
||||
# is too large.
|
||||
block_size_n = 128
|
||||
block_size_k = 128
|
||||
if block_size_k <= group_size:
|
||||
block_size_k = group_size
|
||||
|
||||
num_n_blocks = size_k // block_size_k
|
||||
num_k_blocks = size_n // block_size_k
|
||||
num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \
|
||||
num_experts
|
||||
if num_valid_tokens // real_top_k <= block_size_m:
|
||||
num_m_blocks = min(num_m_blocks, num_valid_tokens)
|
||||
num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
|
||||
|
||||
if size_k % 256 == 0 and num_blocks >= 256 and \
|
||||
block_size_k < 256:
|
||||
block_size_k = 256
|
||||
num_blocks = num_blocks // (256 // block_size_k)
|
||||
|
||||
if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \
|
||||
size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \
|
||||
num_blocks >= 512:
|
||||
block_size_k = block_size_k * 2
|
||||
num_blocks = num_blocks // 2
|
||||
|
||||
if num_blocks > 1024:
|
||||
block_size_n = 256
|
||||
num_n_blocks = num_n_blocks // 2
|
||||
num_blocks = num_blocks // 2
|
||||
|
||||
if size_n <= 1024 and num_blocks >= 1024:
|
||||
# The kernel performance got much better with BLOCK_SIZE_N=1024
|
||||
# when num_blocks is large, event when N is small.
|
||||
# Not sure why, maybe it force the CUDA SM process only one block
|
||||
# at the same time.
|
||||
block_size_n = 1024
|
||||
|
||||
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
|
||||
|
||||
|
||||
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
|
||||
num_experts: int, bit: int):
|
||||
return bit == 4 and group_size in [32, 64, 128] and \
|
||||
num_valid_tokens / num_experts <= 6
|
||||
|
||||
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
@ -873,6 +964,21 @@ def get_default_config(
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
||||
# moe wna16 kernels
|
||||
# only set BLOCK_SIZE_M
|
||||
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
|
||||
bit = 4 if dtype == "int4_w4a16" else 8
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
|
||||
block_shape[1], E, bit)
|
||||
if use_moe_wna16_cuda:
|
||||
config = {"BLOCK_SIZE_M": min(16, M)}
|
||||
elif M <= 20:
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
||||
elif M <= 40:
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
@ -907,6 +1013,8 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
if dtype == "int4_w4a16":
|
||||
N = N * 2
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
@ -1027,7 +1135,7 @@ def get_config_dtype_str(dtype: torch.dtype,
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w8a16"
|
||||
return "int4_w4a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
|
Loading…
x
Reference in New Issue
Block a user