465 lines
16 KiB
Python
465 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# fmt: off
|
|
# ruff: noqa: E501
|
|
import time
|
|
|
|
# Import DeepGEMM functions
|
|
import deep_gemm
|
|
import torch
|
|
import triton
|
|
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
|
|
|
# Import vLLM functions
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
|
|
|
|
|
# Copied from
|
|
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
|
|
def per_token_cast_to_fp8(
|
|
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Convert tensor to FP8 format with per-token scaling."""
|
|
assert x.dim() == 2 and x.size(1) % 128 == 0
|
|
m, n = x.shape
|
|
x_view = x.view(m, -1, 128)
|
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
|
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
|
|
|
|
|
# Copied from
|
|
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
|
|
def per_block_cast_to_fp8(
|
|
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Convert tensor to FP8 format with per-block scaling."""
|
|
assert x.dim() == 2
|
|
m, n = x.shape
|
|
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:m, :n] = x
|
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
|
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
|
|
|
|
|
def benchmark_shape(m: int,
|
|
n: int,
|
|
k: int,
|
|
warmup: int = 100,
|
|
repeat: int = 10000,
|
|
verbose: bool = False) -> dict:
|
|
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
|
if verbose:
|
|
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
|
|
|
# Create test tensors
|
|
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
|
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
|
|
|
# Reference result in BF16
|
|
torch.cuda.synchronize()
|
|
C_ref = A @ B.t()
|
|
|
|
# Pre-quantize B for all implementations
|
|
# (weights can be pre-quantized offline)
|
|
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
|
|
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
|
|
|
|
# Block size configuration
|
|
block_size = [128, 128]
|
|
|
|
# Pre-quantize A for all implementations
|
|
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
|
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
|
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
|
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
|
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
|
A, block_size[1], column_major_scales=True)
|
|
|
|
# === DeepGEMM Implementation ===
|
|
def deepgemm_gemm():
|
|
# A quantization is inside the loop as it depends on activations
|
|
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
|
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
|
# A, block_size[1])
|
|
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
|
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
|
(B_deepgemm, B_scale_deepgemm),
|
|
C_deepgemm)
|
|
return C_deepgemm
|
|
|
|
# === vLLM Triton Implementation ===
|
|
def vllm_triton_gemm():
|
|
# A quantization is inside the loop as it depends on activations
|
|
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
|
return w8a8_block_fp8_matmul(A_vllm,
|
|
B_vllm,
|
|
A_scale_vllm,
|
|
B_scale_vllm,
|
|
block_size,
|
|
output_dtype=torch.bfloat16)
|
|
|
|
# === vLLM CUTLASS Implementation ===
|
|
def vllm_cutlass_gemm():
|
|
# A quantization is inside the loop as it depends on activations
|
|
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
|
# A, block_size[1], column_major_scales=True)
|
|
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
|
B_vllm.T,
|
|
scale_a=A_scale_vllm_cutlass,
|
|
scale_b=B_scale_vllm.T,
|
|
out_dtype=torch.bfloat16)
|
|
|
|
# Run correctness check first
|
|
if verbose:
|
|
print("Running correctness check...")
|
|
C_deepgemm = deepgemm_gemm()
|
|
C_vllm_triton = vllm_triton_gemm()
|
|
C_vllm_cutlass = vllm_cutlass_gemm()
|
|
|
|
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
|
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
|
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
|
|
|
if verbose:
|
|
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
|
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
|
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
|
print("vLLM Triton vs DeepGEMM difference: "
|
|
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
|
print("vLLM CUTLASS vs DeepGEMM difference: "
|
|
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
|
|
|
# Benchmark implementations
|
|
implementations = {
|
|
"DeepGEMM": deepgemm_gemm,
|
|
"vLLM Triton": vllm_triton_gemm,
|
|
"vLLM CUTLASS": vllm_cutlass_gemm
|
|
}
|
|
|
|
benchmark_results = {
|
|
"shape": {
|
|
"m": m,
|
|
"n": n,
|
|
"k": k
|
|
},
|
|
"implementations": {}
|
|
}
|
|
|
|
for name, func in implementations.items():
|
|
# Warmup
|
|
for _ in range(warmup):
|
|
func()
|
|
torch.cuda.synchronize()
|
|
|
|
# Timing loop
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for _ in range(repeat):
|
|
func()
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
|
|
# Calculate timing and TFLOPS
|
|
avg_time_ms = (end - start) / repeat * 1000
|
|
avg_time_us = avg_time_ms * 1000
|
|
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
|
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
|
|
|
benchmark_results["implementations"][name] = {
|
|
"time_ms": avg_time_ms,
|
|
"time_us": avg_time_us,
|
|
"tflops": tflops,
|
|
"gb_s": gb_s,
|
|
"diff": {
|
|
"DeepGEMM":
|
|
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
|
"Reference":
|
|
deepgemm_diff if name == "DeepGEMM" else
|
|
(vllm_triton_diff
|
|
if name == "vLLM Triton" else vllm_cutlass_diff)
|
|
}
|
|
}
|
|
|
|
if verbose:
|
|
print(
|
|
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
|
)
|
|
|
|
# Calculate speedups
|
|
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
|
for name, data in benchmark_results["implementations"].items():
|
|
if name != "DeepGEMM":
|
|
speedup = baseline / data["time_ms"]
|
|
benchmark_results["implementations"][name][
|
|
"speedup_vs_deepgemm"] = speedup
|
|
if verbose:
|
|
print(f"DeepGEMM is {1/speedup:.2f}x "
|
|
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
|
|
|
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
|
"time_ms"]
|
|
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
|
"time_ms"]
|
|
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
|
benchmark_results["implementations"]["vLLM CUTLASS"][
|
|
"speedup_vs_triton"] = cutlass_vs_triton
|
|
if verbose:
|
|
print(
|
|
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
|
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
|
)
|
|
|
|
return benchmark_results
|
|
|
|
|
|
def format_table_row(values, widths):
|
|
"""Format a row with specified column widths."""
|
|
return "| " + " | ".join(f"{val:{w}}"
|
|
for val, w in zip(values, widths)) + " |"
|
|
|
|
|
|
def print_table(headers, rows, title=None):
|
|
"""Print a table with headers and rows."""
|
|
if title:
|
|
print(f"\n{title}")
|
|
|
|
# Calculate column widths based on headers and data
|
|
widths = [
|
|
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
|
for i, h in enumerate(headers)
|
|
]
|
|
|
|
# Create separator line
|
|
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
|
|
|
# Print table
|
|
print(separator)
|
|
print(format_table_row(headers, widths))
|
|
print(separator)
|
|
for row in rows:
|
|
print(format_table_row(row, widths))
|
|
print(separator)
|
|
|
|
|
|
def format_speedup(value):
|
|
"""Format speedup value with indicator if it's faster or slower."""
|
|
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
|
|
|
|
|
def run_benchmarks(verbose: bool = False):
|
|
"""Run benchmarks for a set of common shapes."""
|
|
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
|
|
|
# Make sure we're using the GPU
|
|
if not torch.cuda.is_available():
|
|
print("CUDA not available! Tests require GPU.")
|
|
return
|
|
|
|
# Print system information
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
print(f"Triton version: {triton.__version__}")
|
|
print(f"Using device: {torch.cuda.get_device_name()}")
|
|
|
|
# Enable TF32 for better performance
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# Set seeds for reproducibility
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed(42)
|
|
|
|
# Define benchmark shapes (m, n, k)
|
|
shapes = [
|
|
(8, 4096, 7168),
|
|
(8, 7168, 18432),
|
|
(8, 18432, 7168),
|
|
(64, 4096, 7168),
|
|
(64, 7168, 18432),
|
|
(64, 18432, 7168),
|
|
(64, 24576, 1536),
|
|
(64, 32768, 512),
|
|
(64, 7168, 16384),
|
|
(128, 4096, 7168),
|
|
(128, 7168, 18432),
|
|
(128, 18432, 7168),
|
|
(1024, 4096, 7168),
|
|
(1024, 18432, 7168),
|
|
(2048, 4096, 7168),
|
|
(4096, 4096, 7168),
|
|
]
|
|
shapes = [
|
|
# (64, 2112, 7168),
|
|
(64, 24576, 1536),
|
|
(64, 32768, 512),
|
|
(64, 7168, 16384),
|
|
(64, 4096, 7168),
|
|
(64, 7168, 2048),
|
|
# (128, 2112, 7168),
|
|
(128, 24576, 1536),
|
|
(128, 32768, 512),
|
|
(128, 7168, 16384),
|
|
(128, 4096, 7168),
|
|
(128, 7168, 2048),
|
|
# (4096, 2112, 7168),
|
|
(4096, 24576, 1536),
|
|
(4096, 32768, 512),
|
|
(4096, 7168, 16384),
|
|
(4096, 4096, 7168),
|
|
(4096, 7168, 2048),
|
|
]
|
|
|
|
all_results = []
|
|
for m, n, k in shapes:
|
|
result = benchmark_shape(m, n, k, verbose=verbose)
|
|
all_results.append(result)
|
|
|
|
# Print results in a nicely formatted table
|
|
print("\n===== PERFORMANCE COMPARISON =====")
|
|
|
|
# Print DeepGEMM table
|
|
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
|
deepgemm_rows = []
|
|
for result in all_results:
|
|
shape = result["shape"]
|
|
impl_data = result["implementations"]["DeepGEMM"]
|
|
deepgemm_rows.append([
|
|
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
|
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
|
])
|
|
|
|
print_table(deepgemm_headers,
|
|
deepgemm_rows,
|
|
title="DeepGEMM Implementation:")
|
|
|
|
# Print vLLM Triton table
|
|
triton_headers = [
|
|
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
|
]
|
|
triton_rows = []
|
|
for result in all_results:
|
|
shape = result["shape"]
|
|
impl_data = result["implementations"]["vLLM Triton"]
|
|
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
|
triton_rows.append([
|
|
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
|
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
|
format_speedup(speedup)
|
|
])
|
|
|
|
print_table(triton_headers,
|
|
triton_rows,
|
|
title="vLLM Triton Implementation:")
|
|
|
|
# Print vLLM CUTLASS table
|
|
cutlass_headers = [
|
|
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
|
"vs Triton"
|
|
]
|
|
cutlass_rows = []
|
|
for result in all_results:
|
|
shape = result["shape"]
|
|
impl_data = result["implementations"]["vLLM CUTLASS"]
|
|
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
|
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
|
cutlass_rows.append([
|
|
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
|
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
|
format_speedup(vs_deepgemm),
|
|
format_speedup(vs_triton)
|
|
])
|
|
|
|
print_table(cutlass_headers,
|
|
cutlass_rows,
|
|
title="vLLM CUTLASS Implementation:")
|
|
|
|
# Calculate and print averages
|
|
print("\n===== AVERAGE PERFORMANCE =====")
|
|
|
|
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
|
avg_metrics = {
|
|
impl: {
|
|
"tflops": 0,
|
|
"gb_s": 0,
|
|
"time_ms": 0
|
|
}
|
|
for impl in implementations
|
|
}
|
|
|
|
for result in all_results:
|
|
for impl in implementations:
|
|
impl_data = result["implementations"][impl]
|
|
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
|
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
|
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
|
|
|
num_shapes = len(all_results)
|
|
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
|
avg_rows = []
|
|
|
|
for impl in implementations:
|
|
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
|
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
|
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
|
avg_rows.append([
|
|
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
|
])
|
|
|
|
print_table(avg_headers, avg_rows)
|
|
|
|
# Calculate average speedups
|
|
avg_speedups = {
|
|
"DeepGEMM vs vLLM Triton": 0,
|
|
"DeepGEMM vs vLLM CUTLASS": 0,
|
|
"vLLM CUTLASS vs vLLM Triton": 0
|
|
}
|
|
|
|
for result in all_results:
|
|
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
|
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
|
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
|
"time_ms"]
|
|
|
|
avg_speedups[
|
|
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
|
avg_speedups[
|
|
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
|
avg_speedups[
|
|
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
|
|
|
print("\n===== AVERAGE SPEEDUPS =====")
|
|
speedup_headers = ["Comparison", "Speedup"]
|
|
speedup_rows = []
|
|
for comparison, total in avg_speedups.items():
|
|
avg_speedup = total / num_shapes
|
|
status = "faster" if avg_speedup > 1 else "slower"
|
|
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
|
|
|
print_table(speedup_headers, speedup_rows)
|
|
|
|
# Average accuracy comparison
|
|
print("\n===== ACCURACY COMPARISON =====")
|
|
avg_diff = {impl: 0 for impl in implementations}
|
|
|
|
for result in all_results:
|
|
for impl in implementations:
|
|
avg_diff[impl] += result["implementations"][impl]["diff"][
|
|
"Reference"]
|
|
|
|
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
|
diff_rows = []
|
|
for impl in implementations:
|
|
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
|
|
|
print_table(diff_headers, diff_rows)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_benchmarks(verbose=False)
|