[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo. All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens. Before this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.1 ms ITL, 0.52s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 14.0 ms ITL, 0.70s e2e latency qps = 10: 15.7 ms ITL, 0.79s e2e latency After this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.2 ms ITL, 0.53s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 11.9 ms ITL, 0.59s e2e latency qps = 10: 12.1 ms ITL, 0.61s e2e latency
This commit is contained in:
parent
a657bfc48a
commit
24bb4fe432
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -5,6 +6,7 @@ import sys
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
||||||
get_config_file_name)
|
get_config_file_name)
|
||||||
@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(dtype: str):
|
||||||
method = fused_moe
|
method = fused_moe
|
||||||
for bs in [
|
for bs in [
|
||||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
2048, 3072, 4096
|
2048, 3072, 4096
|
||||||
]:
|
]:
|
||||||
run_grid(bs, method=method)
|
run_grid(bs, method=method, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def run_grid(bs, method):
|
def run_grid(bs, method, dtype: str):
|
||||||
d_model = 4096
|
d_model = 4096
|
||||||
num_total_experts = 8
|
num_total_experts = 8
|
||||||
top_k = 2
|
top_k = 2
|
||||||
@ -34,39 +36,29 @@ def run_grid(bs, method):
|
|||||||
num_trials = 1
|
num_trials = 1
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
if bs <= 16:
|
|
||||||
BLOCK_SIZES_M = [16]
|
|
||||||
elif bs <= 32:
|
|
||||||
BLOCK_SIZES_M = [16, 32]
|
|
||||||
elif bs <= 64:
|
|
||||||
BLOCK_SIZES_M = [16, 32, 64]
|
|
||||||
elif bs <= 128:
|
|
||||||
BLOCK_SIZES_M = [16, 32, 64, 128]
|
|
||||||
else:
|
|
||||||
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
|
|
||||||
|
|
||||||
for block_size_n in [32, 64, 128, 256]:
|
for block_size_n in [32, 64, 128, 256]:
|
||||||
for block_size_m in BLOCK_SIZES_M:
|
for block_size_m in [16, 32, 64, 128, 256]:
|
||||||
for block_size_k in [64, 128, 256]:
|
for block_size_k in [64, 128, 256]:
|
||||||
for group_size_m in [1, 16, 32, 64]:
|
for group_size_m in [1, 16, 32, 64]:
|
||||||
for num_warps in [4, 8]:
|
for num_warps in [4, 8]:
|
||||||
configs.append({
|
for num_stages in [2, 3, 4, 5]:
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
configs.append({
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
"GROUP_SIZE_M": group_size_m,
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
"num_warps": num_warps,
|
"GROUP_SIZE_M": group_size_m,
|
||||||
"num_stages": 4,
|
"num_warps": num_warps,
|
||||||
})
|
"num_stages": num_stages,
|
||||||
|
})
|
||||||
|
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time_us = 1e20
|
best_time_us = 1e20
|
||||||
|
|
||||||
for config in configs:
|
print(f'{tp_size=} {bs=}')
|
||||||
print(f'{tp_size=} {bs=}')
|
|
||||||
print(f'{config}')
|
for config in tqdm(configs):
|
||||||
# warmup
|
# warmup
|
||||||
print('warming up')
|
|
||||||
try:
|
try:
|
||||||
for _ in range(num_warmup_trials):
|
for _ in range(num_warmup_trials):
|
||||||
run_timing(
|
run_timing(
|
||||||
@ -79,12 +71,12 @@ def run_grid(bs, method):
|
|||||||
model_intermediate_size=model_intermediate_size,
|
model_intermediate_size=model_intermediate_size,
|
||||||
method=method,
|
method=method,
|
||||||
config=config,
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# trial
|
# trial
|
||||||
print('benchmarking')
|
|
||||||
for _ in range(num_trials):
|
for _ in range(num_trials):
|
||||||
kernel_dur_ms = run_timing(
|
kernel_dur_ms = run_timing(
|
||||||
num_calls=num_calls,
|
num_calls=num_calls,
|
||||||
@ -96,6 +88,7 @@ def run_grid(bs, method):
|
|||||||
model_intermediate_size=model_intermediate_size,
|
model_intermediate_size=model_intermediate_size,
|
||||||
method=method,
|
method=method,
|
||||||
config=config,
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
kernel_dur_us = 1000 * kernel_dur_ms
|
||||||
@ -105,16 +98,18 @@ def run_grid(bs, method):
|
|||||||
best_config = config
|
best_config = config
|
||||||
best_time_us = kernel_dur_us
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
tqdm.write(
|
||||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
||||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
||||||
|
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
||||||
|
|
||||||
print("best_time_us", best_time_us)
|
print("best_time_us", best_time_us)
|
||||||
print("best_config", best_config)
|
print("best_config", best_config)
|
||||||
|
|
||||||
# holds Dict[str, Dict[str, int]]
|
# holds Dict[str, Dict[str, int]]
|
||||||
filename = get_config_file_name(num_total_experts,
|
filename = get_config_file_name(num_total_experts,
|
||||||
model_intermediate_size // tp_size)
|
model_intermediate_size // tp_size,
|
||||||
|
"float8" if dtype == "float8" else None)
|
||||||
print(f"writing config to file {filename}")
|
print(f"writing config to file {filename}")
|
||||||
existing_content = {}
|
existing_content = {}
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
@ -128,27 +123,48 @@ def run_grid(bs, method):
|
|||||||
|
|
||||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
||||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
||||||
config) -> float:
|
config, dtype: str) -> float:
|
||||||
shard_intermediate_size = model_intermediate_size // tp_size
|
shard_intermediate_size = model_intermediate_size // tp_size
|
||||||
|
|
||||||
hidden_states = torch.rand(
|
hidden_states = torch.rand(
|
||||||
(bs, d_model),
|
(bs, d_model),
|
||||||
device="cuda:0",
|
device="cuda:0",
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
ws = torch.rand(
|
w1 = torch.rand(
|
||||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
w2s = torch.rand(
|
w2 = torch.rand(
|
||||||
(num_total_experts, d_model, shard_intermediate_size),
|
(num_total_experts, d_model, shard_intermediate_size),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
|
||||||
|
if dtype == "float8":
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.ones(num_total_experts,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.ones(num_total_experts,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
a1_scale = torch.ones(1,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
a2_scale = torch.ones(1,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
gating_output = F.softmax(torch.rand(
|
gating_output = F.softmax(torch.rand(
|
||||||
(num_calls, bs, num_total_experts),
|
(num_calls, bs, num_total_experts),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|||||||
for i in range(num_calls):
|
for i in range(num_calls):
|
||||||
hidden_states = method(
|
hidden_states = method(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=ws,
|
w1=w1,
|
||||||
w2=w2s,
|
w2=w2,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
gating_output=gating_output[i],
|
gating_output=gating_output[i],
|
||||||
topk=2,
|
topk=2,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
override_config=config,
|
override_config=config,
|
||||||
|
use_fp8=dtype == "float8",
|
||||||
)
|
)
|
||||||
end_event.record()
|
end_event.record()
|
||||||
end_event.synchronize()
|
end_event.synchronize()
|
||||||
@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='benchmark_mixtral_moe',
|
||||||
|
description='Benchmark and tune the fused_moe kernel',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--dtype',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
choices=['float8', 'float16'],
|
||||||
|
help='Data type used for fused_moe kernel computations',
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
sys.exit(main(args.dtype))
|
||||||
|
@ -0,0 +1,140 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user