
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
265 lines
7.9 KiB
Python
265 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import itertools
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import triton
|
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
|
from torch import nn
|
|
|
|
from vllm import _custom_ops as vllm_ops
|
|
|
|
|
|
class HuggingFaceRMSNorm(nn.Module):
|
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
orig_dtype = x.dtype
|
|
x = x.to(torch.float32)
|
|
if residual is not None:
|
|
x = x + residual.to(torch.float32)
|
|
residual = x.to(orig_dtype)
|
|
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
x = x.to(orig_dtype) * self.weight
|
|
if residual is None:
|
|
return x
|
|
else:
|
|
return x, residual
|
|
|
|
|
|
def rmsnorm_naive(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-6,
|
|
):
|
|
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
|
naive_norm.weight = nn.Parameter(weight)
|
|
naive_norm = naive_norm.to(x.device)
|
|
|
|
orig_shape = x.shape
|
|
x = x.view(-1, x.shape[-1])
|
|
if residual is not None:
|
|
residual = residual.view(-1, residual.shape[-1])
|
|
|
|
output = naive_norm(x, residual)
|
|
|
|
if isinstance(output, tuple):
|
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
else:
|
|
output = output.view(orig_shape)
|
|
return output
|
|
|
|
|
|
def rmsnorm_flashinfer(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-6,
|
|
):
|
|
orig_shape = x.shape
|
|
x = x.view(-1, x.shape[-1])
|
|
if residual is not None:
|
|
residual = residual.view(-1, residual.shape[-1])
|
|
|
|
if residual is not None:
|
|
fused_add_rmsnorm(x, residual, weight, eps)
|
|
output = (x, residual)
|
|
else:
|
|
output = rmsnorm(x, weight, eps)
|
|
|
|
if isinstance(output, tuple):
|
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
else:
|
|
output = output.view(orig_shape)
|
|
return output
|
|
|
|
|
|
def rmsnorm_vllm(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-6,
|
|
):
|
|
orig_shape = x.shape
|
|
x = x.view(-1, x.shape[-1])
|
|
if residual is not None:
|
|
residual = residual.view(-1, residual.shape[-1])
|
|
|
|
if residual is not None:
|
|
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
|
output = (x, residual)
|
|
else:
|
|
out = torch.empty_like(x)
|
|
vllm_ops.rms_norm(out, x, weight, eps)
|
|
output = out
|
|
|
|
if isinstance(output, tuple):
|
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
|
else:
|
|
output = output.view(orig_shape)
|
|
return output
|
|
|
|
|
|
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
|
dtype = torch.bfloat16
|
|
x = torch.randn(batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
|
residual = torch.randn_like(x) if use_residual else None
|
|
|
|
output_naive = rmsnorm_naive(
|
|
x.clone(), weight,
|
|
residual.clone() if residual is not None else None)
|
|
output_flashinfer = rmsnorm_flashinfer(
|
|
x.clone(), weight,
|
|
residual.clone() if residual is not None else None)
|
|
output_vllm = rmsnorm_vllm(
|
|
x.clone(), weight,
|
|
residual.clone() if residual is not None else None)
|
|
|
|
if use_residual:
|
|
output_naive = output_naive[0]
|
|
output_flashinfer = output_flashinfer[0]
|
|
output_vllm = output_vllm[0]
|
|
|
|
print(f"Naive output={output_naive}")
|
|
print(f"FlashInfer output={output_flashinfer}")
|
|
print(f"VLLM output={output_vllm}")
|
|
|
|
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
|
|
rtol=1e-2) and torch.allclose(
|
|
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
|
print("✅ All implementations match")
|
|
else:
|
|
print("❌ Implementations differ")
|
|
|
|
|
|
batch_size_range = [2**i for i in range(0, 7, 2)]
|
|
seq_length_range = [2**i for i in range(6, 11, 1)]
|
|
head_num_range = [32, 48]
|
|
configs = list(
|
|
itertools.product(head_num_range, batch_size_range, seq_length_range))
|
|
|
|
|
|
def get_benchmark(use_residual):
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["head_num", "batch_size", "seq_len"],
|
|
x_vals=[list(_) for _ in configs],
|
|
line_arg="provider",
|
|
line_vals=["huggingface", "flashinfer", "vllm"],
|
|
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
|
ylabel="us",
|
|
plot_name=
|
|
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
|
args={},
|
|
))
|
|
def benchmark(head_num, batch_size, seq_len, provider):
|
|
dtype = torch.bfloat16
|
|
hidden_size = head_num * 128 # assuming head_dim = 128
|
|
|
|
x = torch.randn(batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
|
residual = torch.randn_like(x) if use_residual else None
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "huggingface":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: rmsnorm_naive(
|
|
x.clone(),
|
|
weight,
|
|
residual.clone() if residual is not None else None,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
elif provider == "flashinfer":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: rmsnorm_flashinfer(
|
|
x.clone(),
|
|
weight,
|
|
residual.clone() if residual is not None else None,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
else:
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: rmsnorm_vllm(
|
|
x.clone(),
|
|
weight,
|
|
residual.clone() if residual is not None else None,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
return benchmark
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=4,
|
|
help="Batch size",
|
|
)
|
|
parser.add_argument(
|
|
"--seq-len",
|
|
type=int,
|
|
default=128,
|
|
help="Sequence length",
|
|
)
|
|
parser.add_argument(
|
|
"--hidden-size",
|
|
type=int,
|
|
default=4096,
|
|
help="Hidden size (2nd dimension) of the sequence",
|
|
)
|
|
parser.add_argument("--use-residual",
|
|
action="store_true",
|
|
help="Whether to use residual connection")
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
default="./configs/rmsnorm/",
|
|
help="Path to save rmsnorm benchmark results",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Run correctness test
|
|
calculate_diff(batch_size=args.batch_size,
|
|
seq_len=args.seq_len,
|
|
hidden_size=args.hidden_size,
|
|
use_residual=args.use_residual)
|
|
|
|
# Get the benchmark function with proper use_residual setting
|
|
benchmark = get_benchmark(args.use_residual)
|
|
# Run performance benchmark
|
|
benchmark.run(print_data=True, save_path=args.save_path)
|