Optimize Triton MoE Kernel (#2979)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
70f3e8e3a1
commit
cfc15a1031
172
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
172
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
method = fused_moe
|
||||||
|
for bs in [
|
||||||
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
|
2048, 3072, 4096
|
||||||
|
]:
|
||||||
|
run_grid(bs, method=method)
|
||||||
|
|
||||||
|
|
||||||
|
def run_grid(bs, method):
|
||||||
|
d_model = 4096
|
||||||
|
num_total_experts = 8
|
||||||
|
top_k = 2
|
||||||
|
tp_size = 2
|
||||||
|
model_intermediate_size = 14336
|
||||||
|
num_layers = 32
|
||||||
|
num_calls = 100
|
||||||
|
|
||||||
|
num_warmup_trials = 1
|
||||||
|
num_trials = 1
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
if bs <= 16:
|
||||||
|
BLOCK_SIZES_M = [16]
|
||||||
|
elif bs <= 32:
|
||||||
|
BLOCK_SIZES_M = [16, 32]
|
||||||
|
elif bs <= 64:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64]
|
||||||
|
elif bs <= 128:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64, 128]
|
||||||
|
else:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for block_size_n in [32, 64, 128, 256]:
|
||||||
|
for block_size_m in BLOCK_SIZES_M:
|
||||||
|
for block_size_k in [64, 128, 256]:
|
||||||
|
for group_size_m in [1, 16, 32, 64]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
|
"GROUP_SIZE_M": group_size_m,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": 4,
|
||||||
|
})
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time_us = 1e20
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
print(f'{tp_size=} {bs=}')
|
||||||
|
print(f'{config}')
|
||||||
|
# warmup
|
||||||
|
print(f'warming up')
|
||||||
|
try:
|
||||||
|
for _ in range(num_warmup_trials):
|
||||||
|
run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# trial
|
||||||
|
print(f'benchmarking')
|
||||||
|
for _ in range(num_trials):
|
||||||
|
kernel_dur_ms = run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_dur_us = 1000 * kernel_dur_ms
|
||||||
|
model_dur_ms = kernel_dur_ms * num_layers
|
||||||
|
|
||||||
|
if kernel_dur_us < best_time_us:
|
||||||
|
best_config = config
|
||||||
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}'
|
||||||
|
)
|
||||||
|
|
||||||
|
print("best_time_us", best_time_us)
|
||||||
|
print("best_config", best_config)
|
||||||
|
|
||||||
|
filename = "/tmp/config.jsonl"
|
||||||
|
print(f"writing config to file {filename}")
|
||||||
|
with open(filename, "a") as f:
|
||||||
|
f.write(json.dumps({str(bs): best_config}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
||||||
|
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
||||||
|
config) -> float:
|
||||||
|
shard_intermediate_size = model_intermediate_size // tp_size
|
||||||
|
|
||||||
|
hidden_states = torch.rand(
|
||||||
|
(bs, d_model),
|
||||||
|
device="cuda:0",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = torch.rand(
|
||||||
|
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
w2s = torch.rand(
|
||||||
|
(num_total_experts, d_model, shard_intermediate_size),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
gating_output = F.softmax(torch.rand(
|
||||||
|
(num_calls, bs, num_total_experts),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
for i in range(num_calls):
|
||||||
|
hidden_states = method(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=ws,
|
||||||
|
w2=w2s,
|
||||||
|
gating_output=gating_output[i],
|
||||||
|
topk=2,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
)
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
|
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||||
|
return dur_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
4
setup.py
4
setup.py
@ -432,7 +432,9 @@ def get_requirements() -> List[str]:
|
|||||||
return requirements
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
package_data = {"vllm": ["py.typed"]}
|
package_data = {
|
||||||
|
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
|
||||||
|
}
|
||||||
if os.environ.get("VLLM_USE_PRECOMPILED"):
|
if os.environ.get("VLLM_USE_PRECOMPILED"):
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
package_data["vllm"].append("*.so")
|
package_data["vllm"].append("*.so")
|
||||||
|
5
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
5
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"fused_moe",
|
||||||
|
]
|
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
|
||||||
|
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
|
||||||
|
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
|
||||||
|
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
|
||||||
|
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
|
||||||
|
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
|
||||||
|
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
|
||||||
|
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
|
||||||
|
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
|
||||||
|
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
|
||||||
|
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
|
||||||
|
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
|
||||||
|
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
|
||||||
|
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
|
||||||
|
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
|
||||||
|
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
|
||||||
|
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
|
||||||
|
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
|
||||||
|
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
|
||||||
|
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
|
||||||
|
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
|
||||||
|
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
|
||||||
|
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
|
||||||
|
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
|
||||||
|
}
|
10
vllm/model_executor/layers/fused_moe/configs/README
Normal file
10
vllm/model_executor/layers/fused_moe/configs/README
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
This directory contains tuned configurations for different settings of the fused_moe kernel.
|
||||||
|
For different settings of
|
||||||
|
- E (number of experts)
|
||||||
|
- N (intermediate size)
|
||||||
|
- device_name (torch.cuda.get_device_name())
|
||||||
|
the JSON file contains a mapping from M (batch size) to the chosen configuration.
|
||||||
|
|
||||||
|
The example configurations provided are for the Mixtral model for TP2 on H100
|
||||||
|
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
|
||||||
|
N = 7168 and for TP4 we have N = 3584.
|
@ -1,11 +1,19 @@
|
|||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_kernel(
|
def fused_moe_kernel(
|
||||||
@ -210,6 +218,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
|
||||||
|
"""
|
||||||
|
Return optimized configurations for the fused MoE kernel.
|
||||||
|
|
||||||
|
The return value will be a dictionary that maps an irregular grid of batch sizes
|
||||||
|
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
|
||||||
|
size bs, the closest batch size in the grid should be picked and the associated
|
||||||
|
configuration chosen to invoke the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First look up if an optimized configuration is available in the configs directory
|
||||||
|
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
||||||
|
|
||||||
|
config_file_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.realpath(__file__)), "configs",
|
||||||
|
f"E={E},N={N},device_name={device_name}.json")
|
||||||
|
if os.path.exists(config_file_path):
|
||||||
|
with open(config_file_path) as f:
|
||||||
|
logger.info(
|
||||||
|
f"Using configuration from {config_file_path} for MoE layer.")
|
||||||
|
# If a configuration has been found, return it
|
||||||
|
return {int(key): val for key, val in json.load(f).items()}
|
||||||
|
|
||||||
|
# If no optimized configuration is available, we will use the default configuration
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
def fused_moe(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -218,6 +254,7 @@ def fused_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
|
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
|
||||||
@ -230,6 +267,7 @@ def fused_moe(
|
|||||||
- topk (int): The number of top-k experts to select.
|
- topk (int): The number of top-k experts to select.
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
- inplace (bool): If True, perform the operation in-place. Defaults to False.
|
- inplace (bool): If True, perform the operation in-place. Defaults to False.
|
||||||
|
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -279,20 +317,31 @@ def fused_moe(
|
|||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
config = {
|
if override_config:
|
||||||
'BLOCK_SIZE_M': 64,
|
config = override_config
|
||||||
'BLOCK_SIZE_N': 64,
|
else:
|
||||||
'BLOCK_SIZE_K': 32,
|
# First try to load optimal config from the file
|
||||||
'GROUP_SIZE_M': 8
|
configs = get_moe_configs(E, w2.shape[2])
|
||||||
}
|
|
||||||
|
|
||||||
if topk_ids.numel() <= w1.shape[0]:
|
if configs:
|
||||||
config = {
|
# If an optimal configuration map has been found, look up the optimal config
|
||||||
'BLOCK_SIZE_M': 16,
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
'BLOCK_SIZE_N': 32,
|
else:
|
||||||
'BLOCK_SIZE_K': 64,
|
# Else use the default config
|
||||||
'GROUP_SIZE_M': 1
|
config = {
|
||||||
}
|
'BLOCK_SIZE_M': 64,
|
||||||
|
'BLOCK_SIZE_N': 64,
|
||||||
|
'BLOCK_SIZE_K': 32,
|
||||||
|
'GROUP_SIZE_M': 8
|
||||||
|
}
|
||||||
|
|
||||||
|
if M <= E:
|
||||||
|
config = {
|
||||||
|
'BLOCK_SIZE_M': 16,
|
||||||
|
'BLOCK_SIZE_N': 32,
|
||||||
|
'BLOCK_SIZE_K': 64,
|
||||||
|
'GROUP_SIZE_M': 1
|
||||||
|
}
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
Loading…
x
Reference in New Issue
Block a user