
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import random
|
|
from typing import Type
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.utils import opcheck
|
|
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
|
GeluAndMul, MulAndSilu,
|
|
NewGELU, QuickGELU,
|
|
SiluAndMul)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .allclose_default import get_default_atol, get_default_rtol
|
|
|
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
|
D = [512, 13824] # Arbitrary values for testing
|
|
SEEDS = [0]
|
|
CUDA_DEVICES = [
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"activation",
|
|
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
@pytest.mark.parametrize("d", D)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@torch.inference_mode()
|
|
def test_act_and_mul(
|
|
activation: str,
|
|
num_tokens: int,
|
|
d: int,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
device: str,
|
|
) -> None:
|
|
current_platform.seed_everything(seed)
|
|
torch.set_default_device(device)
|
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
|
if activation == "silu_and_mul":
|
|
layer = SiluAndMul()
|
|
fn = torch.ops._C.silu_and_mul
|
|
if activation == "mul_and_silu":
|
|
layer = MulAndSilu()
|
|
fn = torch.ops._C.mul_and_silu
|
|
elif activation == "gelu":
|
|
layer = GeluAndMul(approximate="none")
|
|
fn = torch.ops._C.gelu_and_mul
|
|
elif activation == "gelu_tanh":
|
|
layer = GeluAndMul(approximate="tanh")
|
|
fn = torch.ops._C.gelu_tanh_and_mul
|
|
elif activation == "fatrelu":
|
|
threshold = random.uniform(0, 1)
|
|
layer = FatreluAndMul(threshold)
|
|
fn = torch.ops._C.fatrelu_and_mul
|
|
out = layer(x)
|
|
ref_out = layer.forward_native(x)
|
|
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
|
# equivalent to the native PyTorch implementations, so we can do exact
|
|
# comparison.
|
|
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
|
|
|
d = x.shape[-1] // 2
|
|
output_shape = (x.shape[:-1] + (d, ))
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
if activation == "fatrelu":
|
|
opcheck(fn, (out, x, threshold))
|
|
else:
|
|
opcheck(fn, (out, x))
|
|
|
|
|
|
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
|
|
(NewGELU, torch.ops._C.gelu_new),
|
|
(QuickGELU, torch.ops._C.gelu_quick)])
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
@pytest.mark.parametrize("d", D)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@torch.inference_mode()
|
|
def test_activation(
|
|
activation: Type[torch.nn.Module],
|
|
num_tokens: int,
|
|
d: int,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
device: str,
|
|
) -> None:
|
|
current_platform.seed_everything(seed)
|
|
torch.set_default_device(device)
|
|
x = torch.randn(num_tokens, d, dtype=dtype)
|
|
layer = activation[0]()
|
|
fn = activation[1]
|
|
out = layer(x)
|
|
ref_out = layer.forward_native(x)
|
|
torch.testing.assert_close(out,
|
|
ref_out,
|
|
atol=get_default_atol(out),
|
|
rtol=get_default_rtol(out))
|
|
|
|
out = torch.empty_like(x)
|
|
opcheck(fn, (out, x))
|