vllm/tests/neuron/test_rotary_embedding.py
Liangfu Chen f75aa72732
[Neuron] Add custom_ops for neuron backend (#13246)
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: George Novack <gnovack@amazon.com>
Co-authored-by: Aoyu Zhang <aoyuzhan@amazon.com>
2025-02-25 11:47:49 -08:00

59 lines
2.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
(16, False, 32, 32, 1024),
(16, False, 32, 128, 1024),
(16, True, 32, 32, 1024),
(16, True, 32, 128, 1024),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
batch_size = 1
base = 10000
num_heads = 8
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query)
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device))
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)