[Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (#11301)

This commit is contained in:
Nicolò Lucchesi 2025-03-06 05:00:53 +01:00 committed by GitHub
parent 3dbd2d813a
commit 5ee10e990d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 83 additions and 22 deletions

View File

@ -17,6 +17,8 @@ if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.attention.backends.xformers import _make_alibi_bias
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
alibi_bias: Optional[list[torch.Tensor]],
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs: list[torch.Tensor] = []
if alibi_bias:
assert len(alibi_bias) == num_seqs
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx
# Create attention mask.
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
# Create attention mask. ALiBi already includes a tril causal mask.
if alibi_bias:
attn_mask = alibi_bias[i]
else:
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
dtype: torch.dtype,
seed: int,
device: str,
use_alibi: bool = False,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [
b.materialize(b.shape, device=device).squeeze() for b in attn_bias
]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
cu_seq_lens = [0]
for seq_len in seq_lens:
@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
key,
value,
scale,
alibi_bias,
dtype,
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
return test_multi_query_kv_attention(
num_seqs,
num_heads,
head_size,
dtype,
seed,
device,
use_alibi=True,
)

View File

@ -439,14 +439,16 @@ def test_contexted_kv_attention_alibi(
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)

View File

@ -788,8 +788,6 @@ def _make_alibi_bias(
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases