605 lines
22 KiB
Python
605 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from contextlib import nullcontext
|
|
from datetime import datetime
|
|
from itertools import product
|
|
from typing import Any, TypedDict
|
|
|
|
import ray
|
|
import torch
|
|
import triton
|
|
from ray.experimental.tqdm_ray import tqdm
|
|
from transformers import AutoConfig
|
|
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
|
|
class BenchmarkConfig(TypedDict):
|
|
BLOCK_SIZE_M: int
|
|
BLOCK_SIZE_N: int
|
|
BLOCK_SIZE_K: int
|
|
GROUP_SIZE_M: int
|
|
num_warps: int
|
|
num_stages: int
|
|
|
|
|
|
def benchmark_config(
|
|
config: BenchmarkConfig,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
num_iters: int = 100,
|
|
block_quant_shape: List[int] = None,
|
|
) -> float:
|
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
|
if use_int8_w8a16:
|
|
w1 = torch.randint(-127,
|
|
127, (
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
),
|
|
dtype=torch.int8)
|
|
w2 = torch.randint(-127,
|
|
127, (
|
|
num_experts,
|
|
hidden_size,
|
|
shard_intermediate_size // 2,
|
|
),
|
|
dtype=torch.int8)
|
|
else:
|
|
w1 = torch.randn(num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
dtype=init_dtype)
|
|
w2 = torch.randn(num_experts,
|
|
hidden_size,
|
|
shard_intermediate_size // 2,
|
|
dtype=init_dtype)
|
|
gating_output = torch.randn(num_iters,
|
|
num_tokens,
|
|
num_experts,
|
|
dtype=torch.float32)
|
|
|
|
w1_scale = None
|
|
w2_scale = None
|
|
a1_scale = None
|
|
a2_scale = None
|
|
if use_int8_w8a16:
|
|
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
|
dtype=torch.float32)
|
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
|
if use_fp8_w8a8:
|
|
if block_quant_shape:
|
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
|
E = num_experts
|
|
N = shard_intermediate_size // 2
|
|
K = hidden_size
|
|
factor_for_scale = 1e-2
|
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
|
|
dtype=torch.float32) * factor_for_scale
|
|
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
|
|
dtype=torch.float32) * factor_for_scale
|
|
else:
|
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
|
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
|
|
|
w1 = w1.to(FP8_DTYPE)
|
|
w2 = w2.to(FP8_DTYPE)
|
|
|
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
|
|
|
def prepare(i: int):
|
|
input_gating.copy_(gating_output[i])
|
|
|
|
def run():
|
|
from vllm.model_executor.layers.fused_moe import override_config
|
|
with override_config(config):
|
|
fused_moe(
|
|
x,
|
|
w1,
|
|
w2,
|
|
input_gating,
|
|
topk,
|
|
renormalize=True,
|
|
inplace=True,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
block_shape=block_quant_shape,
|
|
)
|
|
|
|
# JIT compilation & warmup
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Capture 10 invocations with CUDA graph
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
for _ in range(10):
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Warmup
|
|
for _ in range(5):
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
latencies: list[float] = []
|
|
for i in range(num_iters):
|
|
prepare(i)
|
|
torch.cuda.synchronize()
|
|
|
|
start_event.record()
|
|
graph.replay()
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
latencies.append(start_event.elapsed_time(end_event))
|
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
|
graph.reset()
|
|
return avg
|
|
|
|
|
|
def get_rocm_tuning_space(use_fp16):
|
|
block_mn_range = [16, 32, 64, 128, 256]
|
|
block_k_range = [16, 32, 64, 128, 256]
|
|
if not use_fp16:
|
|
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
|
num_warps_range = [1, 2, 4, 8]
|
|
group_m_range = [1, 4, 8, 16, 32]
|
|
num_stage_range = [2]
|
|
waves_per_eu_range = [0]
|
|
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
|
kpack_range = [1, 2] if use_fp16 else []
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_mn_range,
|
|
"BLOCK_SIZE_N": block_mn_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
"waves_per_eu": waves_per_eu_range,
|
|
}
|
|
if use_fp16:
|
|
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
|
param_ranges["kpack"] = kpack_range
|
|
|
|
return param_ranges
|
|
|
|
|
|
def get_configs_compute_bound(use_fp16,
|
|
block_quant_shape) -> list[dict[str, int]]:
|
|
configs: list[BenchmarkConfig] = []
|
|
|
|
if current_platform.is_rocm():
|
|
param_ranges = get_rocm_tuning_space(use_fp16)
|
|
else:
|
|
# Reduced search space for faster tuning.
|
|
# TODO(woosuk): Increase the search space and use a performance model to
|
|
# prune the search space.
|
|
block_m_range = [16, 32, 64, 128, 256]
|
|
block_n_range = [32, 64, 128, 256]
|
|
block_k_range = [64, 128, 256]
|
|
num_warps_range = [4, 8]
|
|
group_m_range = [1, 16, 32, 64]
|
|
num_stage_range = [2, 3, 4, 5]
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_m_range,
|
|
"BLOCK_SIZE_N": block_n_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
}
|
|
|
|
keys, values = zip(*param_ranges.items())
|
|
for config_values in product(*values):
|
|
config = dict(zip(keys, config_values))
|
|
configs.append(config)
|
|
|
|
# Remove configs that are not compatible with fp8 block quantization
|
|
# BLOCK_SIZE_K must be a multiple of block_k
|
|
# BLOCK_SIZE_N must be a multiple of block_n
|
|
if block_quant_shape is not None and not use_fp16:
|
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
|
for config in configs[:]:
|
|
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
|
|
"BLOCK_SIZE_N"] % block_n != 0:
|
|
configs.remove(config)
|
|
return configs
|
|
|
|
|
|
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
|
search_space, is_fp16, topk):
|
|
N1, K1 = shard_intermediate_size, hidden_size
|
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
|
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
|
|
search_space, is_fp16)
|
|
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
|
|
search_space, is_fp16)
|
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
|
return search_space
|
|
|
|
|
|
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
|
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
|
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
|
pruned_configs = []
|
|
elemBytes_a = 2 if is_fp16 else 1
|
|
elemBytes_b = 2 if is_fp16 else 1
|
|
|
|
mfma = 16 if M < 32 or N < 32 else 32
|
|
|
|
# TODO (zhanglx): figure out the boundary between large and small gemms
|
|
large_gemm = False
|
|
if M >= 2048 and N >= 2048:
|
|
large_gemm = True
|
|
|
|
for config in configs:
|
|
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
|
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
|
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
|
num_warps = config.get("num_warps")
|
|
|
|
if is_fp16:
|
|
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
|
if matrix_instr_nonkdim > mfma:
|
|
continue
|
|
if mfma == 4 and BLOCK_SIZE_K < 64:
|
|
continue
|
|
# some layouts could not work properly in case
|
|
# number elements per thread is less 1
|
|
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
|
continue
|
|
SPLIT_K = config.get("SPLIT_K", 1)
|
|
GROUP_M = config.get("GROUP_SIZE_M")
|
|
if is_fp16:
|
|
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
|
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
|
continue
|
|
if (matrix_instr_nonkdim >= M
|
|
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
|
continue
|
|
if (matrix_instr_nonkdim >= N
|
|
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
|
continue
|
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
|
# unless BLOCK_SIZE is already small enough
|
|
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
|
continue
|
|
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
|
continue
|
|
# skip large split_k when not necessary
|
|
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
|
continue
|
|
# skip split_k that leads to EVEN_K = false
|
|
leap = SPLIT_K * BLOCK_SIZE_K
|
|
modv = K % leap
|
|
if modv != 0:
|
|
continue
|
|
# skip large GROUP_M
|
|
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
|
continue
|
|
# out of shared memory resource
|
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
|
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
|
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
|
if LDS > 65536:
|
|
continue
|
|
# Skip small block sizes and num_warps for large gemm
|
|
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
|
if large_gemm:
|
|
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
|
continue
|
|
if BLOCK_SIZE_K < 64:
|
|
continue
|
|
if num_warps < 4:
|
|
continue
|
|
|
|
pruned_configs.append(config)
|
|
|
|
return pruned_configs
|
|
|
|
|
|
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
|
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
|
|
|
|
|
def merge_unique_dicts(list1, list2):
|
|
result = []
|
|
combined_list = list1.copy()
|
|
combined_list.extend(list2)
|
|
for dictionary in combined_list:
|
|
if dictionary not in result:
|
|
result.append(dictionary)
|
|
return result
|
|
|
|
|
|
@ray.remote(num_gpus=1)
|
|
class BenchmarkWorker:
|
|
|
|
def __init__(self, seed: int) -> None:
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(seed)
|
|
self.seed = seed
|
|
# Get the device ID to allocate tensors and kernels
|
|
# on the respective GPU. This is required for Ray to work
|
|
# correctly with multi-GPU tuning on the ROCm platform.
|
|
self.device_id = int(ray.get_gpu_ids()[0])
|
|
|
|
def benchmark(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
) -> tuple[dict[str, int], float]:
|
|
current_platform.seed_everything(self.seed)
|
|
dtype_str = get_config_dtype_str(dtype,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_fp8_w8a8=use_fp8_w8a8)
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
|
dtype_str)
|
|
if op_config is None:
|
|
config = get_default_config(num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype_str,
|
|
is_marlin=False)
|
|
else:
|
|
config = op_config[min(op_config.keys(),
|
|
key=lambda x: abs(x - num_tokens))]
|
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
|
shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8,
|
|
use_int8_w8a16)
|
|
return config, kernel_time
|
|
|
|
def tune(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
search_space: list[dict[str, int]],
|
|
block_quant_shape: list[int],
|
|
) -> dict[str, int]:
|
|
best_config = None
|
|
best_time = float("inf")
|
|
if current_platform.is_rocm():
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = prune_rocm_search_space(num_tokens,
|
|
shard_intermediate_size,
|
|
hidden_size, search_space,
|
|
is_fp16, topk)
|
|
|
|
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
|
) else nullcontext():
|
|
for config in tqdm(search_space):
|
|
try:
|
|
kernel_time = benchmark_config(
|
|
config,
|
|
num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
num_iters=20,
|
|
block_quant_shape=block_quant_shape)
|
|
except triton.runtime.autotuner.OutOfResources:
|
|
# Some configurations may be invalid and fail to compile.
|
|
continue
|
|
|
|
if kernel_time < best_time:
|
|
best_time = kernel_time
|
|
best_config = config
|
|
now = datetime.now()
|
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
|
assert best_config is not None
|
|
return best_config
|
|
|
|
|
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|
return {
|
|
"BLOCK_SIZE_M":
|
|
config["BLOCK_SIZE_M"],
|
|
"BLOCK_SIZE_N":
|
|
config["BLOCK_SIZE_N"],
|
|
"BLOCK_SIZE_K":
|
|
config["BLOCK_SIZE_K"],
|
|
"GROUP_SIZE_M":
|
|
config["GROUP_SIZE_M"],
|
|
"num_warps":
|
|
config["num_warps"],
|
|
"num_stages":
|
|
config["num_stages"],
|
|
**({
|
|
"waves_per_eu": config["waves_per_eu"]
|
|
} if "waves_per_eu" in config else {}),
|
|
**({
|
|
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
|
} if "matrix_instr_nonkdim" in config else {}),
|
|
**({
|
|
"kpack": config["kpack"]
|
|
} if "kpack" in config else {}),
|
|
}
|
|
|
|
|
|
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
|
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
|
|
block_quant_shape: List[int]) -> None:
|
|
dtype_str = get_config_dtype_str(dtype,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_fp8_w8a8=use_fp8_w8a8)
|
|
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
|
dtype_str, block_quant_shape)
|
|
|
|
print(f"Writing best config to {filename}...")
|
|
with open(filename, "w") as f:
|
|
json.dump(configs, f, indent=4)
|
|
f.write("\n")
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
print(args)
|
|
block_quant_shape = None
|
|
config = AutoConfig.from_pretrained(
|
|
args.model, trust_remote_code=args.trust_remote_code)
|
|
if config.architectures[0] == "DbrxForCausalLM":
|
|
E = config.ffn_config.moe_num_experts
|
|
topk = config.ffn_config.moe_top_k
|
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
elif config.architectures[0] == "JambaForCausalLM":
|
|
E = config.num_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
|
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
|
E = config.n_routed_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.moe_intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
block_quant_shape = config.quantization_config['weight_block_size']
|
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
|
E = config.num_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.moe_intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
else:
|
|
# Default: Mixtral.
|
|
E = config.num_local_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
|
|
hidden_size = config.hidden_size
|
|
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
|
|
|
if args.batch_size is None:
|
|
batch_sizes = [
|
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
2048, 3072, 4096
|
|
]
|
|
else:
|
|
batch_sizes = [args.batch_size]
|
|
|
|
ray.init()
|
|
num_gpus = int(ray.available_resources()["GPU"])
|
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
|
|
|
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
|
outputs = []
|
|
worker_idx = 0
|
|
for input_args in inputs:
|
|
worker = workers[worker_idx]
|
|
worker_method = getattr(worker, method)
|
|
output = worker_method.remote(*input_args)
|
|
outputs.append(output)
|
|
worker_idx = (worker_idx + 1) % num_gpus
|
|
return ray.get(outputs)
|
|
|
|
if args.tune:
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
|
print(f"Start tuning over {len(search_space)} configurations...")
|
|
|
|
start = time.time()
|
|
configs = _distribute(
|
|
"tune",
|
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
|
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
|
for batch_size in batch_sizes])
|
|
best_configs = {
|
|
M: sort_config(config)
|
|
for M, config in zip(batch_sizes, configs)
|
|
}
|
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
|
|
block_quant_shape)
|
|
end = time.time()
|
|
print(f"Tuning took {end - start:.2f} seconds")
|
|
else:
|
|
outputs = _distribute(
|
|
"benchmark",
|
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
|
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
|
for batch_size in batch_sizes])
|
|
|
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
|
print(f"Batch size: {batch_size}, config: {config}")
|
|
print(f"Kernel time: {kernel_time:.2f} us")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = FlexibleArgumentParser()
|
|
parser.add_argument("--model",
|
|
type=str,
|
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
|
parser.add_argument("--tp-size",
|
|
"-tp",
|
|
"--tensor-parallel-size",
|
|
type=int,
|
|
default=2)
|
|
parser.add_argument("--dtype",
|
|
type=str,
|
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
|
default="auto")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--batch-size", type=int, required=False)
|
|
parser.add_argument("--tune", action="store_true")
|
|
parser.add_argument("--trust-remote-code", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|