[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (#16198)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: Hongxia Yang <hongxia.yang@amd.com> Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
parent
102bf967f0
commit
2976dc27e9
@ -471,7 +471,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
self.logits_soft_cap = 0.0
|
||||
|
@ -1002,6 +1002,7 @@ direct_register_custom_op(
|
||||
op_func=inplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=inplace_fused_experts_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
@ -1060,6 +1061,7 @@ direct_register_custom_op(
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, Type, TypeVar, Union, cast, overload)
|
||||
Optional, Tuple, Type, TypeVar, Union, cast, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import cachetools
|
||||
@ -1935,12 +1935,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
tags: Tuple[torch.Tag, ...] = (),
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
@ -1979,7 +1980,7 @@ def direct_register_custom_op(
|
||||
import torch._custom_op.impl
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.define(op_name + schema_str, tags=tags)
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
Loading…
x
Reference in New Issue
Block a user