Add benchmark for DeepGEMM and vLLM Block FP8 Dense GEMM (#13917)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
ffad94397d
commit
ca100c90fe
129
benchmarks/kernels/deepgemm/README.md
Normal file
129
benchmarks/kernels/deepgemm/README.md
Normal file
@ -0,0 +1,129 @@
|
||||
# DeepSeek DeepGEMM Kernels Benchmark
|
||||
|
||||
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
|
||||
|
||||
Currently this just includes dense GEMMs and only works on Hopper GPUs.
|
||||
|
||||
## Setup
|
||||
|
||||
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
|
||||
|
||||
```
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
|
||||
cd DeepGEMM
|
||||
python setup.py install
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python benchmark_fp8_block_dense_gemm.py
|
||||
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
|
||||
===== STARTING FP8 GEMM BENCHMARK =====
|
||||
PyTorch version: 2.5.1+cu124
|
||||
CUDA version: 12.4
|
||||
Triton version: 3.1.0
|
||||
Using device: NVIDIA H100 80GB HBM3
|
||||
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
|
||||
===== PERFORMANCE COMPARISON =====
|
||||
|
||||
DeepGEMM Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
|
||||
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
|
||||
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
|
||||
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
|
||||
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
|
||||
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
|
||||
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
|
||||
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
|
||||
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
|
||||
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
|
||||
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
|
||||
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
|
||||
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
|
||||
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
|
||||
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
|
||||
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
|
||||
vLLM Triton Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
|
||||
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
|
||||
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
|
||||
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
|
||||
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
|
||||
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
|
||||
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
|
||||
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
|
||||
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
|
||||
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
|
||||
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
|
||||
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
|
||||
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
|
||||
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
|
||||
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
|
||||
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
|
||||
vLLM CUTLASS Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
|
||||
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
|
||||
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
|
||||
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
|
||||
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
|
||||
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
|
||||
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
|
||||
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
|
||||
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
|
||||
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
|
||||
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
|
||||
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
|
||||
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
|
||||
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
|
||||
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
|
||||
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
|
||||
===== AVERAGE PERFORMANCE =====
|
||||
+----------------+------------+----------+---------------+
|
||||
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
|
||||
+----------------+------------+----------+---------------+
|
||||
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
|
||||
| vLLM Triton | 144.30 | 715.60 | 0.23 |
|
||||
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
|
||||
+----------------+------------+----------+---------------+
|
||||
|
||||
===== AVERAGE SPEEDUPS =====
|
||||
+-----------------------------+--------------+
|
||||
| Comparison | Speedup |
|
||||
+-----------------------------+--------------+
|
||||
| DeepGEMM vs vLLM Triton | 1.71x faster |
|
||||
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
|
||||
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
|
||||
+-----------------------------+--------------+
|
||||
|
||||
===== ACCURACY COMPARISON =====
|
||||
+----------------+-----------------------+
|
||||
| Implementation | Avg Diff vs Reference |
|
||||
+----------------+-----------------------+
|
||||
| DeepGEMM | 0.000684 |
|
||||
| vLLM Triton | 0.000684 |
|
||||
| vLLM CUTLASS | 0.000684 |
|
||||
+----------------+-----------------------+
|
||||
```
|
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# fmt: off
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
# Import DeepGEMM functions
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-token scaling."""
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
||||
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-block scaling."""
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def benchmark_shape(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False) -> dict:
|
||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||
if verbose:
|
||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||
|
||||
# Create test tensors
|
||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
# Reference result in BF16
|
||||
torch.cuda.synchronize()
|
||||
C_ref = A @ B.t()
|
||||
|
||||
# Pre-quantize B for all implementations
|
||||
# (weights can be pre-quantized offline)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
|
||||
|
||||
# Block size configuration
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True)
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
# A, block_size[1])
|
||||
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
return C_deepgemm
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16)
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
# A, block_size[1], column_major_scales=True)
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
# Run correctness check first
|
||||
if verbose:
|
||||
print("Running correctness check...")
|
||||
C_deepgemm = deepgemm_gemm()
|
||||
C_vllm_triton = vllm_triton_gemm()
|
||||
C_vllm_cutlass = vllm_cutlass_gemm()
|
||||
|
||||
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
||||
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
||||
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
||||
|
||||
if verbose:
|
||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||
print("vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
||||
|
||||
# Benchmark implementations
|
||||
implementations = {
|
||||
"DeepGEMM": deepgemm_gemm,
|
||||
"vLLM Triton": vllm_triton_gemm,
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
||||
}
|
||||
|
||||
benchmark_results = {
|
||||
"shape": {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k
|
||||
},
|
||||
"implementations": {}
|
||||
}
|
||||
|
||||
for name, func in implementations.items():
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing loop
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
# Calculate timing and TFLOPS
|
||||
avg_time_ms = (end - start) / repeat * 1000
|
||||
avg_time_us = avg_time_ms * 1000
|
||||
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
||||
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
||||
|
||||
benchmark_results["implementations"][name] = {
|
||||
"time_ms": avg_time_ms,
|
||||
"time_us": avg_time_us,
|
||||
"tflops": tflops,
|
||||
"gb_s": gb_s,
|
||||
"diff": {
|
||||
"DeepGEMM":
|
||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
||||
"Reference":
|
||||
deepgemm_diff if name == "DeepGEMM" else
|
||||
(vllm_triton_diff
|
||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
||||
}
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
||||
)
|
||||
|
||||
# Calculate speedups
|
||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||
for name, data in benchmark_results["implementations"].items():
|
||||
if name != "DeepGEMM":
|
||||
speedup = baseline / data["time_ms"]
|
||||
benchmark_results["implementations"][name][
|
||||
"speedup_vs_deepgemm"] = speedup
|
||||
if verbose:
|
||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
||||
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
||||
"time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"speedup_vs_triton"] = cutlass_vs_triton
|
||||
if verbose:
|
||||
print(
|
||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
||||
)
|
||||
|
||||
return benchmark_results
|
||||
|
||||
|
||||
def format_table_row(values, widths):
|
||||
"""Format a row with specified column widths."""
|
||||
return "| " + " | ".join(f"{val:{w}}"
|
||||
for val, w in zip(values, widths)) + " |"
|
||||
|
||||
|
||||
def print_table(headers, rows, title=None):
|
||||
"""Print a table with headers and rows."""
|
||||
if title:
|
||||
print(f"\n{title}")
|
||||
|
||||
# Calculate column widths based on headers and data
|
||||
widths = [
|
||||
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
|
||||
# Create separator line
|
||||
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
||||
|
||||
# Print table
|
||||
print(separator)
|
||||
print(format_table_row(headers, widths))
|
||||
print(separator)
|
||||
for row in rows:
|
||||
print(format_table_row(row, widths))
|
||||
print(separator)
|
||||
|
||||
|
||||
def format_speedup(value):
|
||||
"""Format speedup value with indicator if it's faster or slower."""
|
||||
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
||||
|
||||
|
||||
def run_benchmarks(verbose: bool = False):
|
||||
"""Run benchmarks for a set of common shapes."""
|
||||
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
||||
|
||||
# Make sure we're using the GPU
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available! Tests require GPU.")
|
||||
return
|
||||
|
||||
# Print system information
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
print(f"Triton version: {triton.__version__}")
|
||||
print(f"Using device: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Enable TF32 for better performance
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Set seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
# Define benchmark shapes (m, n, k)
|
||||
shapes = [
|
||||
(8, 4096, 7168),
|
||||
(8, 7168, 18432),
|
||||
(8, 18432, 7168),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 18432),
|
||||
(64, 18432, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 18432),
|
||||
(128, 18432, 7168),
|
||||
(1024, 4096, 7168),
|
||||
(1024, 18432, 7168),
|
||||
(2048, 4096, 7168),
|
||||
(4096, 4096, 7168),
|
||||
]
|
||||
shapes = [
|
||||
# (64, 2112, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 2048),
|
||||
# (128, 2112, 7168),
|
||||
(128, 24576, 1536),
|
||||
(128, 32768, 512),
|
||||
(128, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 2048),
|
||||
# (4096, 2112, 7168),
|
||||
(4096, 24576, 1536),
|
||||
(4096, 32768, 512),
|
||||
(4096, 7168, 16384),
|
||||
(4096, 4096, 7168),
|
||||
(4096, 7168, 2048),
|
||||
]
|
||||
|
||||
all_results = []
|
||||
for m, n, k in shapes:
|
||||
result = benchmark_shape(m, n, k, verbose=verbose)
|
||||
all_results.append(result)
|
||||
|
||||
# Print results in a nicely formatted table
|
||||
print("\n===== PERFORMANCE COMPARISON =====")
|
||||
|
||||
# Print DeepGEMM table
|
||||
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
||||
deepgemm_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["DeepGEMM"]
|
||||
deepgemm_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
||||
])
|
||||
|
||||
print_table(deepgemm_headers,
|
||||
deepgemm_rows,
|
||||
title="DeepGEMM Implementation:")
|
||||
|
||||
# Print vLLM Triton table
|
||||
triton_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
||||
]
|
||||
triton_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM Triton"]
|
||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
triton_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup)
|
||||
])
|
||||
|
||||
print_table(triton_headers,
|
||||
triton_rows,
|
||||
title="vLLM Triton Implementation:")
|
||||
|
||||
# Print vLLM CUTLASS table
|
||||
cutlass_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
||||
"vs Triton"
|
||||
]
|
||||
cutlass_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||
cutlass_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton)
|
||||
])
|
||||
|
||||
print_table(cutlass_headers,
|
||||
cutlass_rows,
|
||||
title="vLLM CUTLASS Implementation:")
|
||||
|
||||
# Calculate and print averages
|
||||
print("\n===== AVERAGE PERFORMANCE =====")
|
||||
|
||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||
avg_metrics = {
|
||||
impl: {
|
||||
"tflops": 0,
|
||||
"gb_s": 0,
|
||||
"time_ms": 0
|
||||
}
|
||||
for impl in implementations
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
impl_data = result["implementations"][impl]
|
||||
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
||||
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
||||
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
||||
|
||||
num_shapes = len(all_results)
|
||||
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
||||
avg_rows = []
|
||||
|
||||
for impl in implementations:
|
||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||
avg_rows.append([
|
||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
||||
])
|
||||
|
||||
print_table(avg_headers, avg_rows)
|
||||
|
||||
# Calculate average speedups
|
||||
avg_speedups = {
|
||||
"DeepGEMM vs vLLM Triton": 0,
|
||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||
"vLLM CUTLASS vs vLLM Triton": 0
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
||||
|
||||
print("\n===== AVERAGE SPEEDUPS =====")
|
||||
speedup_headers = ["Comparison", "Speedup"]
|
||||
speedup_rows = []
|
||||
for comparison, total in avg_speedups.items():
|
||||
avg_speedup = total / num_shapes
|
||||
status = "faster" if avg_speedup > 1 else "slower"
|
||||
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
||||
|
||||
print_table(speedup_headers, speedup_rows)
|
||||
|
||||
# Average accuracy comparison
|
||||
print("\n===== ACCURACY COMPARISON =====")
|
||||
avg_diff = {impl: 0 for impl in implementations}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
||||
"Reference"]
|
||||
|
||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||
diff_rows = []
|
||||
for impl in implementations:
|
||||
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
||||
|
||||
print_table(diff_headers, diff_rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_benchmarks(verbose=False)
|
Loading…
x
Reference in New Issue
Block a user