[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (#16198)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: Hongxia Yang <hongxia.yang@amd.com> Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
parent
102bf967f0
commit
2976dc27e9
@ -471,7 +471,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ROCmFlashAttention does not support blocksparse attention.")
|
"ROCmFlashAttention does not support blocksparse attention.")
|
||||||
|
if use_irope:
|
||||||
|
logger.warning(
|
||||||
|
"Using irope in V0 is not supported yet, it will fall back "
|
||||||
|
"to global attention for long context.")
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
self.logits_soft_cap = 0.0
|
self.logits_soft_cap = 0.0
|
||||||
|
@ -1002,6 +1002,7 @@ direct_register_custom_op(
|
|||||||
op_func=inplace_fused_experts,
|
op_func=inplace_fused_experts,
|
||||||
mutates_args=["hidden_states"],
|
mutates_args=["hidden_states"],
|
||||||
fake_impl=inplace_fused_experts_fake,
|
fake_impl=inplace_fused_experts_fake,
|
||||||
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1060,6 +1061,7 @@ direct_register_custom_op(
|
|||||||
op_func=outplace_fused_experts,
|
op_func=outplace_fused_experts,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=outplace_fused_experts_fake,
|
fake_impl=outplace_fused_experts_fake,
|
||||||
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ from dataclasses import dataclass, field
|
|||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, Type, TypeVar, Union, cast, overload)
|
Optional, Tuple, Type, TypeVar, Union, cast, overload)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import cachetools
|
import cachetools
|
||||||
@ -1941,6 +1941,7 @@ def direct_register_custom_op(
|
|||||||
fake_impl: Optional[Callable] = None,
|
fake_impl: Optional[Callable] = None,
|
||||||
target_lib: Optional[Library] = None,
|
target_lib: Optional[Library] = None,
|
||||||
dispatch_key: str = "CUDA",
|
dispatch_key: str = "CUDA",
|
||||||
|
tags: Tuple[torch.Tag, ...] = (),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
`torch.library.custom_op` can have significant overhead because it
|
`torch.library.custom_op` can have significant overhead because it
|
||||||
@ -1979,7 +1980,7 @@ def direct_register_custom_op(
|
|||||||
import torch._custom_op.impl
|
import torch._custom_op.impl
|
||||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||||
my_lib = target_lib or vllm_lib
|
my_lib = target_lib or vllm_lib
|
||||||
my_lib.define(op_name + schema_str)
|
my_lib.define(op_name + schema_str, tags=tags)
|
||||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||||
if fake_impl is not None:
|
if fake_impl is not None:
|
||||||
my_lib._register_fake(op_name, fake_impl)
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user