diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 1befcb6b..f240074f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -237,14 +237,20 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False + +from vllm.attention.ops.triton_flash_attention import triton_attention if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +is_hip = current_platform.is_rocm() + class MLACommonBackend(AttentionBackend): @@ -1046,12 +1052,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.vllm_flash_attn_version = get_flash_attn_version() + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, @@ -1315,18 +1322,48 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): [0, q.shape[-1] - v.shape[-1]], value=0) - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + attn_output, attn_softmax_lse = self.triton_fa_func( + q, + k, + v_padded, + None, + prefill_metadata.query_start_loc, + prefill_metadata.context_chunk_cu_seq_lens[i], + prefill_metadata.max_query_len, + prefill_metadata.context_chunk_max_seq_lens[i], + False, # causal + self.scale, + None, # attn_mask is None unless applying ALiBi mask + ) + elif is_vllm_fa: + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + else: + attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_attn_probs=True, + ) if output is None: output = attn_output @@ -1374,11 +1411,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - if has_context: - if not current_platform.is_cuda(): - raise NotImplementedError( - "Chunked Prefill for MLA is not currently supported on" - "non-cuda platforms") + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + output = self.triton_fa_func( + q, + k, + v_padded, + None, + prefill_metadata.query_start_loc, + prefill_metadata.query_start_loc, + prefill_metadata.max_prefill_seq_len, + prefill_metadata.max_prefill_seq_len, + True, # causal + self.scale, + None, # attn_mask is None unless applying ALiBi mask + ) + ## triton flash attention always return 2 objects + if not has_context: + output = output[0] + elif is_vllm_fa: output = self.flash_attn_varlen_func( q=q, k=k, @@ -1389,7 +1439,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, - return_softmax_lse=True, + return_softmax_lse=has_context, ) else: output = self.flash_attn_varlen_func( @@ -1402,10 +1452,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, + return_attn_probs=has_context, ) if has_context: - suffix_output, suffix_lse = output + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse, *rest = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 057fccb5..40daec3e 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -178,7 +178,8 @@ def _decode_att_m_fwd( page_size, logit_cap, ): - BLOCK = 64 + BLOCK = 64 if not is_hip_ else 8 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -188,7 +189,9 @@ def _decode_att_m_fwd( grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[-2] - num_warps = 4 if kv_group_num == 1 else 2 + num_warps = 4 + if kv_group_num != 1: + num_warps = 1 if is_hip_ else 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: - # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = { - "waves_per_eu": 4, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2 } + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, @@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd( PAGE_SIZE=page_size, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 00000000..2b1167fc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +}