vllm/benchmarks/kernels/benchmark_moe.py

561 lines
20 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
import argparse
import time
from datetime import datetime
from itertools import product
from typing import Any, Dict, List, Tuple, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) else torch.float8_e4m3fn
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_tuning_space(use_fp16):
block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []
param_ranges = {
"BLOCK_SIZE_M": block_mn_range,
"BLOCK_SIZE_N": block_mn_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range
return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16)
else:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [64, 128, 256]
num_warps_range = [4, 8]
group_m_range = [1, 16, 32, 64]
num_stage_range = [2, 3, 4, 5]
param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
}
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
pruned_configs = []
elemBytes_a = 2 if is_fp16 else 1
elemBytes_b = 2 if is_fp16 else 1
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
if is_fp16:
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if (matrix_instr_nonkdim >= M
and matrix_instr_nonkdim != BLOCK_SIZE_M):
continue
if (matrix_instr_nonkdim >= N
and matrix_instr_nonkdim != BLOCK_SIZE_N):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def merge_unique_dicts(list1, list2):
result = []
combined_list = list1.copy()
combined_list.extend(list2)
for dictionary in combined_list:
if dictionary not in result:
result.append(dictionary)
return result
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
dtype_str)
if op_config is None:
config = get_default_config(num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
is_marlin=False)
else:
config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8,
use_int8_w8a16)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = prune_rocm_search_space(num_tokens,
shard_intermediate_size,
hidden_size, search_space,
is_fp16)
with torch.cuda.device(self.device_id):
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--tp-size",
"-tp",
"--tensor-parallel-size",
type=int,
default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
main(args)