[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)

This commit is contained in:
Luka Govedič 2024-10-17 14:36:37 -04:00 committed by GitHub
parent 7871659abb
commit 0f41fbe5a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 220 additions and 21 deletions

View File

@ -0,0 +1,92 @@
import os
from typing import List
import pytest
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
# Registered subclass for test
@CustomOp.register("relu3")
class Relu3(ReLUSquaredActivation):
pass
@pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on",
[
# Default values based on compile level
("", 0, [True] * 4, True),
("", 1, [True] * 4, True),
("", 2, [True] * 4, True), # All by default
("", 3, [False] * 4, False),
("", 4, [False] * 4, False), # None by default
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
# All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
assert CustomOp.default_on() == default_on
ops_enabled = [bool(x) for x in ops_enabled]
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()
with pytest.raises(AssertionError):
RMSNorm(1024).enabled()

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
@ -205,7 +206,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":

View File

@ -1,14 +1,24 @@
from functools import lru_cache
from typing import Dict, Type
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
logger = init_logger(__name__)
class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __init__(self, *args, **kwargs):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
@ -17,7 +27,6 @@ class CustomOp(nn.Module):
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
@ -56,7 +65,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if not enabled:
return self.forward_native
if is_hip():
@ -69,3 +82,50 @@ class CustomOp(nn.Module):
return self.forward_xpu
else:
return self.forward_cuda
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache()
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}
# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls
return decorator

View File

@ -11,8 +11,10 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
return self.forward_native(x)
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
return out
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
return f'approximate={repr(self.approximate)}'
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@ -144,6 +149,7 @@ class NewGELU(CustomOp):
return ops.gelu_new(x)
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@ -164,8 +170,8 @@ class FastGELU(CustomOp):
return ops.gelu_fast(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"quick_gelu":
lambda: QuickGELU(),
})
def get_act_fn(

View File

@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
return s
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.

View File

@ -72,6 +72,7 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.long_factor = long_factor
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:

View File

@ -17,6 +17,7 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from collections.abc import Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
@ -1442,3 +1443,24 @@ class AtomicCounter:
@property
def value(self):
return self._value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)