[core] Perf improvement for DSv3 on AMD GPUs (#13718)
Signed-off-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
parent
cd813c6d4d
commit
8294773e48
@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
is_hip = current_platform.is_rocm()
|
||||
|
||||
|
||||
class MLACommonBackend(AttentionBackend):
|
||||
|
||||
@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.triton_fa_func = triton_attention
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
[0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
||||
attn_output, attn_softmax_lse = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
v_padded,
|
||||
None,
|
||||
prefill_metadata.query_start_loc,
|
||||
prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
prefill_metadata.max_query_len,
|
||||
prefill_metadata.context_chunk_max_seq_lens[i],
|
||||
False, # causal
|
||||
self.scale,
|
||||
None, # attn_mask is None unless applying ALiBi mask
|
||||
)
|
||||
elif is_vllm_fa:
|
||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.
|
||||
context_chunk_max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.
|
||||
context_chunk_max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_attn_probs=True,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = attn_output
|
||||
@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
if has_context:
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError(
|
||||
"Chunked Prefill for MLA is not currently supported on"
|
||||
"non-cuda platforms")
|
||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
||||
output = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
v_padded,
|
||||
None,
|
||||
prefill_metadata.query_start_loc,
|
||||
prefill_metadata.query_start_loc,
|
||||
prefill_metadata.max_prefill_seq_len,
|
||||
prefill_metadata.max_prefill_seq_len,
|
||||
True, # causal
|
||||
self.scale,
|
||||
None, # attn_mask is None unless applying ALiBi mask
|
||||
)
|
||||
## triton flash attention always return 2 objects
|
||||
if not has_context:
|
||||
output = output[0]
|
||||
elif is_vllm_fa:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=True,
|
||||
return_softmax_lse=has_context,
|
||||
)
|
||||
else:
|
||||
output = self.flash_attn_varlen_func(
|
||||
@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_attn_probs=has_context,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||
suffix_output, suffix_lse, *rest = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
|
||||
|
@ -178,7 +178,8 @@ def _decode_att_m_fwd(
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 64
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
@ -188,7 +189,9 @@ def _decode_att_m_fwd(
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4 if kv_group_num == 1 else 2
|
||||
num_warps = 4
|
||||
if kv_group_num != 1:
|
||||
num_warps = 1 if is_hip_ else 2
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
|
||||
)
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 4,
|
||||
"waves_per_eu": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
num_stages = 1
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
|
@ -0,0 +1,128 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user