[core] Perf improvement for DSv3 on AMD GPUs (#13718)
Signed-off-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
parent
cd813c6d4d
commit
8294773e48
@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
is_vllm_fa = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# For rocm use upstream flash attention
|
# For rocm use upstream flash attention
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
is_vllm_fa = False
|
||||||
|
|
||||||
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
is_hip = current_platform.is_rocm()
|
||||||
|
|
||||||
|
|
||||||
class MLACommonBackend(AttentionBackend):
|
class MLACommonBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
self.q_proj = q_proj
|
self.q_proj = q_proj
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.triton_fa_func = triton_attention
|
||||||
|
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
# latter has an additional parameter to control FA2 vs FA3
|
# latter has an additional parameter to control FA2 vs FA3
|
||||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
if self.vllm_flash_attn_version is not None:
|
if self.vllm_flash_attn_version is not None:
|
||||||
self.flash_attn_varlen_func = \
|
self.flash_attn_varlen_func = \
|
||||||
functools.partial(flash_attn_varlen_func,
|
functools.partial(flash_attn_varlen_func,
|
||||||
@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
[0, q.shape[-1] - v.shape[-1]],
|
[0, q.shape[-1] - v.shape[-1]],
|
||||||
value=0)
|
value=0)
|
||||||
|
|
||||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
||||||
q=q,
|
attn_output, attn_softmax_lse = self.triton_fa_func(
|
||||||
k=k,
|
q,
|
||||||
v=v_padded,
|
k,
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
v_padded,
|
||||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
None,
|
||||||
max_seqlen_q=prefill_metadata.max_query_len,
|
prefill_metadata.query_start_loc,
|
||||||
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
|
prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||||
softmax_scale=self.scale,
|
prefill_metadata.max_query_len,
|
||||||
causal=False, # Context is unmasked
|
prefill_metadata.context_chunk_max_seq_lens[i],
|
||||||
return_softmax_lse=True,
|
False, # causal
|
||||||
)
|
self.scale,
|
||||||
|
None, # attn_mask is None unless applying ALiBi mask
|
||||||
|
)
|
||||||
|
elif is_vllm_fa:
|
||||||
|
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_padded,
|
||||||
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
|
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||||
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
|
max_seqlen_k=prefill_metadata.
|
||||||
|
context_chunk_max_seq_lens[i],
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False, # Context is unmasked
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_padded,
|
||||||
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
|
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||||
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
|
max_seqlen_k=prefill_metadata.
|
||||||
|
context_chunk_max_seq_lens[i],
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False, # Context is unmasked
|
||||||
|
return_attn_probs=True,
|
||||||
|
)
|
||||||
|
|
||||||
if output is None:
|
if output is None:
|
||||||
output = attn_output
|
output = attn_output
|
||||||
@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||||
value=0)
|
value=0)
|
||||||
|
|
||||||
if has_context:
|
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
||||||
if not current_platform.is_cuda():
|
output = self.triton_fa_func(
|
||||||
raise NotImplementedError(
|
q,
|
||||||
"Chunked Prefill for MLA is not currently supported on"
|
k,
|
||||||
"non-cuda platforms")
|
v_padded,
|
||||||
|
None,
|
||||||
|
prefill_metadata.query_start_loc,
|
||||||
|
prefill_metadata.query_start_loc,
|
||||||
|
prefill_metadata.max_prefill_seq_len,
|
||||||
|
prefill_metadata.max_prefill_seq_len,
|
||||||
|
True, # causal
|
||||||
|
self.scale,
|
||||||
|
None, # attn_mask is None unless applying ALiBi mask
|
||||||
|
)
|
||||||
|
## triton flash attention always return 2 objects
|
||||||
|
if not has_context:
|
||||||
|
output = output[0]
|
||||||
|
elif is_vllm_fa:
|
||||||
output = self.flash_attn_varlen_func(
|
output = self.flash_attn_varlen_func(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=has_context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self.flash_attn_varlen_func(
|
output = self.flash_attn_varlen_func(
|
||||||
@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
return_attn_probs=has_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
suffix_output, suffix_lse = output
|
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||||
|
suffix_output, suffix_lse, *rest = output
|
||||||
context_output, context_lse = self._compute_prefill_context( \
|
context_output, context_lse = self._compute_prefill_context( \
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||||
|
|
||||||
|
@ -178,7 +178,8 @@ def _decode_att_m_fwd(
|
|||||||
page_size,
|
page_size,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
BLOCK = 64
|
BLOCK = 64 if not is_hip_ else 8
|
||||||
|
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
@ -188,7 +189,9 @@ def _decode_att_m_fwd(
|
|||||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||||
|
|
||||||
num_warps = 4 if kv_group_num == 1 else 2
|
num_warps = 4
|
||||||
|
if kv_group_num != 1:
|
||||||
|
num_warps = 1 if is_hip_ else 2
|
||||||
|
|
||||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
|
num_stages = 2
|
||||||
if is_hip_:
|
if is_hip_:
|
||||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
extra_kargs = {
|
extra_kargs = {
|
||||||
"waves_per_eu": 4,
|
"waves_per_eu": 1,
|
||||||
"matrix_instr_nonkdim": 16,
|
"matrix_instr_nonkdim": 16,
|
||||||
"kpack": 2
|
"kpack": 2
|
||||||
}
|
}
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage1[grid](
|
_fwd_grouped_kernel_stage1[grid](
|
||||||
q,
|
q,
|
||||||
@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
PAGE_SIZE=page_size,
|
PAGE_SIZE=page_size,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=2,
|
num_stages=num_stages,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
|
@ -0,0 +1,128 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 16,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 2,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 4,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2,
|
||||||
|
"waves_per_eu": 0
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user