[ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (#14629)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-03-12 05:00:28 -07:00 committed by GitHub
parent ff47aab056
commit 45f3f3f59e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 3 deletions

View File

@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_wna16.cu"
"csrc/moe/topk_softmax_kernels.cu")
set_gencode_flags_for_srcs(
@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${VLLM_MOE_WNA16_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC

View File

@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
std::optional<torch::Tensor> b_qzeros,
@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif

View File

@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"Tensor b_scales, Tensor? b_qzeros, "
@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
#ifndef USE_ROCM
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "

View File

@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
num_tokens_post_pad: torch.Tensor, top_k: int,
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
bit: int) -> torch.Tensor:
if not current_platform.is_cuda():
raise NotImplementedError(
"The optimized moe_wna16_gemm kernel is only "
"available on CUDA platforms")
torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales,
b_qzeros, topk_weights, sorted_token_ids,
experts_ids, num_tokens_post_pad, top_k,