[ROCm] [Attention] Cleanup ROCm output passing (#16431)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
Luka Govedič 2025-04-18 01:46:45 -04:00 committed by GitHub
parent 7bdfd29a35
commit aaec845f8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention) triton_attention)
self.attn_func = triton_attention self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1): if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support " logger.warning("ROCm Triton FA does not currently support "
@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend") logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support " "ROCm Naive FlashAttention does not support "
"attention logits soft capping.") "attention logits soft capping.")
self.attn_func = _sdpa_attention self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend") logger.debug("Using naive (SDPA) attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert attn_metadata.num_encoder_tokens is not None assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
seq_lens, seq_lens,
make_attn_mask=causal_mask) # type: ignore make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func( self.triton_attn_func(
query, query,
key, key,
value, value,
None, output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
key_seq_start_loc, key_seq_start_loc,
query_max_seq_len, query_max_seq_len,
@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention # sdpa math backend attention
out = self.attn_func( self.sdpa_attn_func(
query, query,
key, key,
value, value,
output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
num_prefill_tokens, num_prefill_tokens,
self.num_heads, self.num_heads,
@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
out = self.attn_func( # upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None query_start_loc = None
ops.paged_attention_rocm( ops.paged_attention_rocm(
out, output[num_prefill_tokens:],
exp_sums, exp_sums,
max_logits, max_logits,
tmp_output, tmp_output,
@ -878,7 +872,8 @@ def _sdpa_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
seq_lens: List[int], output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
@ -886,9 +881,9 @@ def _sdpa_attention(
attn_masks: Optional[List[torch.Tensor]] = None, attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
start = 0 start = 0
output = torch.empty((num_tokens, num_heads, head_size), assert output.shape == (num_tokens, num_heads, head_size)
dtype=query.dtype, assert output.dtype == query.dtype
device=query.device) assert output.device == query.device
for i, seq_len in enumerate(seq_lens): for i, seq_len in enumerate(seq_lens):
end = start + seq_len end = start + seq_len