[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)
This commit is contained in:
parent
7871659abb
commit
0f41fbe5a3
92
tests/model_executor/test_enabled_custom_ops.py
Normal file
92
tests/model_executor/test_enabled_custom_ops.py
Normal file
@ -0,0 +1,92 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
@CustomOp.register("relu3")
|
||||
class Relu3(ReLUSquaredActivation):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env, torch_level, ops_enabled, default_on",
|
||||
[
|
||||
# Default values based on compile level
|
||||
("", 0, [True] * 4, True),
|
||||
("", 1, [True] * 4, True),
|
||||
("", 2, [True] * 4, True), # All by default
|
||||
("", 3, [False] * 4, False),
|
||||
("", 4, [False] * 4, False), # None by default
|
||||
# Explicitly enabling/disabling
|
||||
#
|
||||
# Default: all
|
||||
#
|
||||
# All but SiluAndMul
|
||||
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
|
||||
# Only ReLU3
|
||||
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
|
||||
# All but SiluAndMul
|
||||
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
|
||||
# All but ReLU3 (even if ReLU2 is on)
|
||||
("-relu3,relu2", 1, [1, 1, 1, 0], True),
|
||||
# GeluAndMul and SiluAndMul
|
||||
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
|
||||
# All but RMSNorm
|
||||
("-rms_norm", 2, [0, 1, 1, 1], True),
|
||||
#
|
||||
# Default: none
|
||||
#
|
||||
# Only ReLU3
|
||||
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
|
||||
# All but RMSNorm
|
||||
("all,-rms_norm", 4, [0, 1, 1, 1], True),
|
||||
])
|
||||
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
||||
default_on: bool):
|
||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
||||
|
||||
# Reset default_on (computed once):
|
||||
CustomOp.default_on.cache_clear()
|
||||
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
ops_enabled = [bool(x) for x in ops_enabled]
|
||||
|
||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||
|
||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||
|
||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||
|
||||
# If registered, subclasses should follow their own name
|
||||
assert Relu3().enabled() == ops_enabled[3]
|
||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||
|
||||
# Unregistered subclass
|
||||
class SiluAndMul2(SiluAndMul):
|
||||
pass
|
||||
|
||||
# Subclasses should not require registration
|
||||
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
||||
def test_enabled_ops_invalid(env: str):
|
||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||
CustomOp.default_on.cache_clear()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
RMSNorm(1024).enabled()
|
13
vllm/envs.py
13
vllm/envs.py
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||
VLLM_CUSTOM_OPS: List[str] = []
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
|
||||
|
||||
@ -205,7 +206,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||
"VLLM_TORCH_COMPILE_LEVEL":
|
||||
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
||||
|
||||
# Fine-grained control over which custom ops to enable/disable.
|
||||
# Use 'all' to enable all, 'none' to disable all.
|
||||
# Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||
# or disable (prefixed with a '-').
|
||||
# Examples:
|
||||
# - 'all,-op1' to enable all except op1
|
||||
# - 'none,+op1,+op2' to enable only op1 and op2
|
||||
# By default, all custom ops are enabled when running without Inductor
|
||||
# and disabled when running with Inductor (compile_level >= Inductor).
|
||||
"VLLM_CUSTOM_OPS":
|
||||
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
"LOCAL_RANK":
|
||||
|
@ -1,14 +1,24 @@
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_xpu
|
||||
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
"""
|
||||
Base class for custom ops.
|
||||
Dispatches the forward method to the appropriate backend.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
|
||||
@ -17,7 +27,6 @@ class CustomOp(nn.Module):
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
"""PyTorch-native implementation of the forward method.
|
||||
|
||||
This method is optional. If implemented, it can be used with compilers
|
||||
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
||||
purposes.
|
||||
@ -56,7 +65,11 @@ class CustomOp(nn.Module):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
|
||||
enabled = self.enabled()
|
||||
logger.debug("custom op %s %s", self.__class__.name,
|
||||
"enabled" if enabled else "disabled")
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
|
||||
if is_hip():
|
||||
@ -69,3 +82,50 @@ class CustomOp(nn.Module):
|
||||
return self.forward_xpu
|
||||
else:
|
||||
return self.forward_cuda
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
# if no name, then it was not registered
|
||||
if not hasattr(cls, "name"):
|
||||
print_warning_once(
|
||||
f"Custom op {cls.__name__} was not registered, "
|
||||
f"which means it won't appear in the op registry. "
|
||||
f"It will be enabled/disabled based on the global settings.")
|
||||
return CustomOp.default_on()
|
||||
|
||||
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
|
||||
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
|
||||
assert not (enabled
|
||||
and disabled), f"Cannot enable and disable {cls.name}"
|
||||
|
||||
return (CustomOp.default_on() or enabled) and not disabled
|
||||
|
||||
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
|
||||
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
|
||||
@staticmethod
|
||||
@lru_cache()
|
||||
def default_on() -> bool:
|
||||
count_none = envs.VLLM_CUSTOM_OPS.count("none")
|
||||
count_all = envs.VLLM_CUSTOM_OPS.count("all")
|
||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
|
||||
not count_none > 0 or count_all > 0
|
||||
|
||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||
# Examples:
|
||||
# - MyOp.enabled()
|
||||
# - op_registry["my_op"].enabled()
|
||||
op_registry: Dict[str, Type['CustomOp']] = {}
|
||||
|
||||
# Decorator to register custom ops.
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
|
||||
def decorator(op_cls):
|
||||
assert name not in cls.op_registry, f"Duplicate op name: {name}"
|
||||
op_cls.name = name
|
||||
cls.op_registry[name] = op_cls
|
||||
return op_cls
|
||||
|
||||
return decorator
|
||||
|
@ -11,8 +11,10 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -144,6 +149,7 @@ class NewGELU(CustomOp):
|
||||
return ops.gelu_new(x)
|
||||
|
||||
|
||||
@CustomOp.register("gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -164,8 +170,8 @@ class FastGELU(CustomOp):
|
||||
return ops.gelu_fast(x)
|
||||
|
||||
|
||||
@CustomOp.register("quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
"relu2": ReLUSquaredActivation(),
|
||||
"quick_gelu": QuickGELU(),
|
||||
}
|
||||
_ACTIVATION_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: nn.GELU(),
|
||||
"gelu_fast":
|
||||
lambda: FastGELU(),
|
||||
"gelu_new":
|
||||
lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh":
|
||||
lambda: nn.GELU(approximate="tanh"),
|
||||
"relu":
|
||||
lambda: nn.ReLU(),
|
||||
"relu2":
|
||||
lambda: ReLUSquaredActivation(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
|
@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.forward(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
|
||||
@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
@CustomOp.register("rms_norm")
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
|
||||
@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("gemma_rms_norm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
|
@ -72,6 +72,7 @@ def _apply_rotary_emb(
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
self.long_factor = long_factor
|
||||
|
||||
scale = self.max_position_embeddings / \
|
||||
self.original_max_position_embeddings
|
||||
self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
|
@ -17,6 +17,7 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, ensure_future
|
||||
from collections.abc import Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
@ -1442,3 +1443,24 @@ class AtomicCounter:
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
class LazyDict(Mapping, Generic[T]):
|
||||
|
||||
def __init__(self, factory: Dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: Dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
self._dict[key] = self._factory[key]()
|
||||
return self._dict[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._factory)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._factory)
|
||||
|
Loading…
x
Reference in New Issue
Block a user