
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
561 lines
20 KiB
Python
561 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
import time
|
|
from datetime import datetime
|
|
from itertools import product
|
|
from typing import Any, Dict, List, Tuple, TypedDict
|
|
|
|
import ray
|
|
import torch
|
|
import triton
|
|
from ray.experimental.tqdm_ray import tqdm
|
|
from transformers import AutoConfig
|
|
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
|
) else torch.float8_e4m3fn
|
|
|
|
|
|
class BenchmarkConfig(TypedDict):
|
|
BLOCK_SIZE_M: int
|
|
BLOCK_SIZE_N: int
|
|
BLOCK_SIZE_K: int
|
|
GROUP_SIZE_M: int
|
|
num_warps: int
|
|
num_stages: int
|
|
|
|
|
|
def benchmark_config(
|
|
config: BenchmarkConfig,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
num_iters: int = 100,
|
|
) -> float:
|
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
|
if use_int8_w8a16:
|
|
w1 = torch.randint(-127,
|
|
127, (
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
),
|
|
dtype=torch.int8)
|
|
w2 = torch.randint(-127,
|
|
127, (
|
|
num_experts,
|
|
hidden_size,
|
|
shard_intermediate_size // 2,
|
|
),
|
|
dtype=torch.int8)
|
|
else:
|
|
w1 = torch.randn(num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
dtype=init_dtype)
|
|
w2 = torch.randn(num_experts,
|
|
hidden_size,
|
|
shard_intermediate_size // 2,
|
|
dtype=init_dtype)
|
|
gating_output = torch.randn(num_iters,
|
|
num_tokens,
|
|
num_experts,
|
|
dtype=torch.float32)
|
|
|
|
w1_scale = None
|
|
w2_scale = None
|
|
a1_scale = None
|
|
a2_scale = None
|
|
if use_int8_w8a16:
|
|
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
|
dtype=torch.float32)
|
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
|
if use_fp8_w8a8:
|
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
|
|
|
w1 = w1.to(FP8_DTYPE)
|
|
w2 = w2.to(FP8_DTYPE)
|
|
|
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
|
|
|
def prepare(i: int):
|
|
input_gating.copy_(gating_output[i])
|
|
|
|
def run():
|
|
from vllm.model_executor.layers.fused_moe import override_config
|
|
with override_config(config):
|
|
fused_moe(
|
|
x,
|
|
w1,
|
|
w2,
|
|
input_gating,
|
|
topk,
|
|
renormalize=True,
|
|
inplace=True,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
)
|
|
|
|
# JIT compilation & warmup
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Capture 10 invocations with CUDA graph
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
for _ in range(10):
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Warmup
|
|
for _ in range(5):
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
latencies: List[float] = []
|
|
for i in range(num_iters):
|
|
prepare(i)
|
|
torch.cuda.synchronize()
|
|
|
|
start_event.record()
|
|
graph.replay()
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
latencies.append(start_event.elapsed_time(end_event))
|
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
|
graph.reset()
|
|
return avg
|
|
|
|
|
|
def get_rocm_tuning_space(use_fp16):
|
|
block_mn_range = [16, 32, 64, 128, 256]
|
|
block_k_range = [16, 32, 64, 128, 256]
|
|
if not use_fp16:
|
|
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
|
num_warps_range = [1, 2, 4, 8]
|
|
group_m_range = [1, 4, 8, 16, 32]
|
|
num_stage_range = [2]
|
|
waves_per_eu_range = [0]
|
|
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
|
kpack_range = [1, 2] if use_fp16 else []
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_mn_range,
|
|
"BLOCK_SIZE_N": block_mn_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
"waves_per_eu": waves_per_eu_range,
|
|
}
|
|
if use_fp16:
|
|
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
|
param_ranges["kpack"] = kpack_range
|
|
|
|
return param_ranges
|
|
|
|
|
|
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
|
configs: List[BenchmarkConfig] = []
|
|
|
|
if current_platform.is_rocm():
|
|
param_ranges = get_rocm_tuning_space(use_fp16)
|
|
else:
|
|
# Reduced search space for faster tuning.
|
|
# TODO(woosuk): Increase the search space and use a performance model to
|
|
# prune the search space.
|
|
block_m_range = [16, 32, 64, 128, 256]
|
|
block_n_range = [32, 64, 128, 256]
|
|
block_k_range = [64, 128, 256]
|
|
num_warps_range = [4, 8]
|
|
group_m_range = [1, 16, 32, 64]
|
|
num_stage_range = [2, 3, 4, 5]
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_m_range,
|
|
"BLOCK_SIZE_N": block_n_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
}
|
|
|
|
keys, values = zip(*param_ranges.items())
|
|
for config_values in product(*values):
|
|
config = dict(zip(keys, config_values))
|
|
configs.append(config)
|
|
return configs
|
|
|
|
|
|
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
|
search_space, is_fp16):
|
|
N1, K1 = shard_intermediate_size, hidden_size
|
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
|
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
|
is_fp16)
|
|
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
|
is_fp16)
|
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
|
return search_space
|
|
|
|
|
|
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
|
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
|
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
|
pruned_configs = []
|
|
elemBytes_a = 2 if is_fp16 else 1
|
|
elemBytes_b = 2 if is_fp16 else 1
|
|
|
|
mfma = 16 if M < 32 or N < 32 else 32
|
|
|
|
# TODO (zhanglx): figure out the boundary between large and small gemms
|
|
large_gemm = False
|
|
if M >= 2048 and N >= 2048:
|
|
large_gemm = True
|
|
|
|
for config in configs:
|
|
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
|
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
|
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
|
num_warps = config.get("num_warps")
|
|
|
|
if is_fp16:
|
|
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
|
if matrix_instr_nonkdim > mfma:
|
|
continue
|
|
if mfma == 4 and BLOCK_SIZE_K < 64:
|
|
continue
|
|
# some layouts could not work properly in case
|
|
# number elements per thread is less 1
|
|
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
|
continue
|
|
SPLIT_K = config.get("SPLIT_K", 1)
|
|
GROUP_M = config.get("GROUP_SIZE_M")
|
|
if is_fp16:
|
|
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
|
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
|
continue
|
|
if (matrix_instr_nonkdim >= M
|
|
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
|
continue
|
|
if (matrix_instr_nonkdim >= N
|
|
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
|
continue
|
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
|
# unless BLOCK_SIZE is already small enough
|
|
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
|
continue
|
|
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
|
continue
|
|
# skip large split_k when not necessary
|
|
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
|
continue
|
|
# skip split_k that leads to EVEN_K = false
|
|
leap = SPLIT_K * BLOCK_SIZE_K
|
|
modv = K % leap
|
|
if modv != 0:
|
|
continue
|
|
# skip large GROUP_M
|
|
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
|
continue
|
|
# out of shared memory resource
|
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
|
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
|
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
|
if LDS > 65536:
|
|
continue
|
|
# Skip small block sizes and num_warps for large gemm
|
|
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
|
if large_gemm:
|
|
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
|
continue
|
|
if BLOCK_SIZE_K < 64:
|
|
continue
|
|
if num_warps < 4:
|
|
continue
|
|
|
|
pruned_configs.append(config)
|
|
|
|
return pruned_configs
|
|
|
|
|
|
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
|
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
|
|
|
|
|
def merge_unique_dicts(list1, list2):
|
|
result = []
|
|
combined_list = list1.copy()
|
|
combined_list.extend(list2)
|
|
for dictionary in combined_list:
|
|
if dictionary not in result:
|
|
result.append(dictionary)
|
|
return result
|
|
|
|
|
|
@ray.remote(num_gpus=1)
|
|
class BenchmarkWorker:
|
|
|
|
def __init__(self, seed: int) -> None:
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(seed)
|
|
self.seed = seed
|
|
# Get the device ID to allocate tensors and kernels
|
|
# on the respective GPU. This is required for Ray to work
|
|
# correctly with multi-GPU tuning on the ROCm platform.
|
|
self.device_id = int(ray.get_gpu_ids()[0])
|
|
|
|
def benchmark(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
) -> Tuple[Dict[str, int], float]:
|
|
current_platform.seed_everything(self.seed)
|
|
dtype_str = get_config_dtype_str(dtype,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_fp8_w8a8=use_fp8_w8a8)
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
|
dtype_str)
|
|
if op_config is None:
|
|
config = get_default_config(num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype_str,
|
|
is_marlin=False)
|
|
else:
|
|
config = op_config[min(op_config.keys(),
|
|
key=lambda x: abs(x - num_tokens))]
|
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
|
shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8,
|
|
use_int8_w8a16)
|
|
return config, kernel_time
|
|
|
|
def tune(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
search_space: List[Dict[str, int]],
|
|
) -> Dict[str, int]:
|
|
best_config = None
|
|
best_time = float("inf")
|
|
if current_platform.is_rocm():
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = prune_rocm_search_space(num_tokens,
|
|
shard_intermediate_size,
|
|
hidden_size, search_space,
|
|
is_fp16)
|
|
|
|
with torch.cuda.device(self.device_id):
|
|
for config in tqdm(search_space):
|
|
try:
|
|
kernel_time = benchmark_config(config,
|
|
num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
num_iters=20)
|
|
except triton.runtime.autotuner.OutOfResources:
|
|
# Some configurations may be invalid and fail to compile.
|
|
continue
|
|
|
|
if kernel_time < best_time:
|
|
best_time = kernel_time
|
|
best_config = config
|
|
now = datetime.now()
|
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
|
assert best_config is not None
|
|
return best_config
|
|
|
|
|
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|
return {
|
|
"BLOCK_SIZE_M":
|
|
config["BLOCK_SIZE_M"],
|
|
"BLOCK_SIZE_N":
|
|
config["BLOCK_SIZE_N"],
|
|
"BLOCK_SIZE_K":
|
|
config["BLOCK_SIZE_K"],
|
|
"GROUP_SIZE_M":
|
|
config["GROUP_SIZE_M"],
|
|
"num_warps":
|
|
config["num_warps"],
|
|
"num_stages":
|
|
config["num_stages"],
|
|
**({
|
|
"waves_per_eu": config["waves_per_eu"]
|
|
} if "waves_per_eu" in config else {}),
|
|
**({
|
|
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
|
} if "matrix_instr_nonkdim" in config else {}),
|
|
**({
|
|
"kpack": config["kpack"]
|
|
} if "kpack" in config else {}),
|
|
}
|
|
|
|
|
|
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
|
dtype: torch.dtype, use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool) -> None:
|
|
dtype_str = get_config_dtype_str(dtype,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_fp8_w8a8=use_fp8_w8a8)
|
|
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
|
dtype_str)
|
|
|
|
print(f"Writing best config to {filename}...")
|
|
with open(filename, "w") as f:
|
|
json.dump(configs, f, indent=4)
|
|
f.write("\n")
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
print(args)
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
args.model, trust_remote_code=args.trust_remote_code)
|
|
if config.architectures[0] == "DbrxForCausalLM":
|
|
E = config.ffn_config.moe_num_experts
|
|
topk = config.ffn_config.moe_top_k
|
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
elif config.architectures[0] == "JambaForCausalLM":
|
|
E = config.num_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
|
E = config.n_routed_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.moe_intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
else:
|
|
# Default: Mixtral.
|
|
E = config.num_local_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
|
|
hidden_size = config.hidden_size
|
|
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
|
|
|
if args.batch_size is None:
|
|
batch_sizes = [
|
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
2048, 3072, 4096
|
|
]
|
|
else:
|
|
batch_sizes = [args.batch_size]
|
|
|
|
ray.init()
|
|
num_gpus = int(ray.available_resources()["GPU"])
|
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
|
|
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
|
outputs = []
|
|
worker_idx = 0
|
|
for input_args in inputs:
|
|
worker = workers[worker_idx]
|
|
worker_method = getattr(worker, method)
|
|
output = worker_method.remote(*input_args)
|
|
outputs.append(output)
|
|
worker_idx = (worker_idx + 1) % num_gpus
|
|
return ray.get(outputs)
|
|
|
|
if args.tune:
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = get_configs_compute_bound(is_fp16)
|
|
print(f"Start tuning over {len(search_space)} configurations...")
|
|
|
|
start = time.time()
|
|
configs = _distribute(
|
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
|
for batch_size in batch_sizes])
|
|
best_configs = {
|
|
M: sort_config(config)
|
|
for M, config in zip(batch_sizes, configs)
|
|
}
|
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
|
end = time.time()
|
|
print(f"Tuning took {end - start:.2f} seconds")
|
|
else:
|
|
outputs = _distribute(
|
|
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
|
for batch_size in batch_sizes])
|
|
|
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
|
print(f"Batch size: {batch_size}, config: {config}")
|
|
print(f"Kernel time: {kernel_time:.2f} us")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = FlexibleArgumentParser()
|
|
parser.add_argument("--model",
|
|
type=str,
|
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
|
parser.add_argument("--tp-size",
|
|
"-tp",
|
|
"--tensor-parallel-size",
|
|
type=int,
|
|
default=2)
|
|
parser.add_argument("--dtype",
|
|
type=str,
|
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
|
default="auto")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--batch-size", type=int, required=False)
|
|
parser.add_argument("--tune", action="store_true")
|
|
parser.add_argument("--trust-remote-code", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|