2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-05-09 00:19:58 +08:00
|
|
|
import math
|
2024-01-17 16:32:10 -08:00
|
|
|
import random
|
|
|
|
import time
|
2025-03-06 16:39:16 +01:00
|
|
|
from collections.abc import Callable
|
2024-01-17 16:32:10 -08:00
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
import pytest
|
2024-01-17 16:32:10 -08:00
|
|
|
import torch
|
|
|
|
from xformers import ops as xops
|
|
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
|
|
|
|
2024-05-09 00:19:58 +08:00
|
|
|
from vllm.attention.backends.xformers import _make_alibi_bias
|
2025-03-06 16:39:16 +01:00
|
|
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
|
|
|
chunked_prefill_paged_decode)
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
2024-10-29 22:47:44 +08:00
|
|
|
from vllm.platforms import current_platform
|
|
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-02-27 17:14:31 +08:00
|
|
|
NUM_HEADS = [64]
|
|
|
|
NUM_QUERIES_PER_KV = [1, 8, 64]
|
2024-05-09 00:19:58 +08:00
|
|
|
HEAD_SIZES = [128, 96, 24]
|
2024-01-17 16:32:10 -08:00
|
|
|
DTYPES = [torch.float16]
|
2024-02-02 07:46:39 +08:00
|
|
|
CUDA_DEVICES = [
|
|
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
|
|
]
|
2024-05-02 11:23:37 -07:00
|
|
|
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
2024-08-12 15:47:41 -07:00
|
|
|
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
2024-01-17 16:32:10 -08:00
|
|
|
|
2025-03-06 16:39:16 +01:00
|
|
|
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
|
|
|
|
2024-01-17 16:32:10 -08:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
2024-03-16 11:58:10 +08:00
|
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
2024-01-17 16:32:10 -08:00
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
2024-08-12 15:47:41 -07:00
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
2024-02-02 07:46:39 +08:00
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
2024-05-02 11:23:37 -07:00
|
|
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
2025-03-06 16:39:16 +01:00
|
|
|
@pytest.mark.parametrize("op", OPS)
|
2024-01-17 16:32:10 -08:00
|
|
|
@torch.inference_mode()
|
|
|
|
def test_contexted_kv_attention(
|
|
|
|
num_heads: int,
|
2024-02-27 17:14:31 +08:00
|
|
|
num_queries_per_kv: int,
|
2024-01-17 16:32:10 -08:00
|
|
|
head_size: int,
|
2024-05-02 11:23:37 -07:00
|
|
|
sliding_window: int,
|
2024-01-17 16:32:10 -08:00
|
|
|
dtype: torch.dtype,
|
2024-08-12 15:47:41 -07:00
|
|
|
kv_cache_dtype: str,
|
2024-02-02 07:46:39 +08:00
|
|
|
device: str,
|
2025-03-06 16:39:16 +01:00
|
|
|
op: Callable,
|
2024-01-17 16:32:10 -08:00
|
|
|
) -> None:
|
2024-11-25 14:23:32 -03:00
|
|
|
|
|
|
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
|
|
|
89):
|
|
|
|
pytest.skip(
|
|
|
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
|
|
|
' arch < 89')
|
|
|
|
|
2024-10-29 22:47:44 +08:00
|
|
|
current_platform.seed_everything(0)
|
2024-02-02 07:46:39 +08:00
|
|
|
torch.set_default_device(device)
|
2024-03-16 11:58:10 +08:00
|
|
|
|
2024-03-27 13:39:26 -07:00
|
|
|
# Need this, otherwise when we capture the graph the process
|
|
|
|
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
2024-03-16 11:58:10 +08:00
|
|
|
#
|
|
|
|
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
|
2024-01-17 16:32:10 -08:00
|
|
|
MAX_SEQ_LEN = 1024
|
|
|
|
MAX_CTX_LEN = 1024
|
|
|
|
BS = 10
|
|
|
|
cache_size = 640
|
|
|
|
block_size = 32
|
|
|
|
max_block_per_request = 64
|
2024-05-04 02:20:12 +09:00
|
|
|
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
2025-03-06 16:39:16 +01:00
|
|
|
# ensure one sequence in batch is a decode
|
|
|
|
query_lens[-1] = 1
|
|
|
|
|
2024-01-17 16:32:10 -08:00
|
|
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
2024-02-27 17:14:31 +08:00
|
|
|
num_kv_heads = num_heads // num_queries_per_kv
|
2024-01-17 16:32:10 -08:00
|
|
|
|
2024-05-04 02:20:12 +09:00
|
|
|
num_tokens = sum(query_lens)
|
2024-02-02 07:46:39 +08:00
|
|
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
2024-01-17 16:32:10 -08:00
|
|
|
query.uniform_(-1e-3, 1e-3)
|
2024-02-02 07:46:39 +08:00
|
|
|
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
2024-01-17 16:32:10 -08:00
|
|
|
|
2024-02-27 17:14:31 +08:00
|
|
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
2024-01-17 16:32:10 -08:00
|
|
|
kv.uniform_(-1e-3, 1e-3)
|
|
|
|
key, value = kv.unbind(dim=1)
|
|
|
|
|
2024-08-12 15:47:41 -07:00
|
|
|
if kv_cache_dtype == "auto":
|
|
|
|
cache_dtype = dtype
|
|
|
|
else:
|
|
|
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
2024-01-17 16:32:10 -08:00
|
|
|
k_cache = torch.zeros(cache_size,
|
|
|
|
block_size,
|
2024-02-27 17:14:31 +08:00
|
|
|
num_kv_heads,
|
2024-01-17 16:32:10 -08:00
|
|
|
head_size,
|
2024-08-12 15:47:41 -07:00
|
|
|
dtype=cache_dtype)
|
2024-01-17 16:32:10 -08:00
|
|
|
v_cache = torch.zeros(cache_size,
|
|
|
|
block_size,
|
2024-02-27 17:14:31 +08:00
|
|
|
num_kv_heads,
|
2024-01-17 16:32:10 -08:00
|
|
|
head_size,
|
2024-08-12 15:47:41 -07:00
|
|
|
dtype=cache_dtype)
|
2024-05-04 02:20:12 +09:00
|
|
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
|
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
2024-02-02 07:46:39 +08:00
|
|
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
2024-01-17 16:32:10 -08:00
|
|
|
values = values[torch.randperm(cache_size)]
|
|
|
|
block_table = values[:BS * max_block_per_request].view(
|
|
|
|
BS, max_block_per_request)
|
2024-02-02 07:46:39 +08:00
|
|
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
|
|
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
2025-02-22 05:25:41 -08:00
|
|
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
2024-02-02 07:46:39 +08:00
|
|
|
dtype=torch.long),
|
2024-01-17 16:32:10 -08:00
|
|
|
dim=0)
|
|
|
|
max_input_len = MAX_SEQ_LEN
|
|
|
|
# copy kv to cache
|
|
|
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
2024-02-02 07:46:39 +08:00
|
|
|
dtype=torch.long),
|
2024-01-17 16:32:10 -08:00
|
|
|
dim=0)
|
|
|
|
for i in range(BS):
|
2024-05-04 02:20:12 +09:00
|
|
|
for j in range(query_lens[i]):
|
2024-01-17 16:32:10 -08:00
|
|
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
|
|
|
j])
|
|
|
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
|
|
|
b_ctx_len[i] + j])
|
|
|
|
cur_ctx = 0
|
|
|
|
block_id = 0
|
|
|
|
while cur_ctx < b_ctx_len[i]:
|
|
|
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
|
|
|
if cur_ctx + block_size > b_ctx_len[i]:
|
|
|
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
|
|
|
else:
|
|
|
|
end_loc = start_loc + block_size
|
|
|
|
start_slot = block_table[i, block_id] * block_size
|
|
|
|
end_slot = start_slot + end_loc - start_loc
|
2024-02-27 17:14:31 +08:00
|
|
|
k_cache.view(-1, num_kv_heads,
|
|
|
|
head_size)[start_slot:end_slot].copy_(
|
|
|
|
key[start_loc:end_loc])
|
|
|
|
v_cache.view(-1, num_kv_heads,
|
|
|
|
head_size)[start_slot:end_slot].copy_(
|
|
|
|
value[start_loc:end_loc])
|
2024-01-17 16:32:10 -08:00
|
|
|
cur_ctx += block_size
|
|
|
|
block_id += 1
|
|
|
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
2024-02-27 17:14:31 +08:00
|
|
|
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
2024-01-17 16:32:10 -08:00
|
|
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
|
|
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
2024-02-27 17:14:31 +08:00
|
|
|
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
2024-01-17 16:32:10 -08:00
|
|
|
head_size).permute(0, 2, 3, 1).contiguous()
|
2025-01-23 13:04:03 -05:00
|
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
2024-01-17 16:32:10 -08:00
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
# Warm up the Triton kernel by calling it once before actually measuring
|
|
|
|
# generation time
|
2025-03-06 16:39:16 +01:00
|
|
|
op(query,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
output,
|
|
|
|
kv_cache_dtype,
|
|
|
|
k_cache,
|
|
|
|
v_cache,
|
|
|
|
block_table,
|
|
|
|
b_start_loc,
|
|
|
|
b_seq_len,
|
|
|
|
max_input_len,
|
|
|
|
k_scale,
|
|
|
|
v_scale,
|
|
|
|
sliding_window=sliding_window)
|
2024-01-17 16:32:10 -08:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = time.time()
|
2025-03-06 16:39:16 +01:00
|
|
|
op(query,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
output,
|
|
|
|
kv_cache_dtype,
|
|
|
|
k_cache,
|
|
|
|
v_cache,
|
|
|
|
block_table,
|
|
|
|
b_start_loc,
|
|
|
|
b_seq_len,
|
|
|
|
max_input_len,
|
|
|
|
k_scale,
|
|
|
|
v_scale,
|
|
|
|
sliding_window=sliding_window)
|
2024-01-17 16:32:10 -08:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.time()
|
|
|
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
|
|
|
|
|
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
|
|
|
|
attn_op = xops.fmha.cutlass.FwOp()
|
|
|
|
|
2024-02-27 17:14:31 +08:00
|
|
|
if num_kv_heads != num_heads:
|
|
|
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
|
|
|
# project the key and value tensors to the desired number of
|
|
|
|
# heads.
|
|
|
|
#
|
|
|
|
# see also: vllm/model_executor/layers/attention.py
|
|
|
|
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
|
|
|
query.shape[-1])
|
|
|
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
|
|
|
num_queries_per_kv, key.shape[-1])
|
|
|
|
value = value[:, :,
|
|
|
|
None, :].expand(value.shape[0], num_kv_heads,
|
|
|
|
num_queries_per_kv, value.shape[-1])
|
|
|
|
query = query.unsqueeze(0)
|
|
|
|
key = key.unsqueeze(0)
|
|
|
|
value = value.unsqueeze(0)
|
|
|
|
|
2024-01-17 16:32:10 -08:00
|
|
|
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
2024-05-04 02:20:12 +09:00
|
|
|
query_lens, seq_lens)
|
2024-05-02 11:23:37 -07:00
|
|
|
if sliding_window > 0:
|
|
|
|
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
|
|
|
sliding_window)
|
2024-01-17 16:32:10 -08:00
|
|
|
output_ref = xops.memory_efficient_attention_forward(
|
2024-02-27 17:14:31 +08:00
|
|
|
query,
|
|
|
|
key,
|
|
|
|
value,
|
2024-01-17 16:32:10 -08:00
|
|
|
attn_bias=attn_bias,
|
|
|
|
p=0.0,
|
|
|
|
scale=scale,
|
|
|
|
op=attn_op,
|
|
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = time.time()
|
|
|
|
output_ref = xops.memory_efficient_attention_forward(
|
2024-02-27 17:14:31 +08:00
|
|
|
query,
|
|
|
|
key,
|
|
|
|
value,
|
2024-01-17 16:32:10 -08:00
|
|
|
attn_bias=attn_bias,
|
|
|
|
p=0.0,
|
|
|
|
scale=scale,
|
|
|
|
op=attn_op,
|
|
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.time()
|
|
|
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
2024-03-16 11:58:10 +08:00
|
|
|
output_ref = output_ref.reshape(output.shape)
|
2025-03-06 16:39:16 +01:00
|
|
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
2024-08-12 15:47:41 -07:00
|
|
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
2024-05-09 00:19:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
2024-08-12 15:47:41 -07:00
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
2024-05-09 00:19:58 +08:00
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
2025-03-06 16:39:16 +01:00
|
|
|
@pytest.mark.parametrize("op", OPS)
|
2024-05-09 00:19:58 +08:00
|
|
|
@torch.inference_mode()
|
|
|
|
def test_contexted_kv_attention_alibi(
|
|
|
|
num_heads: int,
|
|
|
|
num_queries_per_kv: int,
|
|
|
|
head_size: int,
|
|
|
|
dtype: torch.dtype,
|
2024-08-12 15:47:41 -07:00
|
|
|
kv_cache_dtype: str,
|
2024-05-09 00:19:58 +08:00
|
|
|
device: str,
|
2025-03-06 16:39:16 +01:00
|
|
|
op: Callable,
|
2024-05-09 00:19:58 +08:00
|
|
|
) -> None:
|
2024-11-25 14:23:32 -03:00
|
|
|
|
|
|
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
|
|
|
89):
|
|
|
|
pytest.skip(
|
|
|
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
|
|
|
' arch < 89')
|
|
|
|
|
2024-10-29 22:47:44 +08:00
|
|
|
current_platform.seed_everything(0)
|
2024-05-09 00:19:58 +08:00
|
|
|
torch.set_default_device(device)
|
|
|
|
|
|
|
|
# Need this, otherwise when we capture the graph the process
|
|
|
|
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
|
|
|
#
|
|
|
|
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
|
|
|
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
|
|
|
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
|
|
|
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
|
|
|
base = torch.tensor(
|
|
|
|
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
|
|
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
|
|
|
slopes = torch.pow(base, powers)
|
|
|
|
|
|
|
|
if closest_power_of_2 != total_num_heads:
|
|
|
|
extra_base = torch.tensor(
|
|
|
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
|
|
|
num_remaining_heads = min(closest_power_of_2,
|
|
|
|
total_num_heads - closest_power_of_2)
|
|
|
|
extra_powers = torch.arange(start=1,
|
|
|
|
end=1 + 2 * num_remaining_heads,
|
|
|
|
step=2,
|
|
|
|
dtype=torch.int32)
|
|
|
|
slopes = torch.cat(
|
|
|
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
|
|
return slopes
|
|
|
|
|
|
|
|
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
|
|
|
|
|
|
|
|
MAX_SEQ_LEN = 1024
|
|
|
|
MAX_CTX_LEN = 1024
|
|
|
|
BS = 10
|
|
|
|
cache_size = 640
|
|
|
|
block_size = 32
|
|
|
|
max_block_per_request = 64
|
|
|
|
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
|
|
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
|
|
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
|
|
|
num_kv_heads = num_heads // num_queries_per_kv
|
|
|
|
|
|
|
|
num_tokens = sum(query_lens)
|
|
|
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
|
|
query.uniform_(-1e-3, 1e-3)
|
|
|
|
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
|
|
|
|
|
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
|
|
|
kv.uniform_(-1e-3, 1e-3)
|
|
|
|
key, value = kv.unbind(dim=1)
|
2024-08-12 15:47:41 -07:00
|
|
|
if kv_cache_dtype == "auto":
|
|
|
|
cache_dtype = dtype
|
|
|
|
else:
|
|
|
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
2024-05-09 00:19:58 +08:00
|
|
|
k_cache = torch.zeros(cache_size,
|
|
|
|
block_size,
|
|
|
|
num_kv_heads,
|
|
|
|
head_size,
|
2024-08-12 15:47:41 -07:00
|
|
|
dtype=cache_dtype)
|
2024-05-09 00:19:58 +08:00
|
|
|
v_cache = torch.zeros(cache_size,
|
|
|
|
block_size,
|
|
|
|
num_kv_heads,
|
|
|
|
head_size,
|
2024-08-12 15:47:41 -07:00
|
|
|
dtype=cache_dtype)
|
2024-05-09 00:19:58 +08:00
|
|
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
|
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
|
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
|
|
|
values = values[torch.randperm(cache_size)]
|
|
|
|
block_table = values[:BS * max_block_per_request].view(
|
|
|
|
BS, max_block_per_request)
|
|
|
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
|
|
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
2025-02-22 05:25:41 -08:00
|
|
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
2024-05-09 00:19:58 +08:00
|
|
|
dtype=torch.long),
|
|
|
|
dim=0)
|
|
|
|
max_input_len = MAX_SEQ_LEN
|
|
|
|
# copy kv to cache
|
|
|
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
|
|
|
dtype=torch.long),
|
|
|
|
dim=0)
|
|
|
|
for i in range(BS):
|
|
|
|
for j in range(query_lens[i]):
|
|
|
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
|
|
|
j])
|
|
|
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
|
|
|
b_ctx_len[i] + j])
|
|
|
|
cur_ctx = 0
|
|
|
|
block_id = 0
|
|
|
|
while cur_ctx < b_ctx_len[i]:
|
|
|
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
|
|
|
if cur_ctx + block_size > b_ctx_len[i]:
|
|
|
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
|
|
|
else:
|
|
|
|
end_loc = start_loc + block_size
|
|
|
|
start_slot = block_table[i, block_id] * block_size
|
|
|
|
end_slot = start_slot + end_loc - start_loc
|
|
|
|
k_cache.view(-1, num_kv_heads,
|
|
|
|
head_size)[start_slot:end_slot].copy_(
|
|
|
|
key[start_loc:end_loc])
|
|
|
|
v_cache.view(-1, num_kv_heads,
|
|
|
|
head_size)[start_slot:end_slot].copy_(
|
|
|
|
value[start_loc:end_loc])
|
|
|
|
cur_ctx += block_size
|
|
|
|
block_id += 1
|
|
|
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
|
|
|
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
|
|
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
|
|
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
|
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
|
|
|
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
|
|
|
head_size).permute(0, 2, 3, 1).contiguous()
|
2025-01-23 13:04:03 -05:00
|
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
2024-05-09 00:19:58 +08:00
|
|
|
|
|
|
|
# Warm up the Triton kernel by calling it once before actually measuring
|
|
|
|
# generation time
|
2025-03-06 16:39:16 +01:00
|
|
|
op(query,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
output,
|
|
|
|
kv_cache_dtype,
|
|
|
|
k_cache,
|
|
|
|
v_cache,
|
|
|
|
block_table,
|
|
|
|
b_start_loc,
|
|
|
|
b_seq_len,
|
|
|
|
max_input_len,
|
|
|
|
k_scale,
|
|
|
|
v_scale,
|
|
|
|
alibi_slopes=alibi_slopes)
|
2024-05-09 00:19:58 +08:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = time.time()
|
2025-03-06 16:39:16 +01:00
|
|
|
op(query,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
output,
|
|
|
|
kv_cache_dtype,
|
|
|
|
k_cache,
|
|
|
|
v_cache,
|
|
|
|
block_table,
|
|
|
|
b_start_loc,
|
|
|
|
b_seq_len,
|
|
|
|
max_input_len,
|
|
|
|
k_scale,
|
|
|
|
v_scale,
|
|
|
|
alibi_slopes=alibi_slopes)
|
2024-05-09 00:19:58 +08:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.time()
|
|
|
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
|
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
|
|
|
|
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
|
|
|
# we have to pad query tensor before MQA/GQA expanding.
|
|
|
|
if query.shape[0] != key.shape[0]:
|
|
|
|
query_pad = torch.empty(sum(seq_lens),
|
|
|
|
num_heads,
|
|
|
|
head_size,
|
|
|
|
dtype=dtype)
|
|
|
|
query_pad.uniform_(-1e-3, 1e-3)
|
|
|
|
seq_start = 0
|
|
|
|
query_start = 0
|
|
|
|
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
|
|
|
seq_end = seq_start + seq_len
|
|
|
|
query_end = query_start + query_len
|
|
|
|
query_pad[seq_start:seq_end, ...] = torch.cat([
|
|
|
|
torch.zeros(
|
|
|
|
seq_len - query_len, num_heads, head_size, dtype=dtype),
|
|
|
|
query[query_start:query_end, ...]
|
|
|
|
],
|
|
|
|
dim=0)
|
|
|
|
seq_start += seq_len
|
|
|
|
query_start += query_len
|
|
|
|
query = query_pad
|
|
|
|
|
|
|
|
if num_kv_heads != num_heads:
|
|
|
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
|
|
|
# project the key and value tensors to the desired number of
|
|
|
|
# heads.
|
|
|
|
#
|
|
|
|
# see also: vllm/model_executor/layers/attention.py
|
|
|
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
|
|
|
num_queries_per_kv, key.shape[-1])
|
|
|
|
value = value[:, :,
|
|
|
|
None, :].expand(value.shape[0], num_kv_heads,
|
|
|
|
num_queries_per_kv, value.shape[-1])
|
2025-03-06 05:00:53 +01:00
|
|
|
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
|
|
|
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
|
|
|
# codebase. We save some time reshaping alibi matrix at runtime.
|
|
|
|
key = key.reshape(key.shape[0], -1, key.shape[-1])
|
|
|
|
value = value.reshape(value.shape[0], -1, value.shape[-1])
|
2024-05-09 00:19:58 +08:00
|
|
|
query = query.unsqueeze(0)
|
|
|
|
key = key.unsqueeze(0)
|
|
|
|
value = value.unsqueeze(0)
|
|
|
|
|
|
|
|
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
|
|
|
output_ref = torch.empty_like(output)
|
|
|
|
seq_start = 0
|
|
|
|
query_start = 0
|
|
|
|
start_time = time.time()
|
|
|
|
# Attention with alibi slopes.
|
|
|
|
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
|
|
|
# lengths with custom attention bias, we process each prompt one by
|
|
|
|
# one. This is inefficient, especially when we have many short prompts.
|
|
|
|
# modified from: vllm/attention/backends/xformers.py#L343
|
|
|
|
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
|
|
|
seq_end = seq_start + seq_len
|
|
|
|
query_end = query_start + query_len
|
|
|
|
out = xops.memory_efficient_attention_forward(query[:,
|
|
|
|
seq_start:seq_end],
|
|
|
|
key[:,
|
|
|
|
seq_start:seq_end],
|
|
|
|
value[:,
|
|
|
|
seq_start:seq_end],
|
|
|
|
attn_bias=attn_bias[i],
|
|
|
|
p=0.0,
|
|
|
|
scale=scale)
|
|
|
|
out = out.view_as(query[:, seq_start:seq_end]).view(
|
|
|
|
seq_len, num_heads, head_size)
|
|
|
|
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
|
|
|
|
...])
|
|
|
|
seq_start += seq_len
|
|
|
|
query_start += query_len
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.time()
|
|
|
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
2024-08-12 15:47:41 -07:00
|
|
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
|
|
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
2024-11-25 14:23:32 -03:00
|
|
|
|
|
|
|
|
|
|
|
# These tests are optional to only run when explicitly invoked
|
|
|
|
#
|
|
|
|
# pytest -v -s --optional \
|
|
|
|
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
|
|
|
#
|
|
|
|
# These tests are useful to test model dtype float32 on Turing devices.
|
|
|
|
# We skip them to not increase the time when running tests on CI
|
|
|
|
@pytest.mark.optional
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
|
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
2025-03-06 16:39:16 +01:00
|
|
|
@pytest.mark.parametrize("op", OPS)
|
2024-11-25 14:23:32 -03:00
|
|
|
@torch.inference_mode()
|
|
|
|
def test_contexted_kv_attention_f32(
|
|
|
|
num_heads: int,
|
|
|
|
num_queries_per_kv: int,
|
|
|
|
head_size: int,
|
|
|
|
sliding_window: int,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
kv_cache_dtype: str,
|
|
|
|
device: str,
|
2025-03-06 16:39:16 +01:00
|
|
|
op: Callable,
|
2024-11-25 14:23:32 -03:00
|
|
|
) -> None:
|
|
|
|
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
2025-03-06 16:39:16 +01:00
|
|
|
sliding_window, dtype, kv_cache_dtype, device,
|
|
|
|
op)
|
2024-11-25 14:23:32 -03:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.optional
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
2025-03-06 16:39:16 +01:00
|
|
|
@pytest.mark.parametrize("op", OPS)
|
2024-11-25 14:23:32 -03:00
|
|
|
@torch.inference_mode()
|
|
|
|
def test_contexted_kv_attention_alibi_f32(
|
|
|
|
num_heads: int,
|
|
|
|
num_queries_per_kv: int,
|
|
|
|
head_size: int,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
kv_cache_dtype: str,
|
|
|
|
device: str,
|
2025-03-06 16:39:16 +01:00
|
|
|
op: Callable,
|
2024-11-25 14:23:32 -03:00
|
|
|
) -> None:
|
|
|
|
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
2025-03-06 16:39:16 +01:00
|
|
|
dtype, kv_cache_dtype, device, op)
|