# SPDX-License-Identifier: Apache-2.0 import torch import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, fused_experts, fused_topk) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = [ "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] PER_ACT_TOKEN_OPTS = [False] PER_OUT_CH_OPTS = [False] def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp( min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) def bench_run(results: list[benchmark.Measurement], model: str, num_experts: int, topk: int, per_act_token: bool, per_out_ch: bool, mkn: tuple[int, int, int]): label = "Quant Matmul" sub_label = ( "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, mkn)) print(f"Testing: {sub_label}") (m, k, n) = mkn dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 _, a_scale = ops.scaled_fp8_quant(a) w1_q = torch.empty((num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) ab_strides1 = torch.full((num_experts, ), k, device="cuda", dtype=torch.int64) c_strides1 = torch.full((num_experts, ), 2 * n, device="cuda", dtype=torch.int64) ab_strides2 = torch.full((num_experts, ), n, device="cuda", dtype=torch.int64) c_strides2 = torch.full((num_experts, ), k, device="cuda", dtype=torch.int64) for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) w1_q_notransp = w1_q.clone() w2_q_notransp = w2_q.clone() w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) score = torch.randn((m, num_experts), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_scale: torch.Tensor, num_repeats: int): for _ in range(num_repeats): fused_experts(a, w1, w2, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale) def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, num_repeats: int): for _ in range(num_repeats): cutlass_moe_fp8(a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, a1_scale=a_scale) def run_cutlass_from_graph( a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): return cutlass_moe_fp8(a, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, a1_scale=a_scale) def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_scale: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): return fused_experts(a, w1, w2, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale) def replay_graph(graph, num_repeats): for _ in range(num_repeats): graph.replay() torch.cuda.synchronize() cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale) torch.cuda.synchronize() min_run_time = 5 num_warmup = 5 num_runs = 25 globals = { # Baseline params "w1": w1, "w2": w2, "score": score, "topk": topk, "w1_q_notransp": w1_q_notransp, "w2_q_notransp": w2_q_notransp, # Cutlass params "a_scale": a_scale, "w1_q": w1_q, "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, "ab_strides1": ab_strides1, "c_strides1": c_strides1, "ab_strides2": ab_strides2, "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, # Gen params "a": a, "topk_weights": topk_weights, "topk_ids": topk_ids, "num_runs": num_runs, # Kernels "run_triton_moe": run_triton_moe, "run_cutlass_moe": run_cutlass_moe, "replay_graph": replay_graph, } # Warmup run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_warmup) results.append( benchmark.Timer( stmt= "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", ).blocked_autorange(min_run_time=min_run_time)) # Warmup replay_graph(triton_graph, num_warmup) results.append( benchmark.Timer( stmt="replay_graph(triton_graph, num_runs)", globals=globals, label=label, sub_label=sub_label, description="triton_moe_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time)) # Warmup run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_warmup) results.append( benchmark.Timer( stmt= "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="grouped_gemm_moe", ).blocked_autorange(min_run_time=min_run_time)) # Warmup replay_graph(cutlass_graph, num_warmup) results.append( benchmark.Timer( stmt="replay_graph(cutlass_graph, num_runs)", globals=globals, label=label, sub_label=sub_label, description="grouped_gemm_moe_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time)) def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") results: list[benchmark.Measurement] = [] for model in args.models: for tp in args.tp_sizes: for layer in WEIGHT_SHAPES_MOE[model]: num_experts = layer[0] topk = layer[1] size_k = layer[2] size_n = layer[3] // tp if len(args.limit_k) > 0 and size_k not in args.limit_k: continue if len(args.limit_n) > 0 and size_n not in args.limit_n: continue for per_act_token in PER_ACT_TOKEN_OPTS: for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = (size_m, size_k, size_n) bench_run(results, model, num_experts, topk, per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark Marlin across specified models/shapes/batches") parser.add_argument( "--models", nargs="+", type=str, default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) parser.add_argument("--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() main(args)