[torch.compile] directly register custom op (#9896)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
031a7995f3
commit
96e0c9cbbd
@ -6,18 +6,22 @@ import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
@torch.library.custom_op("silly::attention", mutates_args=["out"])
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out[0] += 1
|
||||
|
||||
|
||||
@silly_attention.register_fake
|
||||
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
@ -15,9 +16,12 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.plugins import set_compilation_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
@torch.library.custom_op("silly::attention", mutates_args=["out"])
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out += v
|
||||
|
||||
|
||||
@silly_attention.register_fake
|
||||
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaConfig:
|
||||
hidden_size: int = 128
|
||||
|
@ -14,7 +14,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -595,8 +596,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -755,8 +754,7 @@ def unified_flash_attention(
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_attention.register_fake
|
||||
def _(
|
||||
def unified_flash_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -773,3 +771,11 @@ def _(
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
@ -28,8 +28,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
get_kv_cache_torch_dtype, make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -785,8 +785,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_infer",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_infer(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -906,8 +904,7 @@ def unified_flash_infer(
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_infer.register_fake
|
||||
def _(
|
||||
def unified_flash_infer_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -924,3 +921,11 @@ def _(
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_infer",
|
||||
op_func=unified_flash_infer,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_infer_fake,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import supports_custom_op
|
||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None:
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
@torch.library.custom_op("vllm::inplace_all_reduce",
|
||||
mutates_args=["tensor"])
|
||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
@ -108,11 +106,16 @@ if supports_custom_op():
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
group._all_reduce_in_place(tensor)
|
||||
|
||||
@inplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> None:
|
||||
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
|
||||
return
|
||||
|
||||
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
|
||||
direct_register_custom_op(
|
||||
op_name="inplace_all_reduce",
|
||||
op_func=inplace_all_reduce,
|
||||
mutates_args=["tensor"],
|
||||
fake_impl=inplace_all_reduce_fake,
|
||||
)
|
||||
|
||||
def outplace_all_reduce(tensor: torch.Tensor,
|
||||
group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
@ -121,10 +124,17 @@ if supports_custom_op():
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce_out_place(tensor)
|
||||
|
||||
@outplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
def outplace_all_reduce_fake(tensor: torch.Tensor,
|
||||
group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_all_reduce",
|
||||
op_func=outplace_all_reduce,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_all_reduce_fake,
|
||||
)
|
||||
|
||||
|
||||
class GroupCoordinator:
|
||||
"""
|
||||
@ -338,6 +348,11 @@ class GroupCoordinator:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_.is_cpu:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
if not supports_custom_op():
|
||||
self._all_reduce_in_place(input_)
|
||||
return input_
|
||||
@ -369,9 +384,6 @@ class GroupCoordinator:
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
elif input_.is_cpu:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
|
||||
|
@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
|
||||
def single_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
@ -119,8 +119,7 @@ def single_marlin_moe(
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
|
||||
|
||||
@single_marlin_moe.register_fake
|
||||
def _(
|
||||
def single_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
@ -136,7 +135,14 @@ def _(
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
|
||||
direct_register_custom_op(
|
||||
op_name="single_marlin_moe",
|
||||
op_func=single_marlin_moe,
|
||||
mutates_args=[],
|
||||
fake_impl=single_marlin_moe_fake,
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -324,8 +330,7 @@ def fused_marlin_moe(
|
||||
dim=1)
|
||||
|
||||
|
||||
@fused_marlin_moe.register_fake
|
||||
def _(
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -344,3 +349,11 @@ def _(
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_marlin_moe",
|
||||
op_func=fused_marlin_moe,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_marlin_moe_fake,
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype,
|
||||
return None
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::inplace_fused_experts",
|
||||
mutates_args=["hidden_states"])
|
||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a1_scale, a2_scale)
|
||||
|
||||
|
||||
@inplace_fused_experts.register_fake
|
||||
def _(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||
def inplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
|
||||
direct_register_custom_op(
|
||||
op_name="inplace_fused_experts",
|
||||
op_func=inplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=inplace_fused_experts_fake,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -517,21 +523,29 @@ def outplace_fused_experts(
|
||||
w2_scale, a1_scale, a2_scale)
|
||||
|
||||
|
||||
@outplace_fused_experts.register_fake
|
||||
def _(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
|
@ -32,6 +32,7 @@ import torch
|
||||
import torch.types
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -1512,3 +1513,47 @@ def weak_ref_tensors(
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
return isinstance(torch, _MockModule)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
needs to consider complicated dispatching logic. This function
|
||||
directly registers a custom op and dispatches it to the CUDA backend.
|
||||
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
||||
for more details.
|
||||
|
||||
By default, the custom op is registered to the vLLM library. If you
|
||||
want to register it to a different library, you can pass the library
|
||||
object to the `target_lib` argument.
|
||||
|
||||
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||
library object. If you want to bind the operator to a different library,
|
||||
make sure the library object is alive when the operator is used.
|
||||
"""
|
||||
if is_in_doc_build():
|
||||
return
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -217,8 +216,7 @@ def unified_flash_attention(
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_attention.register_fake
|
||||
def _(
|
||||
def unified_flash_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -235,3 +233,11 @@ def _(
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user