341 lines
13 KiB
Python
341 lines
13 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.benchmark as benchmark
|
||
|
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||
|
|
||
|
from vllm import _custom_ops as ops
|
||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||
|
fused_experts,
|
||
|
fused_topk)
|
||
|
from vllm.utils import FlexibleArgumentParser
|
||
|
|
||
|
DEFAULT_MODELS = [
|
||
|
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||
|
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||
|
]
|
||
|
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||
|
DEFAULT_TP_SIZES = [1]
|
||
|
|
||
|
PER_ACT_TOKEN_OPTS = [False]
|
||
|
PER_OUT_CH_OPTS = [False]
|
||
|
|
||
|
|
||
|
def to_fp8(tensor: torch.Tensor):
|
||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||
|
return torch.round(tensor.clamp(
|
||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||
|
|
||
|
|
||
|
def bench_run(results: list[benchmark.Measurement], model: str,
|
||
|
num_experts: int, topk: int, per_act_token: bool,
|
||
|
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||
|
label = "Quant Matmul"
|
||
|
|
||
|
sub_label = (
|
||
|
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||
|
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||
|
mkn))
|
||
|
|
||
|
print(f"Testing: {sub_label}")
|
||
|
|
||
|
(m, k, n) = mkn
|
||
|
|
||
|
dtype = torch.half
|
||
|
|
||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||
|
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||
|
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||
|
|
||
|
_, a_scale = ops.scaled_fp8_quant(a)
|
||
|
|
||
|
w1_q = torch.empty((num_experts, 2 * n, k),
|
||
|
device="cuda",
|
||
|
dtype=torch.float8_e4m3fn)
|
||
|
w2_q = torch.empty((num_experts, k, n),
|
||
|
device="cuda",
|
||
|
dtype=torch.float8_e4m3fn)
|
||
|
w1_scale = torch.empty((num_experts, 1, 1),
|
||
|
device="cuda",
|
||
|
dtype=torch.float32)
|
||
|
w2_scale = torch.empty((num_experts, 1, 1),
|
||
|
device="cuda",
|
||
|
dtype=torch.float32)
|
||
|
|
||
|
ab_strides1 = torch.full((num_experts, ),
|
||
|
k,
|
||
|
device="cuda",
|
||
|
dtype=torch.int64)
|
||
|
c_strides1 = torch.full((num_experts, ),
|
||
|
2 * n,
|
||
|
device="cuda",
|
||
|
dtype=torch.int64)
|
||
|
ab_strides2 = torch.full((num_experts, ),
|
||
|
n,
|
||
|
device="cuda",
|
||
|
dtype=torch.int64)
|
||
|
c_strides2 = torch.full((num_experts, ),
|
||
|
k,
|
||
|
device="cuda",
|
||
|
dtype=torch.int64)
|
||
|
|
||
|
for expert in range(num_experts):
|
||
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||
|
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||
|
w1_q_notransp = w1_q.clone()
|
||
|
w2_q_notransp = w2_q.clone()
|
||
|
w1_q = w1_q.transpose(1, 2)
|
||
|
w2_q = w2_q.transpose(1, 2)
|
||
|
|
||
|
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||
|
|
||
|
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||
|
|
||
|
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||
|
a_scale: torch.Tensor, num_repeats: int):
|
||
|
for _ in range(num_repeats):
|
||
|
fused_experts(a,
|
||
|
w1,
|
||
|
w2,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
use_fp8_w8a8=True,
|
||
|
w1_scale=w1_scale,
|
||
|
w2_scale=w2_scale,
|
||
|
a1_scale=a_scale)
|
||
|
|
||
|
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||
|
w1: torch.Tensor, w2: torch.Tensor,
|
||
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||
|
num_repeats: int):
|
||
|
for _ in range(num_repeats):
|
||
|
cutlass_moe_fp8(a,
|
||
|
w1,
|
||
|
w2,
|
||
|
w1_scale,
|
||
|
w2_scale,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
ab_strides1,
|
||
|
c_strides1,
|
||
|
ab_strides2,
|
||
|
c_strides2,
|
||
|
a1_scale=a_scale)
|
||
|
|
||
|
def run_cutlass_from_graph(
|
||
|
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||
|
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||
|
with set_current_vllm_config(
|
||
|
VllmConfig(parallel_config=ParallelConfig(
|
||
|
pipeline_parallel_size=1))):
|
||
|
return cutlass_moe_fp8(a,
|
||
|
w1_q,
|
||
|
w2_q,
|
||
|
w1_scale,
|
||
|
w2_scale,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
ab_strides1,
|
||
|
c_strides1,
|
||
|
ab_strides2,
|
||
|
c_strides2,
|
||
|
a1_scale=a_scale)
|
||
|
|
||
|
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||
|
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||
|
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||
|
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||
|
with set_current_vllm_config(
|
||
|
VllmConfig(parallel_config=ParallelConfig(
|
||
|
pipeline_parallel_size=1))):
|
||
|
return fused_experts(a,
|
||
|
w1,
|
||
|
w2,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
use_fp8_w8a8=True,
|
||
|
w1_scale=w1_scale,
|
||
|
w2_scale=w2_scale,
|
||
|
a1_scale=a_scale)
|
||
|
|
||
|
def replay_graph(graph, num_repeats):
|
||
|
for _ in range(num_repeats):
|
||
|
graph.replay()
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
cutlass_stream = torch.cuda.Stream()
|
||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||
|
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||
|
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||
|
ab_strides2, c_strides2)
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
triton_stream = torch.cuda.Stream()
|
||
|
triton_graph = torch.cuda.CUDAGraph()
|
||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||
|
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||
|
topk_ids, w1_scale, w2_scale, a_scale)
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
min_run_time = 5
|
||
|
num_warmup = 5
|
||
|
num_runs = 25
|
||
|
|
||
|
globals = {
|
||
|
# Baseline params
|
||
|
"w1": w1,
|
||
|
"w2": w2,
|
||
|
"score": score,
|
||
|
"topk": topk,
|
||
|
"w1_q_notransp": w1_q_notransp,
|
||
|
"w2_q_notransp": w2_q_notransp,
|
||
|
# Cutlass params
|
||
|
"a_scale": a_scale,
|
||
|
"w1_q": w1_q,
|
||
|
"w2_q": w2_q,
|
||
|
"w1_scale": w1_scale,
|
||
|
"w2_scale": w2_scale,
|
||
|
"ab_strides1": ab_strides1,
|
||
|
"c_strides1": c_strides1,
|
||
|
"ab_strides2": ab_strides2,
|
||
|
"c_strides2": c_strides2,
|
||
|
# cuda graph params
|
||
|
"cutlass_graph": cutlass_graph,
|
||
|
"triton_graph": triton_graph,
|
||
|
# Gen params
|
||
|
"a": a,
|
||
|
"topk_weights": topk_weights,
|
||
|
"topk_ids": topk_ids,
|
||
|
"num_runs": num_runs,
|
||
|
# Kernels
|
||
|
"run_triton_moe": run_triton_moe,
|
||
|
"run_cutlass_moe": run_cutlass_moe,
|
||
|
"replay_graph": replay_graph,
|
||
|
}
|
||
|
|
||
|
# Warmup
|
||
|
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||
|
w1_scale, w2_scale, a_scale, num_warmup)
|
||
|
|
||
|
results.append(
|
||
|
benchmark.Timer(
|
||
|
stmt=
|
||
|
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||
|
globals=globals,
|
||
|
label=label,
|
||
|
sub_label=sub_label,
|
||
|
description="triton_moe",
|
||
|
).blocked_autorange(min_run_time=min_run_time))
|
||
|
|
||
|
# Warmup
|
||
|
replay_graph(triton_graph, num_warmup)
|
||
|
|
||
|
results.append(
|
||
|
benchmark.Timer(
|
||
|
stmt="replay_graph(triton_graph, num_runs)",
|
||
|
globals=globals,
|
||
|
label=label,
|
||
|
sub_label=sub_label,
|
||
|
description="triton_moe_cuda_graphs",
|
||
|
).blocked_autorange(min_run_time=min_run_time))
|
||
|
|
||
|
# Warmup
|
||
|
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||
|
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||
|
num_warmup)
|
||
|
|
||
|
results.append(
|
||
|
benchmark.Timer(
|
||
|
stmt=
|
||
|
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||
|
globals=globals,
|
||
|
label=label,
|
||
|
sub_label=sub_label,
|
||
|
description="grouped_gemm_moe",
|
||
|
).blocked_autorange(min_run_time=min_run_time))
|
||
|
|
||
|
# Warmup
|
||
|
replay_graph(cutlass_graph, num_warmup)
|
||
|
|
||
|
results.append(
|
||
|
benchmark.Timer(
|
||
|
stmt="replay_graph(cutlass_graph, num_runs)",
|
||
|
globals=globals,
|
||
|
label=label,
|
||
|
sub_label=sub_label,
|
||
|
description="grouped_gemm_moe_cuda_graphs",
|
||
|
).blocked_autorange(min_run_time=min_run_time))
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
print("Benchmarking models:")
|
||
|
for i, model in enumerate(args.models):
|
||
|
print(f"[{i}] {model}")
|
||
|
|
||
|
results: list[benchmark.Measurement] = []
|
||
|
|
||
|
for model in args.models:
|
||
|
for tp in args.tp_sizes:
|
||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||
|
num_experts = layer[0]
|
||
|
topk = layer[1]
|
||
|
size_k = layer[2]
|
||
|
size_n = layer[3] // tp
|
||
|
|
||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||
|
continue
|
||
|
|
||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||
|
continue
|
||
|
|
||
|
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||
|
for per_out_ch in PER_OUT_CH_OPTS:
|
||
|
for size_m in DEFAULT_BATCH_SIZES:
|
||
|
mkn = (size_m, size_k, size_n)
|
||
|
bench_run(results, model, num_experts, topk,
|
||
|
per_act_token, per_out_ch, mkn)
|
||
|
|
||
|
compare = benchmark.Compare(results)
|
||
|
compare.print()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = FlexibleArgumentParser(
|
||
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||
|
parser.add_argument(
|
||
|
"--models",
|
||
|
nargs="+",
|
||
|
type=str,
|
||
|
default=DEFAULT_MODELS,
|
||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||
|
)
|
||
|
parser.add_argument("--tp-sizes",
|
||
|
nargs="+",
|
||
|
type=int,
|
||
|
default=DEFAULT_TP_SIZES)
|
||
|
parser.add_argument("--batch-sizes",
|
||
|
nargs="+",
|
||
|
type=int,
|
||
|
default=DEFAULT_BATCH_SIZES)
|
||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||
|
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||
|
parser.add_argument("--limit-per-act-token",
|
||
|
nargs="+",
|
||
|
type=int,
|
||
|
default=[])
|
||
|
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
main(args)
|