From d1e82408759067eca0ae55e548f6243a9e0aa12d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Oct 2024 18:41:13 -0400 Subject: [PATCH] [Bugfix] Fix spurious "No compiled cutlass_scaled_mm ..." for W8A8 on Turing (#9487) --- CMakeLists.txt | 4 ++-- csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f6d1c66..a53a8575 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -252,7 +252,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" - "in CUDA target architectures") + " in CUDA target architectures") endif() # @@ -432,7 +432,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" - "in CUDA target architectures") + " in CUDA target architectures") endif() endif() diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 1657f7d0..97a969cf 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -137,9 +137,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, return; } - // Turing - TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + if (version_num >= 75) { + // Turing + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED(