2023-03-01 15:02:19 -08:00
|
|
|
import random
|
2023-09-06 08:57:38 +09:00
|
|
|
from typing import List, Optional, Tuple
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
import pytest
|
2023-03-01 15:02:19 -08:00
|
|
|
import torch
|
2023-05-05 02:01:08 -07:00
|
|
|
from xformers import ops as xops
|
|
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
from vllm import _custom_ops as ops
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2024-05-13 22:50:09 +08:00
|
|
|
from .allclose_default import get_default_atol, get_default_rtol
|
|
|
|
|
2023-09-26 22:27:13 -07:00
|
|
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
|
|
|
# This will change depending on the compute capability.
|
|
|
|
# - 512 as a buffer
|
|
|
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
2024-01-29 08:43:54 +08:00
|
|
|
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
|
|
|
# Reduce NUM_BLOCKS when it happens.
|
|
|
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
2023-10-16 00:59:57 -07:00
|
|
|
PARTITION_SIZE = 512
|
2024-02-05 17:25:36 -05:00
|
|
|
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
|
|
|
DTYPES = [torch.half, torch.bfloat16, torch.float
|
|
|
|
] if not is_hip() else [torch.half, torch.bfloat16]
|
2023-09-06 08:57:38 +09:00
|
|
|
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
2023-10-16 00:59:57 -07:00
|
|
|
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
2023-09-06 08:57:38 +09:00
|
|
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
2024-02-05 17:25:36 -05:00
|
|
|
|
|
|
|
# FlashAttention forward only supports head dimension at most 128
|
|
|
|
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
|
|
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256
|
|
|
|
] if not is_hip() else [64, 80, 96, 112, 128]
|
|
|
|
|
2023-10-16 00:59:57 -07:00
|
|
|
BLOCK_SIZES = [16, 32]
|
2023-09-07 15:53:14 -07:00
|
|
|
USE_ALIBI = [False, True]
|
2024-04-03 16:15:55 -05:00
|
|
|
KV_CACHE_DTYPE = ["auto", "fp8"]
|
2023-09-06 08:57:38 +09:00
|
|
|
SEEDS = [0]
|
2024-02-02 07:46:39 +08:00
|
|
|
CUDA_DEVICES = [
|
|
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
|
|
]
|
2023-03-01 21:13:08 -08:00
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
|
|
|
|
def ref_masked_attention(
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
scale: float,
|
|
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
|
|
) -> torch.Tensor:
|
2023-09-06 08:57:38 +09:00
|
|
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
2023-03-01 15:02:19 -08:00
|
|
|
if attn_mask is not None:
|
2023-09-06 08:57:38 +09:00
|
|
|
attn_weights = attn_weights + attn_mask.float()
|
|
|
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
|
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
2023-03-01 15:02:19 -08:00
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def ref_single_query_cached_kv_attention(
|
|
|
|
output: torch.Tensor,
|
|
|
|
query: torch.Tensor,
|
2023-09-06 08:57:38 +09:00
|
|
|
num_queries_per_kv: int,
|
2023-03-01 15:02:19 -08:00
|
|
|
key_cache: torch.Tensor,
|
|
|
|
value_cache: torch.Tensor,
|
|
|
|
block_tables: torch.Tensor,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens: torch.Tensor,
|
2023-09-06 08:57:38 +09:00
|
|
|
scale: float,
|
|
|
|
alibi_slopes: Optional[torch.Tensor],
|
2023-03-01 15:02:19 -08:00
|
|
|
) -> None:
|
2023-09-06 08:57:38 +09:00
|
|
|
num_query_heads = query.shape[1]
|
|
|
|
num_kv_heads = value_cache.shape[1]
|
2023-03-01 15:02:19 -08:00
|
|
|
head_size = value_cache.shape[2]
|
|
|
|
block_size = value_cache.shape[3]
|
2023-09-06 08:57:38 +09:00
|
|
|
num_seqs = query.shape[0]
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
block_tables = block_tables.cpu().tolist()
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens = seq_lens.cpu().tolist()
|
2023-09-06 08:57:38 +09:00
|
|
|
for i in range(num_seqs):
|
2023-03-01 15:02:19 -08:00
|
|
|
q = query[i].unsqueeze(0)
|
|
|
|
block_table = block_tables[i]
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_len = int(seq_lens[i])
|
2023-03-01 15:02:19 -08:00
|
|
|
|
|
|
|
keys = []
|
|
|
|
values = []
|
2024-05-04 02:20:12 +09:00
|
|
|
for j in range(seq_len):
|
2023-03-01 15:02:19 -08:00
|
|
|
block_number = int(block_table[j // block_size])
|
|
|
|
block_offset = j % block_size
|
|
|
|
|
|
|
|
k = key_cache[block_number, :, :, block_offset, :]
|
2023-09-06 08:57:38 +09:00
|
|
|
k = k.reshape(num_kv_heads, head_size)
|
2023-03-01 15:02:19 -08:00
|
|
|
keys.append(k)
|
|
|
|
|
|
|
|
v = value_cache[block_number, :, :, block_offset]
|
|
|
|
values.append(v)
|
|
|
|
keys = torch.stack(keys, dim=0)
|
|
|
|
values = torch.stack(values, dim=0)
|
2023-09-06 08:57:38 +09:00
|
|
|
if num_queries_per_kv > 1:
|
|
|
|
# Handle MQA and GQA
|
|
|
|
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
|
|
|
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
|
|
|
|
|
|
|
alibi_bias = None
|
|
|
|
if alibi_slopes is not None:
|
|
|
|
# Create the ALiBi bias used in the paged attention kernel.
|
2024-05-04 02:20:12 +09:00
|
|
|
position_ids = torch.arange(seq_len).int()
|
|
|
|
alibi_bias = (position_ids - seq_len + 1).float()
|
2023-09-06 08:57:38 +09:00
|
|
|
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
|
|
|
1, 1, -1)
|
|
|
|
|
|
|
|
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
|
|
|
out = out.view(num_query_heads, head_size)
|
2023-03-01 15:02:19 -08:00
|
|
|
output[i].copy_(out, non_blocking=True)
|
|
|
|
|
|
|
|
|
2023-10-16 00:59:57 -07:00
|
|
|
@pytest.mark.parametrize("version", ["v1", "v2"])
|
2023-09-06 08:57:38 +09:00
|
|
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
|
|
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
2024-01-29 08:43:54 +08:00
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
2023-09-06 08:57:38 +09:00
|
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
2024-02-02 07:46:39 +08:00
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
2023-10-16 00:59:57 -07:00
|
|
|
def test_paged_attention(
|
2023-09-06 08:57:38 +09:00
|
|
|
kv_cache_factory,
|
2023-10-16 00:59:57 -07:00
|
|
|
version: str,
|
2023-09-06 08:57:38 +09:00
|
|
|
num_seqs: int,
|
|
|
|
num_heads: Tuple[int, int],
|
2023-03-01 15:02:19 -08:00
|
|
|
head_size: int,
|
2023-09-06 08:57:38 +09:00
|
|
|
use_alibi: bool,
|
2023-03-01 15:02:19 -08:00
|
|
|
block_size: int,
|
|
|
|
dtype: torch.dtype,
|
2024-01-29 08:43:54 +08:00
|
|
|
kv_cache_dtype: str,
|
2023-09-06 08:57:38 +09:00
|
|
|
seed: int,
|
2024-02-02 07:46:39 +08:00
|
|
|
device: str,
|
2023-03-01 15:02:19 -08:00
|
|
|
) -> None:
|
2023-09-06 08:57:38 +09:00
|
|
|
random.seed(seed)
|
|
|
|
torch.random.manual_seed(seed)
|
2024-02-02 07:46:39 +08:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
torch.set_default_device(device)
|
2023-09-06 08:57:38 +09:00
|
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
num_query_heads, num_kv_heads = num_heads
|
2024-02-02 07:46:39 +08:00
|
|
|
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
2023-09-06 08:57:38 +09:00
|
|
|
query.uniform_(-scale, scale)
|
|
|
|
|
|
|
|
assert num_query_heads % num_kv_heads == 0
|
|
|
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
|
|
alibi_slopes = None
|
|
|
|
if use_alibi:
|
2024-02-02 07:46:39 +08:00
|
|
|
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
2023-09-06 08:57:38 +09:00
|
|
|
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
|
|
|
seq_lens[-1] = MAX_SEQ_LEN
|
|
|
|
max_seq_len = max(seq_lens)
|
|
|
|
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
# Create the block tables.
|
2024-05-04 02:20:12 +09:00
|
|
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
2023-03-01 15:02:19 -08:00
|
|
|
block_tables = []
|
2023-09-06 08:57:38 +09:00
|
|
|
for _ in range(num_seqs):
|
2023-03-01 15:02:19 -08:00
|
|
|
block_table = [
|
2023-09-06 08:57:38 +09:00
|
|
|
random.randint(0, NUM_BLOCKS - 1)
|
2023-03-01 15:02:19 -08:00
|
|
|
for _ in range(max_num_blocks_per_seq)
|
|
|
|
]
|
|
|
|
block_tables.append(block_table)
|
2024-02-02 07:46:39 +08:00
|
|
|
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
# Create the KV caches.
|
|
|
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
2024-01-29 08:43:54 +08:00
|
|
|
num_kv_heads, head_size,
|
|
|
|
kv_cache_dtype, dtype, seed,
|
2024-02-02 07:46:39 +08:00
|
|
|
device)
|
2023-09-06 08:57:38 +09:00
|
|
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
2023-07-25 12:01:56 +08:00
|
|
|
|
2024-04-03 16:15:55 -05:00
|
|
|
# Using default kv_scale
|
|
|
|
kv_scale = 1.0
|
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
# Call the paged attention kernel.
|
|
|
|
output = torch.empty_like(query)
|
2023-10-16 00:59:57 -07:00
|
|
|
if version == "v1":
|
2023-11-23 16:31:19 -08:00
|
|
|
ops.paged_attention_v1(
|
2023-10-16 00:59:57 -07:00
|
|
|
output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
num_kv_heads,
|
2023-10-16 00:59:57 -07:00
|
|
|
scale,
|
|
|
|
block_tables,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens,
|
2023-10-16 00:59:57 -07:00
|
|
|
block_size,
|
2024-05-04 02:20:12 +09:00
|
|
|
max_seq_len,
|
2023-10-16 00:59:57 -07:00
|
|
|
alibi_slopes,
|
2024-01-29 08:43:54 +08:00
|
|
|
kv_cache_dtype,
|
2024-04-03 16:15:55 -05:00
|
|
|
kv_scale,
|
2023-10-16 00:59:57 -07:00
|
|
|
)
|
|
|
|
elif version == "v2":
|
2024-05-04 02:20:12 +09:00
|
|
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
2023-10-16 00:59:57 -07:00
|
|
|
assert PARTITION_SIZE % block_size == 0
|
|
|
|
num_seqs, num_heads, head_size = output.shape
|
|
|
|
tmp_output = torch.empty(
|
|
|
|
size=(num_seqs, num_heads, num_partitions, head_size),
|
|
|
|
dtype=output.dtype,
|
|
|
|
)
|
|
|
|
exp_sums = torch.empty(
|
|
|
|
size=(num_seqs, num_heads, num_partitions),
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
|
|
|
max_logits = torch.empty_like(exp_sums)
|
2023-11-23 16:31:19 -08:00
|
|
|
ops.paged_attention_v2(
|
2023-10-16 00:59:57 -07:00
|
|
|
output,
|
|
|
|
exp_sums,
|
|
|
|
max_logits,
|
|
|
|
tmp_output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
num_kv_heads,
|
2023-10-16 00:59:57 -07:00
|
|
|
scale,
|
|
|
|
block_tables,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens,
|
2023-10-16 00:59:57 -07:00
|
|
|
block_size,
|
2024-05-04 02:20:12 +09:00
|
|
|
max_seq_len,
|
2023-10-16 00:59:57 -07:00
|
|
|
alibi_slopes,
|
2024-01-29 08:43:54 +08:00
|
|
|
kv_cache_dtype,
|
2024-04-03 16:15:55 -05:00
|
|
|
kv_scale,
|
2023-10-16 00:59:57 -07:00
|
|
|
)
|
|
|
|
else:
|
2023-11-20 11:58:01 -08:00
|
|
|
raise AssertionError(f"Unknown version: {version}")
|
2023-03-01 15:02:19 -08:00
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
# Run the reference implementation.
|
2024-04-03 16:15:55 -05:00
|
|
|
if kv_cache_dtype == "fp8":
|
2024-01-29 08:43:54 +08:00
|
|
|
# Convert cache data back to dtype.
|
|
|
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
|
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
|
|
|
block_size, x)
|
|
|
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
|
|
|
dtype=dtype,
|
2024-02-02 07:46:39 +08:00
|
|
|
device=device)
|
2024-05-09 17:04:17 -07:00
|
|
|
ops.convert_fp8(dequantized_key_cache, key_cache)
|
2024-01-29 08:43:54 +08:00
|
|
|
key_cache = dequantized_key_cache
|
|
|
|
|
|
|
|
value_cache_shape = value_cache.shape
|
|
|
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
|
|
|
dtype=dtype,
|
2024-02-02 07:46:39 +08:00
|
|
|
device=device)
|
2024-05-09 17:04:17 -07:00
|
|
|
ops.convert_fp8(dequantized_value_cache, value_cache)
|
2024-01-29 08:43:54 +08:00
|
|
|
value_cache = dequantized_value_cache
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
ref_output = torch.empty_like(query)
|
|
|
|
ref_single_query_cached_kv_attention(
|
|
|
|
ref_output,
|
|
|
|
query,
|
2023-09-06 08:57:38 +09:00
|
|
|
num_queries_per_kv,
|
2023-03-01 15:02:19 -08:00
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
block_tables,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens,
|
2023-09-06 08:57:38 +09:00
|
|
|
scale,
|
|
|
|
alibi_slopes,
|
2023-03-01 15:02:19 -08:00
|
|
|
)
|
2023-09-06 08:57:38 +09:00
|
|
|
|
|
|
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
|
|
|
# implementations, there is a small numerical difference in the two
|
|
|
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
2024-02-05 17:25:36 -05:00
|
|
|
atol = get_default_atol(output) if is_hip() else 1e-3
|
|
|
|
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
|
|
|
|
2024-01-29 08:43:54 +08:00
|
|
|
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
|
|
|
# so we use a relaxed tolerance for the test.
|
2024-04-03 16:15:55 -05:00
|
|
|
atol, rtol = 1e-3, 1e-5
|
|
|
|
if kv_cache_dtype == "fp8":
|
2024-01-29 08:43:54 +08:00
|
|
|
atol, rtol = 1e-2, 1e-5
|
|
|
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
2023-03-01 15:02:19 -08:00
|
|
|
|
|
|
|
|
2023-09-06 08:57:38 +09:00
|
|
|
def ref_multi_query_kv_attention(
|
|
|
|
cu_seq_lens: List[int],
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
scale: float,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
num_seqs = len(cu_seq_lens) - 1
|
|
|
|
ref_outputs = []
|
|
|
|
for i in range(num_seqs):
|
|
|
|
start_idx = cu_seq_lens[i]
|
|
|
|
end_idx = cu_seq_lens[i + 1]
|
|
|
|
seq_len = end_idx - start_idx
|
|
|
|
|
|
|
|
# Create attention mask.
|
|
|
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
|
|
|
diagonal=1)
|
|
|
|
attn_mask = attn_mask * torch.finfo(dtype).min
|
2024-02-02 07:46:39 +08:00
|
|
|
attn_mask = attn_mask.to(dtype=dtype)
|
2023-09-06 08:57:38 +09:00
|
|
|
|
|
|
|
ref_output = ref_masked_attention(
|
|
|
|
query[start_idx:end_idx],
|
|
|
|
key[start_idx:end_idx],
|
|
|
|
value[start_idx:end_idx],
|
|
|
|
scale,
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
)
|
|
|
|
ref_outputs.append(ref_output)
|
|
|
|
ref_output = torch.cat(ref_outputs, dim=0)
|
|
|
|
return ref_output
|
|
|
|
|
|
|
|
|
2023-09-07 15:53:14 -07:00
|
|
|
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
2023-09-06 08:57:38 +09:00
|
|
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
2024-02-02 07:46:39 +08:00
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
2023-05-17 17:11:23 -07:00
|
|
|
@torch.inference_mode()
|
2023-09-06 08:57:38 +09:00
|
|
|
def test_multi_query_kv_attention(
|
2023-03-01 21:13:08 -08:00
|
|
|
num_seqs: int,
|
2023-09-06 08:57:38 +09:00
|
|
|
num_heads: Tuple[int, int],
|
2023-03-01 21:13:08 -08:00
|
|
|
head_size: int,
|
|
|
|
dtype: torch.dtype,
|
2023-09-06 08:57:38 +09:00
|
|
|
seed: int,
|
2024-02-02 07:46:39 +08:00
|
|
|
device: str,
|
2023-03-01 21:13:08 -08:00
|
|
|
) -> None:
|
2023-09-06 08:57:38 +09:00
|
|
|
random.seed(seed)
|
|
|
|
torch.random.manual_seed(seed)
|
2024-02-02 07:46:39 +08:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
torch.set_default_device(device)
|
2023-09-28 14:33:24 -07:00
|
|
|
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
|
|
|
# As the xformers library is already tested with its own tests, we can use
|
|
|
|
# a smaller MAX_SEQ_LEN here.
|
|
|
|
max_len = min(MAX_SEQ_LEN, 4096)
|
|
|
|
seq_lens = random.sample(range(1, max_len), num_seqs)
|
2023-03-01 21:13:08 -08:00
|
|
|
num_tokens = sum(seq_lens)
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
scale = float(1.0 / (head_size**0.5))
|
2023-09-06 08:57:38 +09:00
|
|
|
num_query_heads, num_kv_heads = num_heads
|
2023-07-03 11:31:55 -07:00
|
|
|
qkv = torch.empty(num_tokens,
|
2023-09-06 08:57:38 +09:00
|
|
|
num_query_heads + 2 * num_kv_heads,
|
2023-07-03 11:31:55 -07:00
|
|
|
head_size,
|
2024-02-02 07:46:39 +08:00
|
|
|
dtype=dtype)
|
2023-09-06 08:57:38 +09:00
|
|
|
qkv.uniform_(-scale, scale)
|
|
|
|
query, key, value = qkv.split(
|
|
|
|
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
|
|
|
|
|
|
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
|
|
if num_queries_per_kv > 1:
|
|
|
|
# Handle MQA and GQA
|
|
|
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
|
|
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
2023-05-05 02:01:08 -07:00
|
|
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
|
|
|
output = xops.memory_efficient_attention_forward(
|
|
|
|
query.unsqueeze(0),
|
|
|
|
key.unsqueeze(0),
|
|
|
|
value.unsqueeze(0),
|
|
|
|
attn_bias=attn_bias,
|
|
|
|
p=0.0,
|
|
|
|
scale=scale,
|
2023-04-02 00:30:17 -07:00
|
|
|
)
|
2023-05-05 02:01:08 -07:00
|
|
|
output = output.squeeze(0)
|
2023-03-01 21:13:08 -08:00
|
|
|
|
2023-05-05 02:01:08 -07:00
|
|
|
cu_seq_lens = [0]
|
|
|
|
for seq_len in seq_lens:
|
|
|
|
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
|
2023-03-29 18:59:27 -07:00
|
|
|
ref_output = ref_multi_query_kv_attention(
|
|
|
|
cu_seq_lens,
|
|
|
|
query,
|
|
|
|
key,
|
|
|
|
value,
|
2023-09-06 08:57:38 +09:00
|
|
|
scale,
|
2023-03-29 18:59:27 -07:00
|
|
|
dtype,
|
|
|
|
)
|
2024-02-05 17:25:36 -05:00
|
|
|
atol = get_default_atol(output) if is_hip() else 1e-3
|
|
|
|
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
|
|
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|