From 9239bf718e5ebb5ab871ac8ed09fb80ed02fa82b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 27 Mar 2025 01:54:44 +0100 Subject: [PATCH] [Kernel] CUTLASS grouped gemm fp8 MoE kernel (#13972) Signed-off-by: ElizaWszola Signed-off-by: ElizaWszola Co-authored-by: Lucas Wilkinson --- CMakeLists.txt | 27 ++ .../kernels/benchmark_grouped_gemm_cutlass.py | 340 +++++++++++++ benchmarks/kernels/benchmark_shapes.py | 16 + csrc/cutlass_extensions/common.hpp | 12 +- .../broadcast_load_epilogue_array_c3x.hpp | 457 ++++++++++++++++++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 66 +++ csrc/ops.h | 14 + .../cutlass_w8a8/moe/get_group_starts.cuh | 80 +++ .../cutlass_w8a8/moe/grouped_mm_c3x.cu | 160 ++++++ .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 149 ++++++ .../quantization/cutlass_w8a8/moe/moe_data.cu | 90 ++++ .../cutlass_w8a8/scaled_mm_entry.cu | 67 +++ csrc/torch_bindings.cpp | 29 ++ tests/kernels/test_cutlass.py | 134 +++++ tests/kernels/test_cutlass_moe.py | 244 ++++++++++ vllm/_custom_ops.py | 53 ++ .../layers/fused_moe/__init__.py | 5 +- .../layers/fused_moe/fused_moe.py | 137 ++++++ .../compressed_tensors/compressed_tensors.py | 31 +- .../compressed_tensors_moe.py | 202 +++++++- .../layers/quantization/utils/w8a8_utils.py | 10 + vllm/utils.py | 9 +- 22 files changed, 2317 insertions(+), 15 deletions(-) create mode 100644 benchmarks/kernels/benchmark_grouped_gemm_cutlass.py create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp create mode 100644 csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh create mode 100644 csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu create mode 100644 csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh create mode 100644 csrc/quantization/cutlass_w8a8/moe/moe_data.cu create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 65d1ddbe..e0f1fdf7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -461,6 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() + # + # CUTLASS MoE kernels + + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works + # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible + # to compile MoE kernels that use its output. + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py new file mode 100644 index 00000000..bcdbf6c7 --- /dev/null +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES_MOE + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, + fused_experts, + fused_topk) +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = [ + "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" +] +DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def bench_run(results: list[benchmark.Measurement], model: str, + num_experts: int, topk: int, per_act_token: bool, + per_out_ch: bool, mkn: tuple[int, int, int]): + label = "Quant Matmul" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, + mkn)) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 + + _, a_scale = ops.scaled_fp8_quant(a) + + w1_q = torch.empty((num_experts, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((num_experts, k, n), + device="cuda", + dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((num_experts, ), + k, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_experts, ), + 2 * n, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_experts, ), + n, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_experts, ), + k, + device="cuda", + dtype=torch.int64) + + for expert in range(num_experts): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q_notransp = w1_q.clone() + w2_q_notransp = w2_q.clone() + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + score = torch.randn((m, num_experts), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + a_scale: torch.Tensor, num_repeats: int): + for _ in range(num_repeats): + fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) + + def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor, + num_repeats: int): + for _ in range(num_repeats): + cutlass_moe_fp8(a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + def run_cutlass_from_graph( + a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, a_scale: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, c_strides1, + ab_strides2, c_strides2) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, + topk_ids, w1_scale, w2_scale, a_scale) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_q_notransp": w1_q_notransp, + "w2_q_notransp": w2_q_notransp, + # Cutlass params + "a_scale": a_scale, + "w1_q": w1_q, + "w2_q": w2_q, + "w1_scale": w1_scale, + "w2_scale": w2_scale, + "ab_strides1": ab_strides1, + "c_strides1": c_strides1, + "ab_strides2": ab_strides2, + "c_strides2": c_strides2, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "a": a, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe": run_cutlass_moe, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, + w1_scale, w2_scale, a_scale, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, + topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, + num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", + nargs="+", + type=int, + default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index c375e61e..70190ba2 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -75,3 +75,19 @@ WEIGHT_SHAPES = { [7168, 8192], ], } + +WEIGHT_SHAPES_MOE = { + "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ + [8, 2, 4096, 28672], + [8, 2, 14336, 4096], + ], + "nm-testing/deepseekv2-lite": [ + [64, 6, 2048, 1408], + ], + "ibm-granite/granite-3.0-1b-a400m": [ + [32, 8, 1024, 1024], + ], + "ibm-granite/granite-3.0-3b-a800m": [ + [40, 8, 1024, 1536], + ], +} diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index febc4ecc..dbe0e30f 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel { Kernel::operator()(std::forward(args)...); #endif } -}; \ No newline at end of file +}; + +template +struct enable_sm90_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 00000000..5c1d6e3f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 0a812dc5..62b848a0 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,6 +1,7 @@ #pragma once #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, @@ -69,6 +70,16 @@ struct ScaledEpilogueBase { 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + template + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. @@ -96,6 +107,14 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr}; } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } }; /* @@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken } }; +/* + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(float const* const* a_scales_ptr, + float const* const* b_scales_ptr, + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + }; // namespace vllm::c3x diff --git a/csrc/ops.h b/csrc/ops.h index 7434aead..1ea9f465 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -164,6 +164,7 @@ int64_t ggml_moe_get_block_size(int64_t type); bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); +bool cutlass_group_gemm_supported(int64_t cuda_device_capability); void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, @@ -175,6 +176,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_moe_mm( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh new file mode 100644 index 00000000..6c6e8979 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +template +__global__ void get_group_gemm_starts( + int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, + bool per_act_token, bool per_out_ch) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = + a_scales_base_as_int + (per_act_token ? expert_offset : 0); + b_scales_offsets[expert_id] = + b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), out_tensors.size(1), \ + a_tensors.size(1), per_act_token, per_out_ch); \ + } + +namespace { + +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int num_experts = static_cast(expert_offsets.size(0)); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu new file mode 100644 index 00000000..2b8bc3fb --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -0,0 +1,160 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (16, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M16 { + // M in [1, 16] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_K8192 { + // K in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_N8192 { + // N in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template +void run_cutlass_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_fp8_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + uint32_t const k = a_tensors.size(1); + + if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else if (k >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else if (m <= 16) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } +} + +void dispatch_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm90( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + run_cutlass_moe_mm_sm90( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } +} + +} // namespace + +void cutlass_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides); +} diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh new file mode 100644 index 00000000..db827b7c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -0,0 +1,149 @@ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/common.hpp" +#include "get_group_starts.cuh" + +using namespace cute; + +namespace { + +using ProblemShape = + cutlass::gemm::GroupProblemShape>; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_group_gemm { + using ElementAB = ElementAB_; + using ElementC = void; + using ElementD = ElementC_; + using ElementAccumulator = float; + + using Epilogue = Epilogue_; + + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, + LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_only>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_group_gemm_caller( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int num_experts = static_cast(expert_offsets.size(0)); + int k_size = a_tensors.size(1); + int n_size = out_tensors.size(1); + + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + + run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, + out_tensors, a_scales, b_scales); + + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; + + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + per_act_token, per_out_ch), + nullptr, static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args}; + + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +} // namespace diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu new file mode 100644 index 00000000..2fb0417c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -0,0 +1,90 @@ +#include + +#include +#include + +#include + +constexpr uint64_t THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets( + const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, + int32_t* atomic_buffer, const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } +} + +__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, const int topk_length, + const int topk) { + int expert_id = blockIdx.x; + + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + input_permutation[start] = i / topk; + output_permutation[i] = start; + } + } +} + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), num_experts); + compute_arg_sorts<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), + topk_ids.size(1)); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index b0838645..54b63894 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias); + +void cutlass_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); + #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { return false; } +bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { + // CUTLASS groped FP8 kernels need at least CUDA 12.3 + // and SM90 (Hopper) + +#if defined CUDA_VERSION + if (cuda_device_capability == 90) { + return CUDA_VERSION >= 12030; + } +#endif + + return false; +} + void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, @@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } +void cutlass_moe_mm( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, + ". Required capability: 90"); +} + +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, input_permutation, + output_permutation, num_experts, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " + "CUDA device capability: ", + version_num, ". Required capability: 90"); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index eb3a2c91..60ad6430 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -365,6 +365,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // Check if cutlass grouped gemm is supported for CUDA devices of the given + // capability + ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool"); + ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported); + + // CUTLASS w8a8 grouped GEMM + ops.def( + "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " + " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor c_strides) -> ()", + {stride_tag}); + ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); + + // A function that computes data required to run fused MoE with w8a8 grouped + // GEMM. It takes topk_ids as an input, and computes expert_offsets + // (token start indices of each expert). In addition to this, it computes + // problem sizes for each expert's multiplication used by the two mms called + // from fused MoE operation, and arrays with permutations required to shuffle + // and de-shuffle the input/output of the fused operation. + ops.def( + "get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " Tensor! input_permutation, " + " Tensor! output_permutation, int num_experts, " + " int n, int k) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) ops.def( "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 72fc660a..f11ce6f4 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -3,6 +3,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random import pytest import torch @@ -507,3 +508,136 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + + +@pytest.mark.parametrize("num_experts", [8, 64]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("use_bias", [False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): + + # Device and dtype setup + device = "cuda" + out_dtype = torch.half + + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + baseline_tensors = [] + + expert_offsets = torch.zeros((num_experts + 1), + device=device, + dtype=torch.int32) + + problem_sizes = torch.zeros((num_experts, 3), + device=device, + dtype=torch.int32) + + if not per_act_token: + one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32) + + alignment = 16 # 128 // 8 + # For variation, each group has dimensions + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + for g in range(num_experts): + m_g = alignment * random.randint(1, 64) + + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][0] = m_g + problem_sizes[g][1] = n_g + problem_sizes[g][2] = k_g + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + print("shape:", m_g, n_g, k_g) + + # Create group-specific A and B (FP8) and output (FP16/FP32) + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + + # Set up A/B scales + scale_b = torch.randn((1, n_b_scales), + device=device, + dtype=torch.float32) + b_scales_tensors.append(scale_b) + + if per_act_token: + scale_a = torch.randn((m_a_scales, 1), + device=device, + dtype=torch.float32) + a_scales_tensors.append(scale_a) + else: + scale_a = one_scale_a + + # Compute baseline result for this group + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, + None) + baseline_tensors.append(baseline_g) + + a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), + device=device, + dtype=torch.float8_e4m3fn) + b_tensors_stacked = torch.empty((num_experts, n_g, k_g), + device=device, + dtype=torch.float8_e4m3fn) + + for g in range(num_experts): + a_tensors_stacked[expert_offsets[g]:expert_offsets[g + + 1]] = a_tensors[g] + b_tensors_stacked[g] = b_tensors[g].t() + b_tensors_stacked = b_tensors_stacked.transpose(1, 2) + + if per_act_token: + a_scales_tensors_stacked = torch.empty( + (expert_offsets[num_experts], 1), + device=device, + dtype=torch.float32) + for g in range(num_experts): + a_scales_tensors_stacked[ + expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] + else: + a_scales_tensors_stacked = one_scale_a + + b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales), + device=device, + dtype=torch.float32) + for g in range(num_experts): + b_scales_tensors_stacked[g] = b_scales_tensors[g] + + out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), + device=device, + dtype=out_dtype) + + ab_strides = torch.full((num_experts, ), + a_tensors_stacked.stride(0), + device="cuda", + dtype=torch.int64) + c_strides = torch.full((num_experts, ), + out_tensors_stacked.stride(0), + device="cuda", + dtype=torch.int64) + + ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, a_scales_tensors_stacked, + b_scales_tensors_stacked, expert_offsets[:-1], + problem_sizes, ab_strides, ab_strides, c_strides) + + # Validate each group's result against the baseline + for g in range(num_experts): + baseline = baseline_tensors[g] + c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + print(baseline) + print(c) + print("*") + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 00000000..1652c72d --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, + fused_experts, + fused_topk) +from vllm.platforms import current_platform + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + + +def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + cutlass_output = cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1) + + print(triton_output) + print(cutlass_output) + print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + print(triton_output) + print(cutlass_output) + print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=9e-2, + rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dc07bad4..2ffcef41 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -587,6 +587,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: cuda_device_capability) +def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + def cutlass_sparse_compress(a: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: """ @@ -677,6 +680,56 @@ def cutlass_scaled_sparse_mm( return out +def get_cutlass_moe_mm_data( + topk_ids: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, output_permutation: torch.Tensor, + num_experts: int, n: int, k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + """ + torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, output_permutation, + num_experts, n, k) + + +def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, + b_scales, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides) + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 6f933c3f..e096d14f 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,8 +36,8 @@ if HAS_TRITON: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + cutlass_moe_fp8, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -45,4 +45,5 @@ if HAS_TRITON: "fused_experts", "get_config_file_name", "grouped_topk", + "cutlass_moe_fp8", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index faaea6b4..0929530e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1623,3 +1623,140 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + + +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - out_dtype (torch.Tensor): The output tensor type. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_q.dtype == torch.float8_e4m3fn + assert w2_q.dtype == torch.float8_e4m3fn + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1_q.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2_q.shape[2], "W2 scale shape mismatch" + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert ab_strides1.shape[0] == w1_q.shape[ + 0], "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1_q.shape[ + 0], "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2_q.shape[ + 0], "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2_q.shape[ + 0], "C Strides 2 expert number mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + topk = topk_ids.size(1) + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) + device = a_q.device + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_experts, n, + k) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, + expert_offsets[:-1], problem_sizes1, ab_strides1, + ab_strides1, c_strides1) + + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) + torch.ops._C.silu_and_mul(intermediate, c1) + + intemediate_q, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) + + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, ab_strides2, + ab_strides2, c_strides2) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ce6c706f..4b2d7ca2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -96,7 +96,8 @@ class CompressedTensorsConfig(QuantizationConfig): if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self) + return CompressedTensorsMoEMethod.get_moe_method( + self, layer.activation, layer.expert_map) return None @classmethod @@ -191,17 +192,26 @@ class CompressedTensorsConfig(QuantizationConfig): def _check_scheme_supported(self, min_capability: int, - error: bool = True) -> bool: + error: bool = True, + match_exact: bool = False) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: capability = capability_tuple.to_int() - supported = capability >= min_capability - if error and not supported: - raise RuntimeError( - "Quantization scheme is not supported for ", - f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + if match_exact: + supported = capability == min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + "the current GPU. Required capability: ", + f"{min_capability}. Current capability: {capability}.") + else: + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") return supported else: return False @@ -262,6 +272,11 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: # Confirm weights quantized. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ff381a4c..2e14845f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -31,6 +31,7 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsWNA16MoEMethod" ] @@ -39,7 +40,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + activation: str, + expert_map: Optional[torch.Tensor], ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -49,6 +52,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): if quant_config._is_wNa16_group_channel(weight_quant, input_quant): return CompressedTensorsWNA16MoEMethod(quant_config) + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + and activation == "silu" and expert_map is None): + return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) else: @@ -250,6 +256,200 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): a2_scale=layer.w2_input_scale) +class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + device = w13_weight.device + # TODO strides can be shared across multiple layers + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + + assert activation == "silu" + assert global_num_experts == layer.w13_weight.shape[0] + assert expert_map is None + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 + + return cutlass_moe_fp8( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + out_dtype=x.dtype, + ) + + class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 9de8e453..c2bd4bce 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -50,6 +50,16 @@ def cutlass_block_fp8_supported() -> bool: return ops.cutlass_scaled_mm_supports_block_fp8(capability) +def cutlass_group_gemm_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_group_gemm_supported(capability) + + CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() diff --git a/vllm/utils.py b/vllm/utils.py index 10134233..73de8262 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1568,18 +1568,21 @@ class ClassRegistry(UserDict[Type[T], _V]): return any(cls in self.data for cls in key.mro()) -def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: +def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. """ - return torch.ops._C.weak_ref_tensor(tensor) + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor def weak_ref_tensors( tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] -) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]: +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors.