[ROCm] [Attention] Cleanup ROCm output passing (#16431)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
parent
7bdfd29a35
commit
aaec845f8e
@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
|
|||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||||
triton_attention)
|
triton_attention)
|
||||||
self.attn_func = triton_attention
|
self.triton_attn_func = triton_attention
|
||||||
logger.debug("Using Triton FA in ROCmBackend")
|
logger.debug("Using Triton FA in ROCmBackend")
|
||||||
if self.sliding_window != (-1, -1):
|
if self.sliding_window != (-1, -1):
|
||||||
logger.warning("ROCm Triton FA does not currently support "
|
logger.warning("ROCm Triton FA does not currently support "
|
||||||
@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
self.attn_func = flash_attn_varlen_func
|
self.fa_attn_func = flash_attn_varlen_func
|
||||||
logger.debug("Using CK FA in ROCmBackend")
|
logger.debug("Using CK FA in ROCmBackend")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
self.use_naive_attn = True
|
self.use_naive_attn = True
|
||||||
@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
"ROCm Naive FlashAttention does not support "
|
"ROCm Naive FlashAttention does not support "
|
||||||
"attention logits soft capping.")
|
"attention logits soft capping.")
|
||||||
|
|
||||||
self.attn_func = _sdpa_attention
|
self.sdpa_attn_func = _sdpa_attention
|
||||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
assert value is not None
|
assert value is not None
|
||||||
@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
assert attn_metadata.num_encoder_tokens is not None
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
decode_query = query[num_prefill_tokens:]
|
decode_query = query[num_prefill_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
query.dtype,
|
query.dtype,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
make_attn_mask=causal_mask) # type: ignore
|
make_attn_mask=causal_mask) # type: ignore
|
||||||
out, _ = self.attn_func(
|
self.triton_attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
None,
|
output[:num_prefill_tokens],
|
||||||
query_seq_start_loc,
|
query_seq_start_loc,
|
||||||
key_seq_start_loc,
|
key_seq_start_loc,
|
||||||
query_max_seq_len,
|
query_max_seq_len,
|
||||||
@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key = key.movedim(0, key.dim() - 2)
|
key = key.movedim(0, key.dim() - 2)
|
||||||
value = value.movedim(0, value.dim() - 2)
|
value = value.movedim(0, value.dim() - 2)
|
||||||
# sdpa math backend attention
|
# sdpa math backend attention
|
||||||
out = self.attn_func(
|
self.sdpa_attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
output[:num_prefill_tokens],
|
||||||
query_seq_start_loc,
|
query_seq_start_loc,
|
||||||
num_prefill_tokens,
|
num_prefill_tokens,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_masks,
|
attn_masks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = self.attn_func(
|
# upstream FA does not support an output arg, copy
|
||||||
|
output[:num_prefill_tokens] = self.fa_attn_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# common code for prefill
|
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
if output.shape[0] > num_prefill_tokens:
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
|
||||||
output = out
|
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention -
|
# prefix-enabled attention -
|
||||||
# not applicable for encoder-only models
|
# not applicable for encoder-only models
|
||||||
@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
device=output.device,
|
device=output.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
if num_prefill_tokens > 0:
|
|
||||||
out = output[num_prefill_tokens:]
|
|
||||||
else:
|
|
||||||
out = output
|
|
||||||
|
|
||||||
query_start_loc = None
|
query_start_loc = None
|
||||||
ops.paged_attention_rocm(
|
ops.paged_attention_rocm(
|
||||||
out,
|
output[num_prefill_tokens:],
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
@ -878,7 +872,8 @@ def _sdpa_attention(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
seq_lens: List[int],
|
output: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
@ -886,9 +881,9 @@ def _sdpa_attention(
|
|||||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
start = 0
|
start = 0
|
||||||
output = torch.empty((num_tokens, num_heads, head_size),
|
assert output.shape == (num_tokens, num_heads, head_size)
|
||||||
dtype=query.dtype,
|
assert output.dtype == query.dtype
|
||||||
device=query.device)
|
assert output.device == query.device
|
||||||
|
|
||||||
for i, seq_len in enumerate(seq_lens):
|
for i, seq_len in enumerate(seq_lens):
|
||||||
end = start + seq_len
|
end = start + seq_len
|
||||||
|
Loading…
x
Reference in New Issue
Block a user