From aaec845f8ed7f445c66ba0d28c84bec9d184f5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 18 Apr 2025 01:46:45 -0400 Subject: [PATCH] [ROCm] [Attention] Cleanup ROCm output passing (#16431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/rocm_flash_attn.py | 41 ++++++++++------------ 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7376f930..90a21906 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256 class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) - self.attn_func = triton_attention + self.triton_attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") if self.sliding_window != (-1, -1): logger.warning("ROCm Triton FA does not currently support " @@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: try: from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func + self.fa_attn_func = flash_attn_varlen_func logger.debug("Using CK FA in ROCmBackend") except ModuleNotFoundError: self.use_naive_attn = True @@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): "ROCm Naive FlashAttention does not support " "attention logits soft capping.") - self.attn_func = _sdpa_attention + self.sdpa_attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is not None, "Output tensor must be provided." + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore - out, _ = self.attn_func( + self.triton_attn_func( query, key, value, - None, + output[:num_prefill_tokens], query_seq_start_loc, key_seq_start_loc, query_max_seq_len, @@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) # sdpa math backend attention - out = self.attn_func( + self.sdpa_attn_func( query, key, value, + output[:num_prefill_tokens], query_seq_start_loc, num_prefill_tokens, self.num_heads, @@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_masks, ) else: - out = self.attn_func( + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( q=query, k=key, v=value, @@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): softcap=self.logits_soft_cap, ) - # common code for prefill - assert output[:num_prefill_tokens].shape == out.shape - if output.shape[0] > num_prefill_tokens: - output[:num_prefill_tokens] = out - else: - output = out else: # prefix-enabled attention - # not applicable for encoder-only models @@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): device=output.device, ) max_logits = torch.empty_like(exp_sums) - if num_prefill_tokens > 0: - out = output[num_prefill_tokens:] - else: - out = output query_start_loc = None ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -878,7 +872,8 @@ def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - seq_lens: List[int], + output: torch.Tensor, + seq_lens: torch.Tensor, num_tokens: int, num_heads: int, head_size: int, @@ -886,9 +881,9 @@ def _sdpa_attention( attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 - output = torch.empty((num_tokens, num_heads, head_size), - dtype=query.dtype, - device=query.device) + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device for i, seq_len in enumerate(seq_lens): end = start + seq_len