From 45f3f3f59e4898a11baf9bcb8d6ec5db2581af34 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 12 Mar 2025 05:00:28 -0700 Subject: [PATCH] [ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (#14629) Signed-off-by: Sage Moore --- CMakeLists.txt | 2 +- csrc/moe/moe_ops.h | 3 ++- csrc/moe/torch_bindings.cpp | 2 +- vllm/_custom_ops.py | 4 ++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e028bf59..ea6d5237 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/moe_wna16.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( @@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${VLLM_MOE_WNA16_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") + list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 371edb64..0bae119a 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); - +#ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, @@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); +#endif \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d2c03c4d..957ac765 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); +#ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "Tensor b_scales, Tensor? b_qzeros, " @@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); -#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 64175cc4..d68c097f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, num_tokens_post_pad: torch.Tensor, top_k: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, bit: int) -> torch.Tensor: + if not current_platform.is_cuda(): + raise NotImplementedError( + "The optimized moe_wna16_gemm kernel is only " + "available on CUDA platforms") torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, b_qzeros, topk_weights, sorted_token_ids, experts_ids, num_tokens_post_pad, top_k,