diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fc549d7a..0d7898a9 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -17,6 +17,8 @@ if not current_platform.is_rocm(): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + from vllm.attention.backends.xformers import _make_alibi_bias + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -345,20 +347,26 @@ def ref_multi_query_kv_attention( key: torch.Tensor, value: torch.Tensor, scale: float, + alibi_bias: Optional[list[torch.Tensor]], dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs: list[torch.Tensor] = [] + if alibi_bias: + assert len(alibi_bias) == num_seqs for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype) + # Create attention mask. ALiBi already includes a tril causal mask. + if alibi_bias: + attn_mask = alibi_bias[i] + else: + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -372,7 +380,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -# TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -389,6 +396,7 @@ def test_multi_query_kv_attention( dtype: torch.dtype, seed: int, device: str, + use_alibi: bool = False, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -414,16 +422,40 @@ def test_multi_query_kv_attention( # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + alibi_bias = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) + output = torch.empty_like(query) + start = 0 + # Dynamic sequence length not supported with custom attn_bias. + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + output[start:end].copy_(out.view_as(query[start:end])) + start += seq_len + # xformers.AttentionBias to Tensor for use in reference impl. + alibi_bias = [ + b.materialize(b.shape, device=device).squeeze() for b in attn_bias + ] + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: @@ -434,8 +466,37 @@ def test_multi_query_kv_attention( key, value, scale, + alibi_bias, dtype, ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention_with_alibi( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + return test_multi_query_kv_attention( + num_seqs, + num_heads, + head_size, + dtype, + seed, + device, + use_alibi=True, + ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c3ac6a37..f2c7f2c8 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -439,14 +439,16 @@ def test_contexted_kv_attention_alibi( # heads. # # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]) value = value[:, :, None, :].expand(value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]) - + # [seq, num_kv_heads, num_queries_per_kv, dk]=> + # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the + # codebase. We save some time reshaping alibi matrix at runtime. + key = key.reshape(key.shape[0], -1, key.shape[-1]) + value = value.reshape(value.shape[0], -1, value.shape[-1]) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 9fa76634..14c94c9a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -788,8 +788,6 @@ def _make_alibi_bias( dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases