
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: George Novack <gnovack@amazon.com> Co-authored-by: Aoyu Zhang <aoyuzhan@amazon.com>
57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
|
|
(7, 8, False, torch.half),
|
|
(83, 768, False, torch.half),
|
|
(83, 768, True, torch.half),
|
|
(83, 768, True, torch.bfloat16),
|
|
(83, 768, True, torch.float32),
|
|
])
|
|
@torch.inference_mode()
|
|
def test_rms_norm(
|
|
num_tokens: int,
|
|
hidden_size: int,
|
|
add_residual: bool,
|
|
dtype: torch.dtype,
|
|
) -> None:
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
device = xm.xla_device()
|
|
current_platform.seed_everything(0)
|
|
torch.set_default_device("cpu")
|
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
|
scale = 1 / (2 * hidden_size)
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
|
|
x *= scale
|
|
residual = torch.randn_like(x) * scale if add_residual else None
|
|
|
|
residual_cpu = residual.cpu() if add_residual else None
|
|
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
|
|
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
|
|
out = layer.to(device=device)(x, residual)
|
|
|
|
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
|
# numerical errors than other operators because they involve reductions.
|
|
# Therefore, we use a larger tolerance.
|
|
if add_residual:
|
|
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
|
|
torch.testing.assert_close(out[0].cpu(),
|
|
ref_out[0],
|
|
atol=1e-2,
|
|
rtol=1e-2)
|
|
torch.testing.assert_close(out[1].cpu(),
|
|
ref_out[1],
|
|
atol=1e-2,
|
|
rtol=1e-2)
|
|
else:
|
|
assert out.is_xla, "output tensor is expected to be XLA tensor"
|
|
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
|