[ROCm] [Attention] Cleanup ROCm output passing (#16431)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
parent
7bdfd29a35
commit
aaec845f8e
@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
triton_attention)
|
||||
self.attn_func = triton_attention
|
||||
self.triton_attn_func = triton_attention
|
||||
logger.debug("Using Triton FA in ROCmBackend")
|
||||
if self.sliding_window != (-1, -1):
|
||||
logger.warning("ROCm Triton FA does not currently support "
|
||||
@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
self.attn_func = flash_attn_varlen_func
|
||||
self.fa_attn_func = flash_attn_varlen_func
|
||||
logger.debug("Using CK FA in ROCmBackend")
|
||||
except ModuleNotFoundError:
|
||||
self.use_naive_attn = True
|
||||
@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"ROCm Naive FlashAttention does not support "
|
||||
"attention logits soft capping.")
|
||||
|
||||
self.attn_func = _sdpa_attention
|
||||
self.sdpa_attn_func = _sdpa_attention
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
self.triton_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
output[:num_prefill_tokens],
|
||||
query_seq_start_loc,
|
||||
key_seq_start_loc,
|
||||
query_max_seq_len,
|
||||
@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
# sdpa math backend attention
|
||||
out = self.attn_func(
|
||||
self.sdpa_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output[:num_prefill_tokens],
|
||||
query_seq_start_loc,
|
||||
num_prefill_tokens,
|
||||
self.num_heads,
|
||||
@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
out = self.attn_func(
|
||||
# upstream FA does not support an output arg, copy
|
||||
output[:num_prefill_tokens] = self.fa_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
# common code for prefill
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
if output.shape[0] > num_prefill_tokens:
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output = out
|
||||
else:
|
||||
# prefix-enabled attention -
|
||||
# not applicable for encoder-only models
|
||||
@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if num_prefill_tokens > 0:
|
||||
out = output[num_prefill_tokens:]
|
||||
else:
|
||||
out = output
|
||||
|
||||
query_start_loc = None
|
||||
ops.paged_attention_rocm(
|
||||
out,
|
||||
output[num_prefill_tokens:],
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
@ -878,7 +872,8 @@ def _sdpa_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
seq_lens: List[int],
|
||||
output: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -886,9 +881,9 @@ def _sdpa_attention(
|
||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
start = 0
|
||||
output = torch.empty((num_tokens, num_heads, head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
assert output.shape == (num_tokens, num_heads, head_size)
|
||||
assert output.dtype == query.dtype
|
||||
assert output.device == query.device
|
||||
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
|
Loading…
x
Reference in New Issue
Block a user