[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (#16198)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
TJian 2025-04-09 10:12:34 +08:00 committed by GitHub
parent 102bf967f0
commit 2976dc27e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 9 deletions

View File

@ -471,7 +471,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0

View File

@ -1002,6 +1002,7 @@ direct_register_custom_op(
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
@ -1060,6 +1061,7 @@ direct_register_custom_op(
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -40,7 +40,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union, cast, overload)
Optional, Tuple, Type, TypeVar, Union, cast, overload)
from uuid import uuid4
import cachetools
@ -1935,12 +1935,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: Tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
@ -1979,7 +1980,7 @@ def direct_register_custom_op(
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)