Optimize Triton MoE Kernel (#2979)

Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
Philipp Moritz 2024-02-26 13:48:56 -08:00 committed by GitHub
parent 70f3e8e3a1
commit cfc15a1031
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 296 additions and 14 deletions

View File

@ -0,0 +1,172 @@
import json
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from vllm.model_executor.layers.fused_moe import fused_moe
import torch
import torch.nn.functional as F
import triton
def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)
def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100
num_warmup_trials = 1
num_trials = 1
configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})
best_config = None
best_time_us = 1e20
for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print(f'warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
print(f'benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
print(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}'
)
print("best_time_us", best_time_us)
print("best_config", best_config)
filename = "/tmp/config.jsonl"
print(f"writing config to file {filename}")
with open(filename, "a") as f:
f.write(json.dumps({str(bs): best_config}) + "\n")
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
)
ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=ws,
w2=w2s,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@ -432,7 +432,9 @@ def get_requirements() -> List[str]:
return requirements return requirements
package_data = {"vllm": ["py.typed"]} package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}
if os.environ.get("VLLM_USE_PRECOMPILED"): if os.environ.get("VLLM_USE_PRECOMPILED"):
ext_modules = [] ext_modules = []
package_data["vllm"].append("*.so") package_data["vllm"].append("*.so")

View File

@ -0,0 +1,5 @@
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
__all__ = [
"fused_moe",
]

View File

@ -0,0 +1,20 @@
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}

View File

@ -0,0 +1,24 @@
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
}

View File

@ -0,0 +1,10 @@
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.

View File

@ -1,11 +1,19 @@
"""Fused MoE kernel.""" """Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm._C import ops from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
logger = init_logger(__name__)
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
@ -210,6 +218,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
) )
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name = torch.cuda.get_device_name().replace(" ", "_")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs",
f"E={E},N={N},device_name={device_name}.json")
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration
return None
def fused_moe( def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@ -218,6 +254,7 @@ def fused_moe(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
@ -230,6 +267,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
@ -279,20 +317,31 @@ def fused_moe(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = { if override_config:
'BLOCK_SIZE_M': 64, config = override_config
'BLOCK_SIZE_N': 64, else:
'BLOCK_SIZE_K': 32, # First try to load optimal config from the file
'GROUP_SIZE_M': 8 configs = get_moe_configs(E, w2.shape[2])
}
if topk_ids.numel() <= w1.shape[0]: if configs:
config = { # If an optimal configuration map has been found, look up the optimal config
'BLOCK_SIZE_M': 16, config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
'BLOCK_SIZE_N': 32, else:
'BLOCK_SIZE_K': 64, # Else use the default config
'GROUP_SIZE_M': 1 config = {
} 'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,