[Kernel] CUTLASS grouped gemm fp8 MoE kernel (#13972)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com>
This commit is contained in:
parent
7a6d45bc8a
commit
9239bf718e
@ -461,6 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||
# to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
|
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||
]
|
||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
num_experts: int, topk: int, per_act_token: bool,
|
||||
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||
mkn))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
_, a_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((num_experts, k, n),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts, ),
|
||||
2 * n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts, ),
|
||||
n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
w1_q_notransp = w1_q.clone()
|
||||
w2_q_notransp = w2_q.clone()
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor, num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||
num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||
ab_strides2, c_strides2)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||
topk_ids, w1_scale, w2_scale, a_scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
"w1_q_notransp": w1_q_notransp,
|
||||
"w2_q_notransp": w2_q_notransp,
|
||||
# Cutlass params
|
||||
"a_scale": a_scale,
|
||||
"w1_q": w1_q,
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"ab_strides1": ab_strides1,
|
||||
"c_strides1": c_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"a": a,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe": run_cutlass_moe,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||
w1_scale, w2_scale, a_scale, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||
num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in DEFAULT_BATCH_SIZES:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(results, model, num_experts, topk,
|
||||
per_act_token, per_out_ch, mkn)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -75,3 +75,19 @@ WEIGHT_SHAPES = {
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
|
||||
[8, 2, 4096, 28672],
|
||||
[8, 2, 14336, 4096],
|
||||
],
|
||||
"nm-testing/deepseekv2-lite": [
|
||||
[64, 6, 2048, 1408],
|
||||
],
|
||||
"ibm-granite/granite-3.0-1b-a400m": [
|
||||
[32, 8, 1024, 1024],
|
||||
],
|
||||
"ibm-granite/granite-3.0-3b-a800m": [
|
||||
[40, 8, 1024, 1536],
|
||||
],
|
||||
}
|
||||
|
@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
@ -0,0 +1,457 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_row_array = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||
int group, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, group(group)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
int group;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
l,
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_col_array = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
int group,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
group(group),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
int group;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
l,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||
return Arguments{data_ptr, do_broadcast};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||
to arrays containing different scales used in group gemm. The number of
|
||||
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||
group size.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
static ArgumentType prepare_args(float const* const* a_scales_ptr,
|
||||
float const* const* b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
b_scales_ptr, b_row_broadcast);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c3x
|
||||
|
14
csrc/ops.h
14
csrc/ops.h
@ -164,6 +164,7 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
@ -175,6 +176,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
int64_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] =
|
||||
a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] =
|
||||
b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
|
||||
a_tensors.size(1), per_act_token, per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
@ -0,0 +1,160 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (16, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_K8192 {
|
||||
// K in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (k >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (m <= 16) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
}
|
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_group_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
using StrideC =
|
||||
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int k_size = a_tensors.size(1);
|
||||
int n_size = out_tensors.size(1);
|
||||
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
out_tensors, a_scales, b_scales);
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename GemmKernel::InternalStrideC;
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
problem_sizes.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
per_act_token, per_out_ch),
|
||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
@ -0,0 +1,90 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer, const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS groped FP8 kernels need at least CUDA 12.3
|
||||
// and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability == 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
@ -365,6 +365,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -507,3 +508,136 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
|
||||
# Create separate A, B, C tensors for each group
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
alignment = 16 # 128 // 8
|
||||
# For variation, each group has dimensions
|
||||
n_g = alignment * random.randint(1, 64)
|
||||
k_g = alignment * random.randint(1, 64)
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][0] = m_g
|
||||
problem_sizes[g][1] = n_g
|
||||
problem_sizes[g][2] = k_g
|
||||
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
print("shape:", m_g, n_g, k_g)
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
|
||||
None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
|
||||
1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[
|
||||
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
ab_strides = torch.full((num_experts, ),
|
||||
a_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, ),
|
||||
out_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||
b_tensors_stacked, a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||
problem_sizes, ab_strides, ab_strides, c_strides)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
print(baseline)
|
||||
print(c)
|
||||
print("*")
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
|
244
tests/kernels/test_cutlass_moe.py
Normal file
244
tests/kernels/test_cutlass_moe.py
Normal file
@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
|
||||
def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
|
||||
cutlass_output = cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale1)
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1,
|
||||
c_strides1, ab_strides2, c_strides2)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
@ -587,6 +587,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
-> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -677,6 +680,56 @@ def cutlass_scaled_sparse_mm(
|
||||
return out
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_data(
|
||||
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
|
||||
num_experts: int, n: int, k: int):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2,
|
||||
input_permutation, output_permutation,
|
||||
num_experts, n, k)
|
||||
|
||||
|
||||
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
|
||||
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
|
||||
b_strides: torch.Tensor, c_strides: torch.Tensor):
|
||||
"""
|
||||
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
|
||||
The function executes fp8-quantized OUT = AB matrix multiplication.
|
||||
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
|
||||
"""
|
||||
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
|
||||
b_scales, expert_offsets, problem_sizes,
|
||||
a_strides, b_strides, c_strides)
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
|
@ -36,8 +36,8 @@ if HAS_TRITON:
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
cutlass_moe_fp8, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
@ -45,4 +45,5 @@ if HAS_TRITON:
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
]
|
||||
|
@ -1623,3 +1623,140 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
||||
grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
||||
grouped gemm.
|
||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[
|
||||
0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[
|
||||
0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
assert ab_strides1.shape[0] == w1_q.shape[
|
||||
0], "AB Strides 1 expert number mismatch"
|
||||
assert c_strides1.shape[0] == w1_q.shape[
|
||||
0], "C Strides 1 expert number mismatch"
|
||||
assert ab_strides2.shape[0] == w2_q.shape[
|
||||
0], "AB Strides 2 expert number mismatch"
|
||||
assert c_strides2.shape[0] == w2_q.shape[
|
||||
0], "C Strides 2 expert number mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
device = a_q.device
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
ab_strides1, c_strides1)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
|
||||
return (c2[c_map].view(m, topk, k) *
|
||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
|
@ -96,7 +96,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||
return CompressedTensorsMoEMethod.get_moe_method(
|
||||
self, layer.activation, layer.expert_map)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@ -191,17 +192,26 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
error: bool = True,
|
||||
match_exact: bool = False) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
if match_exact:
|
||||
supported = capability == min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
"the current GPU. Required capability: ",
|
||||
f"{min_capability}. Current capability: {capability}.")
|
||||
else:
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
@ -262,6 +272,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights quantized.
|
||||
|
@ -31,6 +31,7 @@ class GPTQMarlinState(Enum):
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"
|
||||
]
|
||||
|
||||
@ -39,7 +40,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
activation: str,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@ -49,6 +52,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
and activation == "silu" and expert_map is None):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
else:
|
||||
@ -250,6 +256,200 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
a2_scale=layer.w2_input_scale)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
|
||||
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
device = w13_weight.device
|
||||
# TODO strides can be shared across multiple layers
|
||||
self.ab_strides1 = torch.full((num_experts, ),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.c_strides1 = torch.full((num_experts, ),
|
||||
2 * intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.ab_strides2 = torch.full((num_experts, ),
|
||||
intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.c_strides2 = torch.full((num_experts, ),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer.")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu"
|
||||
assert global_num_experts == layer.w13_weight.shape[0]
|
||||
assert expert_map is None
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
|
||||
|
||||
return cutlass_moe_fp8(
|
||||
x,
|
||||
layer.w13_weight.transpose(1, 2),
|
||||
layer.w2_weight.transpose(1, 2),
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.ab_strides1,
|
||||
self.c_strides1,
|
||||
self.ab_strides2,
|
||||
self.c_strides2,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
out_dtype=x.dtype,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
|
@ -50,6 +50,16 @@ def cutlass_block_fp8_supported() -> bool:
|
||||
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
|
||||
|
||||
|
||||
def cutlass_group_gemm_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
|
||||
return ops.cutlass_group_gemm_supported(capability)
|
||||
|
||||
|
||||
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
||||
|
||||
|
@ -1568,18 +1568,21 @@ class ClassRegistry(UserDict[Type[T], _V]):
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
"""
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
|
Loading…
x
Reference in New Issue
Block a user