[Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (#11301)
This commit is contained in:
parent
3dbd2d813a
commit
5ee10e990d
@ -17,6 +17,8 @@ if not current_platform.is_rocm():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_bias: Optional[list[torch.Tensor]],
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs: list[torch.Tensor] = []
|
||||
if alibi_bias:
|
||||
assert len(alibi_bias) == num_seqs
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
# Create attention mask. ALiBi already includes a tril causal mask.
|
||||
if alibi_bias:
|
||||
attn_mask = alibi_bias[i]
|
||||
else:
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_alibi: bool = False,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
alibi_bias = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||
seq_lens)
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
# Dynamic sequence length not supported with custom attn_bias.
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale)
|
||||
output[start:end].copy_(out.view_as(query[start:end]))
|
||||
start += seq_len
|
||||
# xformers.AttentionBias to Tensor for use in reference impl.
|
||||
alibi_bias = [
|
||||
b.materialize(b.shape, device=device).squeeze() for b in attn_bias
|
||||
]
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
cu_seq_lens = [0]
|
||||
for seq_len in seq_lens:
|
||||
@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
alibi_bias,
|
||||
dtype,
|
||||
)
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [64])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention_with_alibi(
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
return test_multi_query_kv_attention(
|
||||
num_seqs,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
use_alibi=True,
|
||||
)
|
||||
|
@ -439,14 +439,16 @@ def test_contexted_kv_attention_alibi(
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, value.shape[-1])
|
||||
|
||||
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
||||
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
||||
# codebase. We save some time reshaping alibi matrix at runtime.
|
||||
key = key.reshape(key.shape[0], -1, key.shape[-1])
|
||||
value = value.reshape(value.shape[0], -1, value.shape[-1])
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
@ -788,8 +788,6 @@ def _make_alibi_bias(
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||
|
||||
return attn_biases
|
||||
|
Loading…
x
Reference in New Issue
Block a user