[torch.compile] directly register custom op (#9896)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-10-31 21:56:09 -07:00 committed by GitHub
parent 031a7995f3
commit 96e0c9cbbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 192 additions and 67 deletions

View File

@ -6,18 +6,22 @@ import os
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.utils import direct_register_custom_op
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out[0] += 1
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@support_torch_compile
class SillyModel(nn.Module):

View File

@ -8,6 +8,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
@ -15,9 +16,12 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out += v
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@dataclass
class LlamaConfig:
hidden_size: int = 128

View File

@ -14,7 +14,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -595,8 +596,6 @@ class FlashAttentionImpl(AttentionImpl):
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
@ -755,8 +754,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -773,3 +771,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)

View File

@ -28,8 +28,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
get_kv_cache_torch_dtype, make_tensor_with_pad)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -785,8 +785,6 @@ class FlashInferImpl(AttentionImpl):
)
@torch.library.custom_op("vllm::unified_flash_infer",
mutates_args=["kv_cache"])
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
@ -906,8 +904,7 @@ def unified_flash_infer(
return output.view(num_tokens, hidden_size)
@unified_flash_infer.register_fake
def _(
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -924,3 +921,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)

View File

@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from vllm.utils import direct_register_custom_op, supports_custom_op
@dataclass
@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None:
if supports_custom_op():
@torch.library.custom_op("vllm::inplace_all_reduce",
mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
@ -108,11 +106,16 @@ if supports_custom_op():
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)
@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)
def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
@ -121,10 +124,17 @@ if supports_custom_op():
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)
class GroupCoordinator:
"""
@ -338,6 +348,11 @@ class GroupCoordinator:
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
@ -369,9 +384,6 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)

View File

@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
@ -119,8 +119,7 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@single_marlin_moe.register_fake
def _(
def single_marlin_moe_fake(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
@ -136,7 +135,14 @@ def _(
return torch.empty_like(hidden_states)
@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
direct_register_custom_op(
op_name="single_marlin_moe",
op_func=single_marlin_moe,
mutates_args=[],
fake_impl=single_marlin_moe_fake,
)
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -324,8 +330,7 @@ def fused_marlin_moe(
dim=1)
@fused_marlin_moe.register_fake
def _(
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -344,3 +349,11 @@ def _(
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)

View File

@ -12,6 +12,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype,
return None
@torch.library.custom_op("vllm::inplace_fused_experts",
mutates_args=["hidden_states"])
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale, a2_scale)
@inplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
pass
@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
direct_register_custom_op(
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -517,21 +523,29 @@ def outplace_fused_experts(
w2_scale, a1_scale, a2_scale)
@outplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,

View File

@ -32,6 +32,7 @@ import torch
import torch.types
import yaml
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
@ -1512,3 +1513,47 @@ def weak_ref_tensors(
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
return isinstance(torch, _MockModule)
except ModuleNotFoundError:
return False
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if is_in_doc_build():
return
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)

View File

@ -7,6 +7,7 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func
@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
@ -217,8 +216,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -235,3 +233,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)