[Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (#11301)
This commit is contained in:
parent
3dbd2d813a
commit
5ee10e990d
@ -17,6 +17,8 @@ if not current_platform.is_rocm():
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
|
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
# This will change depending on the compute capability.
|
# This will change depending on the compute capability.
|
||||||
# - 512 as a buffer
|
# - 512 as a buffer
|
||||||
@ -345,16 +347,22 @@ def ref_multi_query_kv_attention(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
|
alibi_bias: Optional[list[torch.Tensor]],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs = len(cu_seq_lens) - 1
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
ref_outputs: list[torch.Tensor] = []
|
ref_outputs: list[torch.Tensor] = []
|
||||||
|
if alibi_bias:
|
||||||
|
assert len(alibi_bias) == num_seqs
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
start_idx = cu_seq_lens[i]
|
start_idx = cu_seq_lens[i]
|
||||||
end_idx = cu_seq_lens[i + 1]
|
end_idx = cu_seq_lens[i + 1]
|
||||||
seq_len = end_idx - start_idx
|
seq_len = end_idx - start_idx
|
||||||
|
|
||||||
# Create attention mask.
|
# Create attention mask. ALiBi already includes a tril causal mask.
|
||||||
|
if alibi_bias:
|
||||||
|
attn_mask = alibi_bias[i]
|
||||||
|
else:
|
||||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||||
diagonal=1)
|
diagonal=1)
|
||||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||||
@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
|
|||||||
return torch.cat(ref_outputs, dim=0)
|
return torch.cat(ref_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
|
use_alibi: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -414,6 +422,30 @@ def test_multi_query_kv_attention(
|
|||||||
# Handle MQA and GQA
|
# Handle MQA and GQA
|
||||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||||
|
alibi_bias = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||||
|
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||||
|
seq_lens)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
start = 0
|
||||||
|
# Dynamic sequence length not supported with custom attn_bias.
|
||||||
|
for i, seq_len in enumerate(seq_lens):
|
||||||
|
end = start + seq_len
|
||||||
|
out = xops.memory_efficient_attention_forward(
|
||||||
|
query[None, start:end],
|
||||||
|
key[None, start:end],
|
||||||
|
value[None, start:end],
|
||||||
|
attn_bias=attn_bias[i],
|
||||||
|
p=0.0,
|
||||||
|
scale=scale)
|
||||||
|
output[start:end].copy_(out.view_as(query[start:end]))
|
||||||
|
start += seq_len
|
||||||
|
# xformers.AttentionBias to Tensor for use in reference impl.
|
||||||
|
alibi_bias = [
|
||||||
|
b.materialize(b.shape, device=device).squeeze() for b in attn_bias
|
||||||
|
]
|
||||||
|
else:
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||||
output = xops.memory_efficient_attention_forward(
|
output = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
|
|||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
scale,
|
scale,
|
||||||
|
alibi_bias,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", [64])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_multi_query_kv_attention_with_alibi(
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
return test_multi_query_kv_attention(
|
||||||
|
num_seqs,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
seed,
|
||||||
|
device,
|
||||||
|
use_alibi=True,
|
||||||
|
)
|
||||||
|
@ -439,14 +439,16 @@ def test_contexted_kv_attention_alibi(
|
|||||||
# heads.
|
# heads.
|
||||||
#
|
#
|
||||||
# see also: vllm/model_executor/layers/attention.py
|
# see also: vllm/model_executor/layers/attention.py
|
||||||
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
|
||||||
query.shape[-1])
|
|
||||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||||
num_queries_per_kv, key.shape[-1])
|
num_queries_per_kv, key.shape[-1])
|
||||||
value = value[:, :,
|
value = value[:, :,
|
||||||
None, :].expand(value.shape[0], num_kv_heads,
|
None, :].expand(value.shape[0], num_kv_heads,
|
||||||
num_queries_per_kv, value.shape[-1])
|
num_queries_per_kv, value.shape[-1])
|
||||||
|
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
||||||
|
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
||||||
|
# codebase. We save some time reshaping alibi matrix at runtime.
|
||||||
|
key = key.reshape(key.shape[0], -1, key.shape[-1])
|
||||||
|
value = value.reshape(value.shape[0], -1, value.shape[-1])
|
||||||
query = query.unsqueeze(0)
|
query = query.unsqueeze(0)
|
||||||
key = key.unsqueeze(0)
|
key = key.unsqueeze(0)
|
||||||
value = value.unsqueeze(0)
|
value = value.unsqueeze(0)
|
||||||
|
@ -788,8 +788,6 @@ def _make_alibi_bias(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)[:, :, :, :seq_len].copy_(bias)
|
)[:, :, :, :seq_len].copy_(bias)
|
||||||
bias.mul_(alibi_slopes[:, None, None])
|
bias.mul_(alibi_slopes[:, None, None])
|
||||||
if num_heads != num_kv_heads:
|
|
||||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
|
||||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||||
|
|
||||||
return attn_biases
|
return attn_biases
|
||||||
|
Loading…
x
Reference in New Issue
Block a user