[Kernel] Enhance MoE benchmarking & tuning script (#4921)
This commit is contained in:
parent
bd0e7802e0
commit
3a434b07ed
@ -1,239 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|
||||||
get_config_file_name)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model, tp_size, gpu, dtype: str):
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
|
||||||
method = fused_moe
|
|
||||||
for bs in [
|
|
||||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
||||||
2048, 3072, 4096
|
|
||||||
]:
|
|
||||||
run_grid(bs,
|
|
||||||
model=model,
|
|
||||||
method=method,
|
|
||||||
gpu=gpu,
|
|
||||||
tp_size=tp_size,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def run_grid(bs, model, method, gpu, tp_size, dtype: str):
|
|
||||||
if model == '8x7B':
|
|
||||||
d_model = 4096
|
|
||||||
model_intermediate_size = 14336
|
|
||||||
num_layers = 32
|
|
||||||
elif model == '8x22B':
|
|
||||||
d_model = 6144
|
|
||||||
model_intermediate_size = 16384
|
|
||||||
num_layers = 56
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported Mixtral model {model}')
|
|
||||||
num_total_experts = 8
|
|
||||||
top_k = 2
|
|
||||||
# tp_size = 2
|
|
||||||
num_calls = 100
|
|
||||||
|
|
||||||
num_warmup_trials = 1
|
|
||||||
num_trials = 1
|
|
||||||
|
|
||||||
configs = []
|
|
||||||
|
|
||||||
for block_size_n in [32, 64, 128, 256]:
|
|
||||||
for block_size_m in [16, 32, 64, 128, 256]:
|
|
||||||
for block_size_k in [64, 128, 256]:
|
|
||||||
for group_size_m in [1, 16, 32, 64]:
|
|
||||||
for num_warps in [4, 8]:
|
|
||||||
for num_stages in [2, 3, 4, 5]:
|
|
||||||
configs.append({
|
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
|
||||||
"GROUP_SIZE_M": group_size_m,
|
|
||||||
"num_warps": num_warps,
|
|
||||||
"num_stages": num_stages,
|
|
||||||
})
|
|
||||||
|
|
||||||
best_config = None
|
|
||||||
best_time_us = 1e20
|
|
||||||
|
|
||||||
print(f'{tp_size=} {bs=}')
|
|
||||||
|
|
||||||
for config in tqdm(configs):
|
|
||||||
# warmup
|
|
||||||
try:
|
|
||||||
for _ in range(num_warmup_trials):
|
|
||||||
run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# trial
|
|
||||||
for _ in range(num_trials):
|
|
||||||
kernel_dur_ms = run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
|
||||||
model_dur_ms = kernel_dur_ms * num_layers
|
|
||||||
|
|
||||||
if kernel_dur_us < best_time_us:
|
|
||||||
best_config = config
|
|
||||||
best_time_us = kernel_dur_us
|
|
||||||
|
|
||||||
tqdm.write(
|
|
||||||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
|
||||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
|
||||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
|
||||||
|
|
||||||
print("best_time_us", best_time_us)
|
|
||||||
print("best_config", best_config)
|
|
||||||
|
|
||||||
# holds Dict[str, Dict[str, int]]
|
|
||||||
filename = get_config_file_name(num_total_experts,
|
|
||||||
model_intermediate_size // tp_size,
|
|
||||||
"float8" if dtype == "float8" else None)
|
|
||||||
print(f"writing config to file {filename}")
|
|
||||||
existing_content = {}
|
|
||||||
if os.path.exists(filename):
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
existing_content = json.load(f)
|
|
||||||
existing_content[str(bs)] = best_config
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(existing_content, f, indent=4)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|
||||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
|
||||||
config, dtype: str) -> float:
|
|
||||||
shard_intermediate_size = model_intermediate_size // tp_size
|
|
||||||
|
|
||||||
hidden_states = torch.rand(
|
|
||||||
(bs, d_model),
|
|
||||||
device="cuda:0",
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1 = torch.rand(
|
|
||||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w2 = torch.rand(
|
|
||||||
(num_total_experts, d_model, shard_intermediate_size),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1_scale = None
|
|
||||||
w2_scale = None
|
|
||||||
a1_scale = None
|
|
||||||
a2_scale = None
|
|
||||||
|
|
||||||
if dtype == "float8":
|
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
|
||||||
w1_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a1_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a2_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
gating_output = F.softmax(torch.rand(
|
|
||||||
(num_calls, bs, num_total_experts),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
for i in range(num_calls):
|
|
||||||
hidden_states = method(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
gating_output=gating_output[i],
|
|
||||||
topk=2,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
override_config=config,
|
|
||||||
use_fp8=dtype == "float8",
|
|
||||||
)
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
|
|
||||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
||||||
return dur_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog='benchmark_mixtral_moe',
|
|
||||||
description='Benchmark and tune the fused_moe kernel',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['float8', 'float16'],
|
|
||||||
help='Data type used for fused_moe kernel computations',
|
|
||||||
)
|
|
||||||
parser.add_argument('--model',
|
|
||||||
type=str,
|
|
||||||
default='8x7B',
|
|
||||||
choices=['8x7B', '8x22B'],
|
|
||||||
help='The Mixtral model to benchmark')
|
|
||||||
parser.add_argument('--tp-size',
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help='Tensor paralleli size')
|
|
||||||
parser.add_argument('--gpu',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="GPU ID for benchmarking")
|
|
||||||
args = parser.parse_args()
|
|
||||||
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))
|
|
319
benchmarks/kernels/benchmark_moe.py
Normal file
319
benchmarks/kernels/benchmark_moe.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: Dict[str, int],
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
w1 = torch.randn(num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=init_dtype)
|
||||||
|
w2 = torch.randn(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
|
gating_output = torch.randn(num_iters,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_fp8:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128, 256]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
})
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> Tuple[Dict[str, int], float]:
|
||||||
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype_str)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(),
|
||||||
|
key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8,
|
||||||
|
num_iters=10)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: Dict[int, Dict[str, int]],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> None:
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
use_fp8 = args.dtype == "fp8"
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8, search_space)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config)
|
||||||
|
for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute("benchmark",
|
||||||
|
[(batch_size, E, shard_intermediate_size,
|
||||||
|
hidden_size, topk, dtype, use_fp8)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model",
|
||||||
|
type=str,
|
||||||
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
@ -308,6 +308,30 @@ def get_moe_configs(E: int, N: int,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_config(
|
||||||
|
M: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
config = {
|
||||||
|
'BLOCK_SIZE_M': 64,
|
||||||
|
'BLOCK_SIZE_N': 64,
|
||||||
|
'BLOCK_SIZE_K': 32,
|
||||||
|
'GROUP_SIZE_M': 8
|
||||||
|
}
|
||||||
|
if M <= E:
|
||||||
|
config = {
|
||||||
|
'BLOCK_SIZE_M': 16,
|
||||||
|
'BLOCK_SIZE_N': 32,
|
||||||
|
'BLOCK_SIZE_K': 64,
|
||||||
|
'GROUP_SIZE_M': 1
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -382,20 +406,9 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
else:
|
else:
|
||||||
# Else use the default config
|
# Else use the default config
|
||||||
config = {
|
config = get_default_config(M, E, N, w1.shape[2],
|
||||||
'BLOCK_SIZE_M': 64,
|
topk_ids.shape[1],
|
||||||
'BLOCK_SIZE_N': 64,
|
"float8" if use_fp8 else None)
|
||||||
'BLOCK_SIZE_K': 32,
|
|
||||||
'GROUP_SIZE_M': 8
|
|
||||||
}
|
|
||||||
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
'BLOCK_SIZE_M': 16,
|
|
||||||
'BLOCK_SIZE_N': 32,
|
|
||||||
'BLOCK_SIZE_K': 64,
|
|
||||||
'GROUP_SIZE_M': 1
|
|
||||||
}
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user