[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
import argparse
|
2024-02-26 13:48:56 -08:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import triton
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
from tqdm import tqdm
|
2024-02-26 13:48:56 -08:00
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|
|
|
get_config_file_name)
|
|
|
|
|
2024-03-14 01:11:48 -07:00
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
|
|
|
2024-02-26 13:48:56 -08:00
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
def main(dtype: str):
|
2024-02-26 13:48:56 -08:00
|
|
|
method = fused_moe
|
|
|
|
for bs in [
|
|
|
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
|
|
2048, 3072, 4096
|
|
|
|
]:
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
run_grid(bs, method=method, dtype=dtype)
|
2024-02-26 13:48:56 -08:00
|
|
|
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
def run_grid(bs, method, dtype: str):
|
2024-02-26 13:48:56 -08:00
|
|
|
d_model = 4096
|
|
|
|
num_total_experts = 8
|
|
|
|
top_k = 2
|
|
|
|
tp_size = 2
|
|
|
|
model_intermediate_size = 14336
|
|
|
|
num_layers = 32
|
|
|
|
num_calls = 100
|
|
|
|
|
|
|
|
num_warmup_trials = 1
|
|
|
|
num_trials = 1
|
|
|
|
|
|
|
|
configs = []
|
|
|
|
|
|
|
|
for block_size_n in [32, 64, 128, 256]:
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
for block_size_m in [16, 32, 64, 128, 256]:
|
2024-02-26 13:48:56 -08:00
|
|
|
for block_size_k in [64, 128, 256]:
|
|
|
|
for group_size_m in [1, 16, 32, 64]:
|
|
|
|
for num_warps in [4, 8]:
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
for num_stages in [2, 3, 4, 5]:
|
|
|
|
configs.append({
|
|
|
|
"BLOCK_SIZE_M": block_size_m,
|
|
|
|
"BLOCK_SIZE_N": block_size_n,
|
|
|
|
"BLOCK_SIZE_K": block_size_k,
|
|
|
|
"GROUP_SIZE_M": group_size_m,
|
|
|
|
"num_warps": num_warps,
|
|
|
|
"num_stages": num_stages,
|
|
|
|
})
|
2024-02-26 13:48:56 -08:00
|
|
|
|
|
|
|
best_config = None
|
|
|
|
best_time_us = 1e20
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
print(f'{tp_size=} {bs=}')
|
|
|
|
|
|
|
|
for config in tqdm(configs):
|
2024-02-26 13:48:56 -08:00
|
|
|
# warmup
|
|
|
|
try:
|
|
|
|
for _ in range(num_warmup_trials):
|
|
|
|
run_timing(
|
|
|
|
num_calls=num_calls,
|
|
|
|
bs=bs,
|
|
|
|
d_model=d_model,
|
|
|
|
num_total_experts=num_total_experts,
|
|
|
|
top_k=top_k,
|
|
|
|
tp_size=tp_size,
|
|
|
|
model_intermediate_size=model_intermediate_size,
|
|
|
|
method=method,
|
|
|
|
config=config,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
dtype=dtype,
|
2024-02-26 13:48:56 -08:00
|
|
|
)
|
|
|
|
except triton.runtime.autotuner.OutOfResources:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# trial
|
|
|
|
for _ in range(num_trials):
|
|
|
|
kernel_dur_ms = run_timing(
|
|
|
|
num_calls=num_calls,
|
|
|
|
bs=bs,
|
|
|
|
d_model=d_model,
|
|
|
|
num_total_experts=num_total_experts,
|
|
|
|
top_k=top_k,
|
|
|
|
tp_size=tp_size,
|
|
|
|
model_intermediate_size=model_intermediate_size,
|
|
|
|
method=method,
|
|
|
|
config=config,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
dtype=dtype,
|
2024-02-26 13:48:56 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
kernel_dur_us = 1000 * kernel_dur_ms
|
|
|
|
model_dur_ms = kernel_dur_ms * num_layers
|
|
|
|
|
|
|
|
if kernel_dur_us < best_time_us:
|
|
|
|
best_config = config
|
|
|
|
best_time_us = kernel_dur_us
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
tqdm.write(
|
|
|
|
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
|
|
|
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
|
|
|
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
2024-02-26 13:48:56 -08:00
|
|
|
|
|
|
|
print("best_time_us", best_time_us)
|
|
|
|
print("best_config", best_config)
|
|
|
|
|
2024-03-14 01:11:48 -07:00
|
|
|
# holds Dict[str, Dict[str, int]]
|
|
|
|
filename = get_config_file_name(num_total_experts,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
model_intermediate_size // tp_size,
|
|
|
|
"float8" if dtype == "float8" else None)
|
2024-02-26 13:48:56 -08:00
|
|
|
print(f"writing config to file {filename}")
|
2024-03-14 01:11:48 -07:00
|
|
|
existing_content = {}
|
|
|
|
if os.path.exists(filename):
|
|
|
|
with open(filename, "r") as f:
|
|
|
|
existing_content = json.load(f)
|
|
|
|
existing_content[str(bs)] = best_config
|
|
|
|
with open(filename, "w") as f:
|
|
|
|
json.dump(existing_content, f, indent=4)
|
|
|
|
f.write("\n")
|
2024-02-26 13:48:56 -08:00
|
|
|
|
|
|
|
|
|
|
|
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|
|
|
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
config, dtype: str) -> float:
|
2024-02-26 13:48:56 -08:00
|
|
|
shard_intermediate_size = model_intermediate_size // tp_size
|
|
|
|
|
|
|
|
hidden_states = torch.rand(
|
|
|
|
(bs, d_model),
|
|
|
|
device="cuda:0",
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
dtype=torch.float16,
|
2024-02-26 13:48:56 -08:00
|
|
|
)
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
w1 = torch.rand(
|
2024-02-26 13:48:56 -08:00
|
|
|
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
)
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
w2 = torch.rand(
|
2024-02-26 13:48:56 -08:00
|
|
|
(num_total_experts, d_model, shard_intermediate_size),
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
)
|
|
|
|
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
w1_scale = None
|
|
|
|
w2_scale = None
|
|
|
|
a1_scale = None
|
|
|
|
a2_scale = None
|
|
|
|
|
|
|
|
if dtype == "float8":
|
|
|
|
w1 = w1.to(torch.float8_e4m3fn)
|
|
|
|
w2 = w2.to(torch.float8_e4m3fn)
|
|
|
|
w1_scale = torch.ones(num_total_experts,
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=torch.float32)
|
|
|
|
w2_scale = torch.ones(num_total_experts,
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=torch.float32)
|
|
|
|
a1_scale = torch.ones(1,
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=torch.float32)
|
|
|
|
a2_scale = torch.ones(1,
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=torch.float32)
|
|
|
|
|
2024-02-26 13:48:56 -08:00
|
|
|
gating_output = F.softmax(torch.rand(
|
|
|
|
(num_calls, bs, num_total_experts),
|
|
|
|
device=hidden_states.device,
|
|
|
|
dtype=torch.float32,
|
|
|
|
),
|
|
|
|
dim=-1)
|
|
|
|
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
|
|
|
|
start_event.record()
|
|
|
|
for i in range(num_calls):
|
|
|
|
hidden_states = method(
|
|
|
|
hidden_states=hidden_states,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
w1=w1,
|
|
|
|
w2=w2,
|
|
|
|
w1_scale=w1_scale,
|
|
|
|
w2_scale=w2_scale,
|
|
|
|
a1_scale=a1_scale,
|
|
|
|
a2_scale=a2_scale,
|
2024-02-26 13:48:56 -08:00
|
|
|
gating_output=gating_output[i],
|
|
|
|
topk=2,
|
|
|
|
renormalize=True,
|
|
|
|
inplace=True,
|
|
|
|
override_config=config,
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
use_fp8=dtype == "float8",
|
2024-02-26 13:48:56 -08:00
|
|
|
)
|
|
|
|
end_event.record()
|
|
|
|
end_event.synchronize()
|
|
|
|
|
|
|
|
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
|
|
return dur_ms
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
[Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.
All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.
Before this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency
After this PR (with static activation scaling):
qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
prog='benchmark_mixtral_moe',
|
|
|
|
description='Benchmark and tune the fused_moe kernel',
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--dtype',
|
|
|
|
type=str,
|
|
|
|
default='auto',
|
|
|
|
choices=['float8', 'float16'],
|
|
|
|
help='Data type used for fused_moe kernel computations',
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
sys.exit(main(args.dtype))
|