[Neuron] Add custom_ops for neuron backend (#13246)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: George Novack <gnovack@amazon.com> Co-authored-by: Aoyu Zhang <aoyuzhan@amazon.com>
This commit is contained in:
parent
340e39e387
commit
f75aa72732
42
tests/neuron/test_activation.py
Normal file
42
tests/neuron/test_activation.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
|
||||||
|
@pytest.mark.parametrize("num_tokens,d,dtype", [
|
||||||
|
(7, 512, torch.half),
|
||||||
|
(7, 512, torch.float),
|
||||||
|
(83, 512, torch.half),
|
||||||
|
])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_act_and_mul(
|
||||||
|
activation: str,
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
|
||||||
|
if activation == "silu_and_mul":
|
||||||
|
layer = SiluAndMul()
|
||||||
|
fn = layer.forward_native
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
layer = FastGELU()
|
||||||
|
fn = F.gelu
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"activation {activation} is not implemented.")
|
||||||
|
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
|
||||||
|
out = layer.to(device=device).forward_neuron(x)
|
||||||
|
ref_out = fn(x.cpu())
|
||||||
|
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)
|
56
tests/neuron/test_layernorm.py
Normal file
56
tests/neuron/test_layernorm.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
|
||||||
|
(7, 8, False, torch.half),
|
||||||
|
(83, 768, False, torch.half),
|
||||||
|
(83, 768, True, torch.half),
|
||||||
|
(83, 768, True, torch.bfloat16),
|
||||||
|
(83, 768, True, torch.float32),
|
||||||
|
])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_rms_norm(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
scale = 1 / (2 * hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
|
||||||
|
x *= scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
|
residual_cpu = residual.cpu() if add_residual else None
|
||||||
|
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
|
||||||
|
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
|
||||||
|
out = layer.to(device=device)(x, residual)
|
||||||
|
|
||||||
|
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||||
|
# numerical errors than other operators because they involve reductions.
|
||||||
|
# Therefore, we use a larger tolerance.
|
||||||
|
if add_residual:
|
||||||
|
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
|
||||||
|
torch.testing.assert_close(out[0].cpu(),
|
||||||
|
ref_out[0],
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
torch.testing.assert_close(out[1].cpu(),
|
||||||
|
ref_out[1],
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
else:
|
||||||
|
assert out.is_xla, "output tensor is expected to be XLA tensor"
|
||||||
|
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
|
95
tests/neuron/test_logits_processor.py
Normal file
95
tests/neuron/test_logits_processor.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
|
class MockLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, scale: float,
|
||||||
|
fake_logits: torch.Tensor):
|
||||||
|
super().__init__(vocab_size=vocab_size, scale=scale)
|
||||||
|
self.fake_logits = fake_logits.clone()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
||||||
|
lambda x, y: x
|
||||||
|
), patch(
|
||||||
|
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
||||||
|
lambda *args, **kwargs: self.fake_logits):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_test(
|
||||||
|
batch_size: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||||
|
vocab_size = 32000
|
||||||
|
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||||
|
fake_logits = torch.full((batch_size, vocab_size),
|
||||||
|
1e-2,
|
||||||
|
dtype=input_tensor.dtype)
|
||||||
|
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||||
|
return input_tensor, fake_logits, logits_processor
|
||||||
|
|
||||||
|
|
||||||
|
RANDOM_SEEDS = list(range(8))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_logits_processors(seed: int):
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
# This sample logits processor gives infinite score to the i-th token,
|
||||||
|
# where i is the length of the input sequence.
|
||||||
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||||
|
def pick_ith(token_ids, logits):
|
||||||
|
logits[len(token_ids)] = float("inf")
|
||||||
|
return logits
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
seq_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(temperature=0,
|
||||||
|
logits_processors=[pick_ith]),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
seq_lens,
|
||||||
|
query_lens=seq_lens,
|
||||||
|
device=device,
|
||||||
|
pin_memory=is_pin_memory_available())
|
||||||
|
logits_processor_output = logits_processor(
|
||||||
|
lm_head=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
fake_logits *= logits_processor.scale
|
||||||
|
torch.testing.assert_close(logits_processor_output[:, 1],
|
||||||
|
fake_logits[:, 1],
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=0.0)
|
@ -345,6 +345,7 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.set_printoptions(sci_mode=False)
|
torch.set_printoptions(sci_mode=False)
|
||||||
|
torch.set_default_device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
min_ctx_len = 32
|
min_ctx_len = 32
|
||||||
@ -438,9 +439,9 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
# transform block table
|
# transform block table
|
||||||
active_block_table = get_active_block_tables(
|
active_block_table = get_active_block_tables(
|
||||||
block_table,
|
block_table.cpu(),
|
||||||
torch.tensor(query_lens),
|
torch.tensor(query_lens).cpu(),
|
||||||
torch.tensor(seq_lens),
|
torch.tensor(seq_lens).cpu(),
|
||||||
block_size,
|
block_size,
|
||||||
num_active_blocks,
|
num_active_blocks,
|
||||||
)
|
)
|
||||||
|
58
tests/neuron/test_rotary_embedding.py
Normal file
58
tests/neuron/test_rotary_embedding.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Tests for miscellaneous utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
|
||||||
|
(16, False, 32, 32, 1024),
|
||||||
|
(16, False, 32, 128, 1024),
|
||||||
|
(16, True, 32, 32, 1024),
|
||||||
|
(16, True, 32, 128, 1024),
|
||||||
|
])
|
||||||
|
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
|
||||||
|
head_size, seq_len):
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
base = 10000
|
||||||
|
num_heads = 8
|
||||||
|
|
||||||
|
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_style, torch.float32)
|
||||||
|
|
||||||
|
positions = torch.randint(0,
|
||||||
|
max_position, (batch_size, seq_len),
|
||||||
|
device="cpu")
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu")
|
||||||
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
|
assert positions.is_cpu, \
|
||||||
|
"reference input tensor is expected to be CPU tensor."
|
||||||
|
ref_query, ref_key = rot.to(device="cpu").forward_native(
|
||||||
|
positions, query, key)
|
||||||
|
out_query, out_key = rot.to(device=device).forward_neuron(
|
||||||
|
positions.to(device=device), query.to(device=device),
|
||||||
|
key.to(device=device))
|
||||||
|
assert out_query.is_xla and out_key.is_xla, \
|
||||||
|
"output tensor is expected to be XLA tensor"
|
||||||
|
torch.testing.assert_close(out_query.cpu(),
|
||||||
|
ref_query,
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)
|
@ -59,6 +59,11 @@ class CustomOp(nn.Module):
|
|||||||
# PyTorch-native implementation.
|
# PyTorch-native implementation.
|
||||||
return self.forward_native(*args, **kwargs)
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_neuron(self, *args, **kwargs):
|
||||||
|
# By default, we assume that Neuron ops are compatible with the
|
||||||
|
# PyTorch-native implementation.
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
def forward_oot(self, *args, **kwargs):
|
def forward_oot(self, *args, **kwargs):
|
||||||
# By default, we assume that OOT ops are compatible with the
|
# By default, we assume that OOT ops are compatible with the
|
||||||
# PyTorch-native implementation.
|
# PyTorch-native implementation.
|
||||||
@ -88,6 +93,8 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_tpu
|
return self.forward_tpu
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
return self.forward_xpu
|
return self.forward_xpu
|
||||||
|
elif current_platform.is_neuron():
|
||||||
|
return self.forward_neuron
|
||||||
elif current_platform.is_out_of_tree():
|
elif current_platform.is_out_of_tree():
|
||||||
return self.forward_oot
|
return self.forward_oot
|
||||||
else:
|
else:
|
||||||
|
@ -89,6 +89,13 @@ class SiluAndMul(CustomOp):
|
|||||||
self.op(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
x_reshaped = x.view(-1, x.shape[-1])
|
||||||
|
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
|
||||||
|
result = s * x_reshaped[:, d:]
|
||||||
|
return result.view(*x.shape[:-1], d)
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("mul_and_silu")
|
@CustomOp.register("mul_and_silu")
|
||||||
class MulAndSilu(CustomOp):
|
class MulAndSilu(CustomOp):
|
||||||
|
@ -53,6 +53,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
# Whether to use gather or all-gather to gather the logits.
|
# Whether to use gather or all-gather to gather the logits.
|
||||||
parallel_config = get_current_vllm_config().parallel_config
|
parallel_config = get_current_vllm_config().parallel_config
|
||||||
self.use_all_gather = current_platform.is_tpu() \
|
self.use_all_gather = current_platform.is_tpu() \
|
||||||
|
or current_platform.is_neuron() \
|
||||||
or envs.VLLM_USE_V1 \
|
or envs.VLLM_USE_V1 \
|
||||||
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
|
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
|
||||||
|
|
||||||
|
@ -254,6 +254,82 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_neuron(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
def _apply_rotary_emb_neuron(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||||
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||||
|
if is_neox_style:
|
||||||
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
|
else:
|
||||||
|
# x1 = x[..., ::2]
|
||||||
|
|
||||||
|
# x2 = x[..., 1::2]
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
x_reshaped = x.view(-1, x.shape[-1])
|
||||||
|
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
|
||||||
|
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
|
||||||
|
o1 = x1 * cos - x2 * sin
|
||||||
|
o2 = x2 * cos + x1 * sin
|
||||||
|
if is_neox_style:
|
||||||
|
return torch.cat((o1, o2), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
if offsets is not None:
|
||||||
|
positions = positions + offsets
|
||||||
|
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||||
|
dtype=query.dtype)
|
||||||
|
|
||||||
|
positions = positions.flatten()
|
||||||
|
num_tokens = positions.shape[0]
|
||||||
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
|
||||||
|
if self.rotary_dim == self.head_size:
|
||||||
|
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
|
||||||
|
query = query.reshape(query_shape)
|
||||||
|
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
|
||||||
|
key = key.reshape(key_shape)
|
||||||
|
else:
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
query_reshaped = query.view(-1, head_size)
|
||||||
|
query_pass = query_reshaped[:, self.rotary_dim:].view(
|
||||||
|
*query.shape[:-1], head_size - self.rotary_dim)
|
||||||
|
query_rot = query_reshaped[:, :self.rotary_dim].view(
|
||||||
|
*query.shape[:-1], self.rotary_dim)
|
||||||
|
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
|
query = torch.cat((query_rot, query_pass),
|
||||||
|
dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
|
key_reshaped = key.view(-1, head_size)
|
||||||
|
key_pass = key_reshaped[:, self.rotary_dim:].view(
|
||||||
|
*key.shape[:-1], head_size - self.rotary_dim)
|
||||||
|
key_rot = key_reshaped[:, :self.rotary_dim].view(
|
||||||
|
*key.shape[:-1], self.rotary_dim)
|
||||||
|
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user