vllm/tests/model_executor/test_enabled_custom_ops.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

92 lines
3.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
# Registered subclass for test
@CustomOp.register("relu3")
class Relu3(ReLUSquaredActivation):
pass
@pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on",
[
# Default values based on compile level
("", 0, [True] * 4, True),
("", 1, [True] * 4, True),
("", 2, [True] * 4, True), # All by default
("", 3, [False] * 4, False),
("", 4, [False] * 4, False), # None by default
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
# All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=torch_level, custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
ops_enabled = [bool(x) for x in ops_enabled]
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()