59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm import layernorm_ops
|
|
|
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
|
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
|
SEEDS = [0]
|
|
|
|
|
|
class RefRMSNorm(nn.Module):
|
|
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
weight = torch.empty(hidden_size)
|
|
weight.normal_(mean=1.0, std=0.1)
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance +
|
|
self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@torch.inference_mode()
|
|
def test_rms_norm(
|
|
num_tokens: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
) -> None:
|
|
torch.random.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
scale = float(hidden_size**-0.5)
|
|
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
|
x.uniform_(-scale, scale)
|
|
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
|
|
|
out = torch.empty_like(x)
|
|
layernorm_ops.rms_norm(
|
|
out,
|
|
x,
|
|
ref.weight.data,
|
|
ref.variance_epsilon,
|
|
)
|
|
ref_out = ref(x)
|
|
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|