# SPDX-License-Identifier: Apache-2.0 import pytest import torch from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform @pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ (7, 8, False, torch.half), (83, 768, False, torch.half), (83, 768, True, torch.half), (83, 768, True, torch.bfloat16), (83, 768, True, torch.float32), ]) @torch.inference_mode() def test_rms_norm( num_tokens: int, hidden_size: int, add_residual: bool, dtype: torch.dtype, ) -> None: import torch_xla.core.xla_model as xm device = xm.xla_device() current_platform.seed_everything(0) torch.set_default_device("cpu") layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) x *= scale residual = torch.randn_like(x) * scale if add_residual else None residual_cpu = residual.cpu() if add_residual else None ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) assert x.is_xla, "input tensor under testing is expected to be XLA tensor." out = layer.to(device=device)(x, residual) # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger # numerical errors than other operators because they involve reductions. # Therefore, we use a larger tolerance. if add_residual: assert out[0].is_xla, "output tensor is expected to be XLA tensor" torch.testing.assert_close(out[0].cpu(), ref_out[0], atol=1e-2, rtol=1e-2) torch.testing.assert_close(out[1].cpu(), ref_out[1], atol=1e-2, rtol=1e-2) else: assert out.is_xla, "output tensor is expected to be XLA tensor" torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)