[Neuron] Add custom_ops for neuron backend (#13246)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: George Novack <gnovack@amazon.com>
Co-authored-by: Aoyu Zhang <aoyuzhan@amazon.com>
This commit is contained in:
Liangfu Chen 2025-02-25 11:47:49 -08:00 committed by GitHub
parent 340e39e387
commit f75aa72732
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 346 additions and 3 deletions

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
from vllm.platforms import current_platform
@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
@pytest.mark.parametrize("num_tokens,d,dtype", [
(7, 512, torch.half),
(7, 512, torch.float),
(83, 512, torch.half),
])
@torch.inference_mode()
def test_act_and_mul(
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
if activation == "silu_and_mul":
layer = SiluAndMul()
fn = layer.forward_native
elif activation == "gelu_fast":
layer = FastGELU()
fn = F.gelu
else:
raise NotImplementedError(
f"activation {activation} is not implemented.")
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device).forward_neuron(x)
ref_out = fn(x.cpu())
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)

View File

@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
(7, 8, False, torch.half),
(83, 768, False, torch.half),
(83, 768, True, torch.half),
(83, 768, True, torch.bfloat16),
(83, 768, True, torch.float32),
])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
residual_cpu = residual.cpu() if add_residual else None
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device)(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if add_residual:
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out[0].cpu(),
ref_out[0],
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out[1].cpu(),
ref_out[1],
atol=1e-2,
rtol=1e-2)
else:
assert out.is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Tuple
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available
class MockLogitsProcessor(LogitsProcessor):
def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()
def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
return input_tensor, fake_logits, logits_processor
RANDOM_SEEDS = list(range(8))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_logits_processors(seed: int):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
set_random_seed(seed)
torch.set_default_device("cpu")
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
seq_group_metadata_list = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
fake_logits *= logits_processor.scale
torch.testing.assert_close(logits_processor_output[:, 1],
fake_logits[:, 1],
rtol=1e-4,
atol=0.0)

View File

@ -345,6 +345,7 @@ def test_contexted_kv_attention(
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
torch.set_default_device("cpu")
dtype = torch.float32
min_ctx_len = 32
@ -438,9 +439,9 @@ def test_contexted_kv_attention(
# transform block table
active_block_table = get_active_block_tables(
block_table,
torch.tensor(query_lens),
torch.tensor(seq_lens),
block_table.cpu(),
torch.tensor(query_lens).cpu(),
torch.tensor(seq_lens).cpu(),
block_size,
num_active_blocks,
)

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
(16, False, 32, 32, 1024),
(16, False, 32, 128, 1024),
(16, True, 32, 32, 1024),
(16, True, 32, 128, 1024),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
batch_size = 1
base = 10000
num_heads = 8
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query)
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device))
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)

View File

@ -59,6 +59,11 @@ class CustomOp(nn.Module):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_neuron(self, *args, **kwargs):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
@ -88,6 +93,8 @@ class CustomOp(nn.Module):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree():
return self.forward_oot
else:

View File

@ -89,6 +89,13 @@ class SiluAndMul(CustomOp):
self.op(out, x)
return out
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
result = s * x_reshaped[:, d:]
return result.view(*x.shape[:-1], d)
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):

View File

@ -53,6 +53,7 @@ class LogitsProcessor(nn.Module):
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or current_platform.is_neuron() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa

View File

@ -254,6 +254,82 @@ class RotaryEmbedding(CustomOp):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
if offsets is not None:
positions = positions + offsets
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if self.rotary_dim == self.head_size:
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
query = query.reshape(query_shape)
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
key = key.reshape(key_shape)
else:
head_size = query.shape[-1]
query_reshaped = query.view(-1, head_size)
query_pass = query_reshaped[:, self.rotary_dim:].view(
*query.shape[:-1], head_size - self.rotary_dim)
query_rot = query_reshaped[:, :self.rotary_dim].view(
*query.shape[:-1], self.rotary_dim)
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass),
dim=-1).reshape(query_shape)
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"