[ROCm] [Attention] Cleanup ROCm output passing (#16431)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
Luka Govedič 2025-04-18 01:46:45 -04:00 committed by GitHub
parent 7bdfd29a35
commit aaec845f8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support "
"attention logits soft capping.")
self.attn_func = _sdpa_attention
self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func(
self.triton_attn_func(
query,
key,
value,
None,
output[:num_prefill_tokens],
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention
out = self.attn_func(
self.sdpa_attn_func(
query,
key,
value,
output[:num_prefill_tokens],
query_seq_start_loc,
num_prefill_tokens,
self.num_heads,
@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks,
)
else:
out = self.attn_func(
# upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query,
k=key,
v=value,
@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap,
)
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else:
# prefix-enabled attention -
# not applicable for encoder-only models
@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None
ops.paged_attention_rocm(
out,
output[num_prefill_tokens:],
exp_sums,
max_logits,
tmp_output,
@ -878,7 +872,8 @@ def _sdpa_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int,
num_heads: int,
head_size: int,
@ -886,9 +881,9 @@ def _sdpa_attention(
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)
assert output.shape == (num_tokens, num_heads, head_size)
assert output.dtype == query.dtype
assert output.device == query.device
for i, seq_len in enumerate(seq_lens):
end = start + seq_len