From cfc15a1031ef0197a1b291d2ed93717a9bdad268 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 26 Feb 2024 13:48:56 -0800 Subject: [PATCH] Optimize Triton MoE Kernel (#2979) Co-authored-by: Cade Daniel --- benchmarks/kernels/benchmark_mixtral_moe.py | 172 ++++++++++++++++++ setup.py | 4 +- .../layers/fused_moe/__init__.py | 5 + ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 20 ++ ...168,device_name=NVIDIA_H100_80GB_HBM3.json | 24 +++ .../layers/fused_moe/configs/README | 10 + .../layers/{ => fused_moe}/fused_moe.py | 75 ++++++-- 7 files changed, 296 insertions(+), 14 deletions(-) create mode 100644 benchmarks/kernels/benchmark_mixtral_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/__init__.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/README rename vllm/model_executor/layers/{ => fused_moe}/fused_moe.py (85%) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py new file mode 100644 index 00000000..9e08df76 --- /dev/null +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -0,0 +1,172 @@ +import json +import os +import sys + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from vllm.model_executor.layers.fused_moe import fused_moe +import torch +import torch.nn.functional as F +import triton + + +def main(): + method = fused_moe + for bs in [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ]: + run_grid(bs, method=method) + + +def run_grid(bs, method): + d_model = 4096 + num_total_experts = 8 + top_k = 2 + tp_size = 2 + model_intermediate_size = 14336 + num_layers = 32 + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + configs = [] + if bs <= 16: + BLOCK_SIZES_M = [16] + elif bs <= 32: + BLOCK_SIZES_M = [16, 32] + elif bs <= 64: + BLOCK_SIZES_M = [16, 32, 64] + elif bs <= 128: + BLOCK_SIZES_M = [16, 32, 64, 128] + else: + BLOCK_SIZES_M = [16, 32, 64, 128, 256] + + for block_size_n in [32, 64, 128, 256]: + for block_size_m in BLOCK_SIZES_M: + for block_size_k in [64, 128, 256]: + for group_size_m in [1, 16, 32, 64]: + for num_warps in [4, 8]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": 4, + }) + + best_config = None + best_time_us = 1e20 + + for config in configs: + print(f'{tp_size=} {bs=}') + print(f'{config}') + # warmup + print(f'warming up') + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # trial + print(f'benchmarking') + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + print( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}' + ) + + print("best_time_us", best_time_us) + print("best_config", best_config) + + filename = "/tmp/config.jsonl" + print(f"writing config to file {filename}") + with open(filename, "a") as f: + f.write(json.dumps({str(bs): best_config}) + "\n") + + +def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, + top_k: int, tp_size: int, model_intermediate_size: int, method, + config) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=torch.bfloat16, + ) + + ws = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2s = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + gating_output = F.softmax(torch.rand( + (num_calls, bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + hidden_states = method( + hidden_states=hidden_states, + w1=ws, + w2=w2s, + gating_output=gating_output[i], + topk=2, + renormalize=True, + inplace=True, + override_config=config, + ) + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/setup.py b/setup.py index 8fcb8639..16978d74 100644 --- a/setup.py +++ b/setup.py @@ -432,7 +432,9 @@ def get_requirements() -> List[str]: return requirements -package_data = {"vllm": ["py.typed"]} +package_data = { + "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] +} if os.environ.get("VLLM_USE_PRECOMPILED"): ext_modules = [] package_data["vllm"].append("*.so") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 00000000..1391d43c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,5 @@ +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe + +__all__ = [ + "fused_moe", +] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 00000000..1fefb5ff --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, + "4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, + "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, + "16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, + "24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, + "192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4}, + "1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, + "2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, + "3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}, + "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4} +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 00000000..64d49ca6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,24 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4}, + "16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, + "24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, + "32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, + "200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, + "208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, + "216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, + "224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, + "512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, + "4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4} +} diff --git a/vllm/model_executor/layers/fused_moe/configs/README b/vllm/model_executor/layers/fused_moe/configs/README new file mode 100644 index 00000000..45d40cbf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/README @@ -0,0 +1,10 @@ +This directory contains tuned configurations for different settings of the fused_moe kernel. +For different settings of +- E (number of experts) +- N (intermediate size) +- device_name (torch.cuda.get_device_name()) +the JSON file contains a mapping from M (batch size) to the chosen configuration. + +The example configurations provided are for the Mixtral model for TP2 on H100 +and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have +N = 7168 and for TP4 we have N = 3584. diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py similarity index 85% rename from vllm/model_executor/layers/fused_moe.py rename to vllm/model_executor/layers/fused_moe/fused_moe.py index bc3aef18..830fde6c 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,11 +1,19 @@ """Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Dict, Optional + import torch import triton import triton.language as tl from vllm._C import ops +from vllm.logger import init_logger from vllm.utils import is_hip +logger = init_logger(__name__) + @triton.jit def fused_moe_kernel( @@ -210,6 +218,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) +@functools.lru_cache +def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of batch sizes + to configurations of the fused_moe kernel. To evaluate the kernel on a given batch + size bs, the closest batch size in the grid should be picked and the associated + configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs directory + device_name = torch.cuda.get_device_name().replace(" ", "_") + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", + f"E={E},N={N},device_name={device_name}.json") + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + f"Using configuration from {config_file_path} for MoE layer.") + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default configuration + return None + + def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -218,6 +254,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -230,6 +267,7 @@ def fused_moe( - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -279,20 +317,31 @@ def fused_moe( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - } + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2]) - if topk_ids.numel() <= w1.shape[0]: - config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 - } + if configs: + # If an optimal configuration map has been found, look up the optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if M <= E: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device,