[core] Perf improvement for DSv3 on AMD GPUs (#13718)

Signed-off-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
qli88 2025-02-27 16:14:30 -06:00 committed by GitHub
parent cd813c6d4d
commit 8294773e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 210 additions and 25 deletions

View File

@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
from vllm.attention.ops.triton_flash_attention import triton_attention
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
is_hip = current_platform.is_rocm()
class MLACommonBackend(AttentionBackend):
@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
[0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
attn_output, attn_softmax_lse = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.context_chunk_cu_seq_lens[i],
prefill_metadata.max_query_len,
prefill_metadata.context_chunk_max_seq_lens[i],
False, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
elif is_vllm_fa:
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
if output is None:
output = attn_output
@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
if has_context:
if not current_platform.is_cuda():
raise NotImplementedError(
"Chunked Prefill for MLA is not currently supported on"
"non-cuda platforms")
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
output = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
if has_context:
suffix_output, suffix_lse = output
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)

View File

@ -178,7 +178,8 @@ def _decode_att_m_fwd(
page_size,
logit_cap,
):
BLOCK = 64
BLOCK = 64 if not is_hip_ else 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
@ -188,7 +189,9 @@ def _decode_att_m_fwd(
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]
num_warps = 4 if kv_group_num == 1 else 2
num_warps = 4
if kv_group_num != 1:
num_warps = 1 if is_hip_ else 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
)
extra_kargs = {}
num_stages = 2
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
num_stages = 1
_fwd_grouped_kernel_stage1[grid](
q,
@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,

View File

@ -0,0 +1,128 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}