// clang-format will break include orders // clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" // clang-format on using namespace cute; using namespace vllm; struct GemmCallerTraits { using return_type = void; template static return_type invoke(Args&&... args) { return cutlass_sparse_gemm_caller(std::forward(args)...); } }; struct GemmCompressorTraits { using return_type = CompressorResult; template static return_type invoke(Args&&... args) { return cutlass_sparse_compress(std::forward(args)...); } }; template typename Epilogue, typename DispatchFunc, typename... Args> typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch( uint32_t m, uint32_t n, Args&&... args) { static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM256 = typename sm90_fp8_config_M256::Cutlass3xGemm; using Cutlass3xGemmM512 = typename sm90_fp8_config_M512::Cutlass3xGemm; using Cutlass3xGemm1 = typename sm90_fp8_config_1::Cutlass3xGemm; using Cutlass3xGemm2 = typename sm90_fp8_config_2::Cutlass3xGemm; using Cutlass3xGemm3 = typename sm90_fp8_config_3::Cutlass3xGemm; using Cutlass3xGemm4 = typename sm90_fp8_config_4::Cutlass3xGemm; using Cutlass3xGemm5 = typename sm90_fp8_config_5::Cutlass3xGemm; using Cutlass3xGemm6 = typename sm90_fp8_config_6::Cutlass3xGemm; using Cutlass3xGemm7 = typename sm90_fp8_config_7::Cutlass3xGemm; using Cutlass3xGemm8 = typename sm90_fp8_config_8::Cutlass3xGemm; uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { if (n == 28672) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 4096 || n == 6144) { return DispatchFunc::template invoke( std::forward(args)...); } } else if (mp2 <= 128) { if (n == 4096) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 28672) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 6144) { return DispatchFunc::template invoke( std::forward(args)...); } } else if (mp2 <= 256) { if (n == 4096) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 28672) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 6144) { return DispatchFunc::template invoke( std::forward(args)...); } } else { if (n == 6144 || n == 28672) { return DispatchFunc::template invoke( std::forward(args)...); } else if (n == 4096) { return DispatchFunc::template invoke( std::forward(args)...); } } // Otherwise the default heuristic if (mp2 <= 64) { // n in [1, 64] return DispatchFunc::template invoke( std::forward(args)...); } else if (mp2 <= 128) { // n in (64, 128] return DispatchFunc::template invoke( std::forward(args)...); } else if (mp2 <= 256) { // n in (128, 256] return DispatchFunc::template invoke( std::forward(args)...); } else { // n in (256, inf) return DispatchFunc::template invoke( std::forward(args)...); } } template typename Epilogue, typename DispatchFunc, typename... Args> typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch( uint32_t m, uint32_t n, Args&&... args) { using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; return DispatchFunc::template invoke( std::forward(args)...); } template typename Epilogue, typename DispatchFunc, typename... Args> typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch( uint32_t m, uint32_t n, Args&&... args) { static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = typename sm90_int8_config_M32_NBig::Cutlass3xGemm; using Cutlass3xGemmM32NSmall = typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; bool const is_small_n = n < 8192; uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { return DispatchFunc::template invoke( std::forward(args)...); } else { return DispatchFunc::template invoke( std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] return DispatchFunc::template invoke( std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return DispatchFunc::template invoke( std::forward(args)...); } else { // m in (128, inf) return DispatchFunc::template invoke( std::forward(args)...); } } // Dispatch to GEMM implementations based on element types template