diff --git a/CMakeLists.txt b/CMakeLists.txt index 5349b64a..55ac3c77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -558,6 +558,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/moe_wna16.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( @@ -565,6 +566,13 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") + set(VLLM_MOE_WNA16_SRC + "csrc/moe/moe_wna16.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_WNA16_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 66bb5f41..371edb64 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -18,3 +18,13 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit); diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu new file mode 100644 index 00000000..51ae76c1 --- /dev/null +++ b/csrc/moe/moe_wna16.cu @@ -0,0 +1,346 @@ + +#include +#include +#include +#include + +#include +#include +#include "moe_wna16_utils.h" + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +template +__global__ void moe_wna16_gemm_kernel( + const scalar_t* __restrict__ input, scalar_t* __restrict__ output, + + const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, + const uint32_t* __restrict__ qzeros, + + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_token_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ num_tokens_post_pad, + + uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m, + uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M, + uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp, + bool mul_topk_weight) { +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr (std::is_same::value) { + return; + } else { +#endif + + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; + + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; + + const int32_t expert_id = expert_ids[blockIdx.x]; + + int32_t num_valid_tokens = 0; + extern __shared__ uint16_t block_input_tmp[]; + scalar_t* block_input = reinterpret_cast(block_input_tmp); + scalar_t2* block_input_half2 = reinterpret_cast(block_input); + + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory + for (int m = 0; m < BLOCK_SIZE_M; m++) { + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; + const int32_t token_index = sorted_token_ids[offset_m]; + if (token_index / top_k >= size_m) break; + + num_valid_tokens = m + 1; + if (blockIdx.z == 0 && offset_n < size_n) + output[token_index * size_n + offset_n] = Dtype::int2num(0); + + if (expert_id != -1) { + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } + + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + } + } + } + + if (expert_id == -1) return; + __syncthreads(); + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; + + float res[64]; // assume BLOCK_SIZE_M <= 64 + scalar_t2 res2; + scalar_t2 scale_f2; + scalar_t2 qzero_f2; + + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + constexpr int8_t pack_factor = 32 / bit; + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 + // weight would be loaded in loop + uint32_t expert_qweight_tmp[4]; + float4* expert_qweight_tmp_float4 = + reinterpret_cast(expert_qweight_tmp); + + // load all required scales one time + scalar_t expert_scales_groups[GROUPS]; + int scales_offset_tmp = + (offset_n * size_k + offset_k) / group_size / GROUPS; + if constexpr (GROUPS == 1) { + *expert_scales_groups = expert_scales[scales_offset_tmp]; + } else if constexpr (GROUPS == 2) { + float* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 4) { + float2* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 8) { + float4* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } + + // load all required qzeros one time + uint8_t expert_qzeros_groups[GROUPS]; + if (!has_zp) { + if constexpr (bit == 4) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); + } + } else { + int qzeros_offset_tmp = + (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; + if constexpr (GROUPS == 1) { + uint8_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 2) { + uint16_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 4) { + uint32_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 8) { + uint64_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } + } + + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { + int k = offset_k + tmp_k * pack_factor; + if (k >= size_k) break; + const int32_t weight_offset = offset_n * size_k + k; + + if (tmp_k % 4 == 0) { + *expert_qweight_tmp_float4 = reinterpret_cast( + expert_qweight)[weight_offset / pack_factor / 4]; + } + + if (tmp_k % (group_size / pack_factor) == 0) { + scalar_t scale_f = + expert_scales_groups[tmp_k / (group_size / pack_factor)]; + scale_f2 = Dtype::num2num2(scale_f); + + if (has_zp) { + uint8_t qzero = + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; + if constexpr (bit == 4) { + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + } + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + } + } + + scalar_t2 weight_half2[16 / bit]; + dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); + + for (int m = 0; m < num_valid_tokens; m++) { + res2 = {}; + +#pragma unroll + for (int i = 0; i < 16 / bit; i++) { + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; + res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), + block_input_half2[offset_input], res2); + } + + if (tmp_k == 0) { + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } else { + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } + } + } + + for (int m = 0; m < num_valid_tokens; ++m) { + const int32_t token_index = + sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; + if (mul_topk_weight) { + res[m] *= topk_weights[token_index]; + } + atomicAdd(&output[token_index * size_n + offset_n], + Dtype::float2num(res[m])); + } + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } +#endif +} + +template +void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, + const uint32_t* b_qweight, const scalar_t* b_scales, + const uint32_t* b_qzeros, const float* topk_weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const int32_t* num_tokens_post_pad, int num_experts, + int group_size, int num_token_blocks, int top_k, + int size_m, int size_n, int size_k, int BLOCK_SIZE_M, + int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit, + bool has_zp, bool mul_topk_weight) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_SIZE_N; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = num_token_blocks; + gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N); + gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K); + + auto kernel = moe_wna16_gemm_kernel; + if (bit == 4) { + if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } else { + if (BLOCK_SIZE_K / group_size == 1) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } + + const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + input, output, b_qweight, b_scales, b_qzeros, topk_weights, + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, + group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, has_zp, mul_topk_weight); +} + +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto options = + torch::TensorOptions().dtype(input.dtype()).device(input.device()); + + const int num_experts = b_qweight.size(0); + const int size_m = input.size(0); + const int size_n = b_qweight.size(1); + const int size_k = input.size(1); + const int group_size = size_k / b_scales.size(2); + + int64_t EM = sorted_token_ids.size(0); + if (size_m <= BLOCK_SIZE_M) { + EM = min(EM, size_m * BLOCK_SIZE_M * top_k); + } + const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; + + const uint32_t* b_qzeros_ptr; + if (b_qzeros.has_value()) + b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr(); + const float* topk_weights_ptr; + if (topk_weights.has_value()) + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); + + int groups_per_block_row = BLOCK_SIZE_K / group_size; + TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); + TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, + "size_k must divisible by BLOCK_SIZE_K"); + TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, + "BLOCK_SIZE_K must divisible by group_size"); + TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); + TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || + groups_per_block_row == 4 || groups_per_block_row == 8, + "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); + + if (input.scalar_type() == at::ScalarType::Half) { + run_moe_wna16_gemm( + (const half*)input.data_ptr(), + (half*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const half*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); + } else if (input.scalar_type() == at::ScalarType::BFloat16) { + run_moe_wna16_gemm( + (const nv_bfloat16*)input.data_ptr(), + (nv_bfloat16*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const nv_bfloat16*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); + } else { + TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16"); + } + return output; +} diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h new file mode 100644 index 00000000..4396b802 --- /dev/null +++ b/csrc/moe/moe_wna16_utils.h @@ -0,0 +1,200 @@ + +#include +#include + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } + + static __host__ __device__ half inline int2num(const float x) { + return __int2half_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } + + static __host__ __device__ half2 inline float22num2(const float2 x) { + return __float22half2_rn(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } + + static __host__ __device__ nv_bfloat16 inline int2num(const float x) { + return __int2bfloat16_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } + + static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) { + return __float22bfloat162_rn(x); + } +#endif +}; + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* res) {} + +template <> +__device__ inline void dequant(int q, half2* res) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + q >>= 8; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + + res[0] = __hsub2(*reinterpret_cast(&lo0), + *reinterpret_cast(&SUB)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hsub2(*reinterpret_cast(&lo1), + *reinterpret_cast(&SUB)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* res) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + res[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + res[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + res[0] = __hfma2(*reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2(*reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(res); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8540633d..d2c03c4d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -31,6 +31,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); + m.def( + "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " + "Tensor b_scales, Tensor? b_qzeros, " + "Tensor? topk_weights, Tensor sorted_token_ids, " + "Tensor expert_ids, Tensor num_tokens_post_pad, " + "int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, " + "int bit) -> Tensor"); + + m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); + #ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d25e7944..53065dd0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1098,6 +1098,21 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, experts_ids, num_tokens_post_pad) +def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, top_k: int, + BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + bit: int) -> torch.Tensor: + torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, + b_qzeros, topk_weights, sorted_token_ids, + experts_ids, num_tokens_post_pad, top_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit) + + def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, gating_output: float) -> None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5336b3c1..89ceba12 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -719,6 +719,33 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=topk_ids.numel(), + group_size=block_shape[1], + num_experts=B.shape[0], + bit=4 if use_int4_w4a16 else 8) + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=topk_ids.numel(), + size_k=A.shape[1], + size_n=B.shape[1], + num_experts=B.shape[1], + group_size=block_shape[1], + real_top_k=topk_ids.shape[1], + block_size_m=config["BLOCK_SIZE_M"])) + + if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 + ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, expert_ids, + num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], bit) + return + fused_moe_kernel_gptq_awq[grid]( A, B, @@ -852,6 +879,70 @@ def get_moe_configs( return None +def get_moe_wna16_block_config(config: Dict[str, + int], use_moe_wna16_cuda: bool, + num_valid_tokens: int, size_k: int, size_n: int, + num_experts: int, group_size: int, + real_top_k: int, block_size_m: int): + if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + # optimal block config is set + return {} + if not use_moe_wna16_cuda: + # triton moe wna16 kernel + if num_valid_tokens // real_top_k == 1: + # if bs=1, use a smaller BLOCK_SIZE_N + return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} + else: + return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} + else: + # cuda moe wna16 kernel + # set default block_size 128, and increase them when num_blocks + # is too large. + block_size_n = 128 + block_size_k = 128 + if block_size_k <= group_size: + block_size_k = group_size + + num_n_blocks = size_k // block_size_k + num_k_blocks = size_n // block_size_k + num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ + num_experts + if num_valid_tokens // real_top_k <= block_size_m: + num_m_blocks = min(num_m_blocks, num_valid_tokens) + num_blocks = num_m_blocks * num_n_blocks * num_k_blocks + + if size_k % 256 == 0 and num_blocks >= 256 and \ + block_size_k < 256: + block_size_k = 256 + num_blocks = num_blocks // (256 // block_size_k) + + if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ + size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ + num_blocks >= 512: + block_size_k = block_size_k * 2 + num_blocks = num_blocks // 2 + + if num_blocks > 1024: + block_size_n = 256 + num_n_blocks = num_n_blocks // 2 + num_blocks = num_blocks // 2 + + if size_n <= 1024 and num_blocks >= 1024: + # The kernel performance got much better with BLOCK_SIZE_N=1024 + # when num_blocks is large, event when N is small. + # Not sure why, maybe it force the CUDA SM process only one block + # at the same time. + block_size_n = 1024 + + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} + + +def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 6 + + def get_default_config( M: int, E: int, @@ -873,6 +964,21 @@ def get_default_config( "num_warps": 4, "num_stages": 3, } + elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: + # moe wna16 kernels + # only set BLOCK_SIZE_M + # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + bit = 4 if dtype == "int4_w4a16" else 8 + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, + block_shape[1], E, bit) + if use_moe_wna16_cuda: + config = {"BLOCK_SIZE_M": min(16, M)} + elif M <= 20: + config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + elif M <= 40: + config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + else: + config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} else: config = { "BLOCK_SIZE_M": 64, @@ -907,6 +1013,8 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape + if dtype == "int4_w4a16": + N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 configs = get_moe_configs(E, N, dtype, block_n, block_k) @@ -1027,7 +1135,7 @@ def get_config_dtype_str(dtype: torch.dtype, elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: - return "int4_w8a16" + return "int4_w4a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs