#include #include #include #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" using namespace cute; namespace { template typename Epilogue> struct sm90_fp8_config_default { // M in (16, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template typename Epilogue> struct sm90_fp8_config_M16 { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template typename Epilogue> struct sm90_fp8_config_K8192 { // K in [8192, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template typename Epilogue> struct sm90_fp8_config_N8192 { // N in [8192, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template void run_cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, "A tensors must be of type float8_e4m3fn."); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); uint32_t const k = a_tensors.size(1); if (n >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (k >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (m <= 16) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } } void dispatch_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { if (out_tensors.dtype() == torch::kBFloat16) { run_cutlass_moe_mm_sm90( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { run_cutlass_moe_mm_sm90( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } } } // namespace void cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); }