
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
540 lines
22 KiB
Python
540 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import math
|
|
import random
|
|
import time
|
|
|
|
import pytest
|
|
import torch
|
|
from xformers import ops as xops
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
|
|
|
from vllm.attention.backends.xformers import _make_alibi_bias
|
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
|
|
NUM_HEADS = [64]
|
|
NUM_QUERIES_PER_KV = [1, 8, 64]
|
|
HEAD_SIZES = [128, 96, 24]
|
|
DTYPES = [torch.float16]
|
|
CUDA_DEVICES = [
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
]
|
|
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
|
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
|
@torch.inference_mode()
|
|
def test_contexted_kv_attention(
|
|
num_heads: int,
|
|
num_queries_per_kv: int,
|
|
head_size: int,
|
|
sliding_window: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str,
|
|
device: str,
|
|
) -> None:
|
|
|
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
|
89):
|
|
pytest.skip(
|
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
|
' arch < 89')
|
|
|
|
current_platform.seed_everything(0)
|
|
torch.set_default_device(device)
|
|
|
|
# Need this, otherwise when we capture the graph the process
|
|
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
|
#
|
|
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
|
torch.cuda.set_device(device)
|
|
|
|
MAX_SEQ_LEN = 1024
|
|
MAX_CTX_LEN = 1024
|
|
BS = 10
|
|
cache_size = 640
|
|
block_size = 32
|
|
max_block_per_request = 64
|
|
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
|
num_kv_heads = num_heads // num_queries_per_kv
|
|
|
|
num_tokens = sum(query_lens)
|
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
query.uniform_(-1e-3, 1e-3)
|
|
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
|
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
|
kv.uniform_(-1e-3, 1e-3)
|
|
key, value = kv.unbind(dim=1)
|
|
|
|
if kv_cache_dtype == "auto":
|
|
cache_dtype = dtype
|
|
else:
|
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
|
k_cache = torch.zeros(cache_size,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
dtype=cache_dtype)
|
|
v_cache = torch.zeros(cache_size,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
dtype=cache_dtype)
|
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
|
values = values[torch.randperm(cache_size)]
|
|
block_table = values[:BS * max_block_per_request].view(
|
|
BS, max_block_per_request)
|
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
|
dtype=torch.long),
|
|
dim=0)
|
|
max_input_len = MAX_SEQ_LEN
|
|
# copy kv to cache
|
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
|
dtype=torch.long),
|
|
dim=0)
|
|
for i in range(BS):
|
|
for j in range(query_lens[i]):
|
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
|
j])
|
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
|
b_ctx_len[i] + j])
|
|
cur_ctx = 0
|
|
block_id = 0
|
|
while cur_ctx < b_ctx_len[i]:
|
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
|
if cur_ctx + block_size > b_ctx_len[i]:
|
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
|
else:
|
|
end_loc = start_loc + block_size
|
|
start_slot = block_table[i, block_id] * block_size
|
|
end_slot = start_slot + end_loc - start_loc
|
|
k_cache.view(-1, num_kv_heads,
|
|
head_size)[start_slot:end_slot].copy_(
|
|
key[start_loc:end_loc])
|
|
v_cache.view(-1, num_kv_heads,
|
|
head_size)[start_slot:end_slot].copy_(
|
|
value[start_loc:end_loc])
|
|
cur_ctx += block_size
|
|
block_id += 1
|
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
|
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
|
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
|
head_size).permute(0, 2, 3, 1).contiguous()
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
|
|
# Warm up the Triton kernel by calling it once before actually measuring
|
|
# generation time
|
|
context_attention_fwd(query,
|
|
k,
|
|
v,
|
|
output,
|
|
kv_cache_dtype,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
b_ctx_len,
|
|
max_input_len,
|
|
k_scale,
|
|
v_scale,
|
|
sliding_window=sliding_window)
|
|
torch.cuda.synchronize()
|
|
start_time = time.time()
|
|
context_attention_fwd(query,
|
|
k,
|
|
v,
|
|
output,
|
|
kv_cache_dtype,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
b_ctx_len,
|
|
max_input_len,
|
|
k_scale,
|
|
v_scale,
|
|
sliding_window=sliding_window)
|
|
torch.cuda.synchronize()
|
|
end_time = time.time()
|
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
|
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
attn_op = xops.fmha.cutlass.FwOp()
|
|
|
|
if num_kv_heads != num_heads:
|
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
|
# project the key and value tensors to the desired number of
|
|
# heads.
|
|
#
|
|
# see also: vllm/model_executor/layers/attention.py
|
|
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
|
query.shape[-1])
|
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
|
num_queries_per_kv, key.shape[-1])
|
|
value = value[:, :,
|
|
None, :].expand(value.shape[0], num_kv_heads,
|
|
num_queries_per_kv, value.shape[-1])
|
|
query = query.unsqueeze(0)
|
|
key = key.unsqueeze(0)
|
|
value = value.unsqueeze(0)
|
|
|
|
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
|
query_lens, seq_lens)
|
|
if sliding_window > 0:
|
|
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
|
sliding_window)
|
|
output_ref = xops.memory_efficient_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_bias=attn_bias,
|
|
p=0.0,
|
|
scale=scale,
|
|
op=attn_op,
|
|
)
|
|
torch.cuda.synchronize()
|
|
start_time = time.time()
|
|
output_ref = xops.memory_efficient_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_bias=attn_bias,
|
|
p=0.0,
|
|
scale=scale,
|
|
op=attn_op,
|
|
)
|
|
torch.cuda.synchronize()
|
|
end_time = time.time()
|
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
|
output_ref = output_ref.reshape(output.shape)
|
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@torch.inference_mode()
|
|
def test_contexted_kv_attention_alibi(
|
|
num_heads: int,
|
|
num_queries_per_kv: int,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str,
|
|
device: str,
|
|
) -> None:
|
|
|
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
|
89):
|
|
pytest.skip(
|
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
|
' arch < 89')
|
|
|
|
current_platform.seed_everything(0)
|
|
torch.set_default_device(device)
|
|
|
|
# Need this, otherwise when we capture the graph the process
|
|
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
|
#
|
|
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
|
torch.cuda.set_device(device)
|
|
|
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
|
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
|
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
|
base = torch.tensor(
|
|
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
|
dtype=torch.float32,
|
|
)
|
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
|
slopes = torch.pow(base, powers)
|
|
|
|
if closest_power_of_2 != total_num_heads:
|
|
extra_base = torch.tensor(
|
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
|
dtype=torch.float32,
|
|
)
|
|
num_remaining_heads = min(closest_power_of_2,
|
|
total_num_heads - closest_power_of_2)
|
|
extra_powers = torch.arange(start=1,
|
|
end=1 + 2 * num_remaining_heads,
|
|
step=2,
|
|
dtype=torch.int32)
|
|
slopes = torch.cat(
|
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
return slopes
|
|
|
|
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
|
|
|
|
MAX_SEQ_LEN = 1024
|
|
MAX_CTX_LEN = 1024
|
|
BS = 10
|
|
cache_size = 640
|
|
block_size = 32
|
|
max_block_per_request = 64
|
|
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
|
num_kv_heads = num_heads // num_queries_per_kv
|
|
|
|
num_tokens = sum(query_lens)
|
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
query.uniform_(-1e-3, 1e-3)
|
|
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
|
|
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
|
kv.uniform_(-1e-3, 1e-3)
|
|
key, value = kv.unbind(dim=1)
|
|
if kv_cache_dtype == "auto":
|
|
cache_dtype = dtype
|
|
else:
|
|
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
|
k_cache = torch.zeros(cache_size,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
dtype=cache_dtype)
|
|
v_cache = torch.zeros(cache_size,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
dtype=cache_dtype)
|
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
|
values = values[torch.randperm(cache_size)]
|
|
block_table = values[:BS * max_block_per_request].view(
|
|
BS, max_block_per_request)
|
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
|
dtype=torch.long),
|
|
dim=0)
|
|
max_input_len = MAX_SEQ_LEN
|
|
# copy kv to cache
|
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
|
dtype=torch.long),
|
|
dim=0)
|
|
for i in range(BS):
|
|
for j in range(query_lens[i]):
|
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
|
j])
|
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
|
b_ctx_len[i] + j])
|
|
cur_ctx = 0
|
|
block_id = 0
|
|
while cur_ctx < b_ctx_len[i]:
|
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
|
if cur_ctx + block_size > b_ctx_len[i]:
|
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
|
else:
|
|
end_loc = start_loc + block_size
|
|
start_slot = block_table[i, block_id] * block_size
|
|
end_slot = start_slot + end_loc - start_loc
|
|
k_cache.view(-1, num_kv_heads,
|
|
head_size)[start_slot:end_slot].copy_(
|
|
key[start_loc:end_loc])
|
|
v_cache.view(-1, num_kv_heads,
|
|
head_size)[start_slot:end_slot].copy_(
|
|
value[start_loc:end_loc])
|
|
cur_ctx += block_size
|
|
block_id += 1
|
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
|
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
|
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
|
head_size).permute(0, 2, 3, 1).contiguous()
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
|
|
# Warm up the Triton kernel by calling it once before actually measuring
|
|
# generation time
|
|
context_attention_fwd(query,
|
|
k,
|
|
v,
|
|
output,
|
|
kv_cache_dtype,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
b_ctx_len,
|
|
max_input_len,
|
|
k_scale,
|
|
v_scale,
|
|
alibi_slopes=alibi_slopes)
|
|
torch.cuda.synchronize()
|
|
start_time = time.time()
|
|
context_attention_fwd(query,
|
|
k,
|
|
v,
|
|
output,
|
|
kv_cache_dtype,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
b_ctx_len,
|
|
max_input_len,
|
|
k_scale,
|
|
v_scale,
|
|
alibi_slopes=alibi_slopes)
|
|
torch.cuda.synchronize()
|
|
end_time = time.time()
|
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
|
scale = float(1.0 / (head_size**0.5))
|
|
|
|
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
|
# we have to pad query tensor before MQA/GQA expanding.
|
|
if query.shape[0] != key.shape[0]:
|
|
query_pad = torch.empty(sum(seq_lens),
|
|
num_heads,
|
|
head_size,
|
|
dtype=dtype)
|
|
query_pad.uniform_(-1e-3, 1e-3)
|
|
seq_start = 0
|
|
query_start = 0
|
|
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
|
seq_end = seq_start + seq_len
|
|
query_end = query_start + query_len
|
|
query_pad[seq_start:seq_end, ...] = torch.cat([
|
|
torch.zeros(
|
|
seq_len - query_len, num_heads, head_size, dtype=dtype),
|
|
query[query_start:query_end, ...]
|
|
],
|
|
dim=0)
|
|
seq_start += seq_len
|
|
query_start += query_len
|
|
query = query_pad
|
|
|
|
if num_kv_heads != num_heads:
|
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
|
# project the key and value tensors to the desired number of
|
|
# heads.
|
|
#
|
|
# see also: vllm/model_executor/layers/attention.py
|
|
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
|
query.shape[-1])
|
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
|
num_queries_per_kv, key.shape[-1])
|
|
value = value[:, :,
|
|
None, :].expand(value.shape[0], num_kv_heads,
|
|
num_queries_per_kv, value.shape[-1])
|
|
|
|
query = query.unsqueeze(0)
|
|
key = key.unsqueeze(0)
|
|
value = value.unsqueeze(0)
|
|
|
|
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
|
output_ref = torch.empty_like(output)
|
|
seq_start = 0
|
|
query_start = 0
|
|
start_time = time.time()
|
|
# Attention with alibi slopes.
|
|
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
|
# lengths with custom attention bias, we process each prompt one by
|
|
# one. This is inefficient, especially when we have many short prompts.
|
|
# modified from: vllm/attention/backends/xformers.py#L343
|
|
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
|
seq_end = seq_start + seq_len
|
|
query_end = query_start + query_len
|
|
out = xops.memory_efficient_attention_forward(query[:,
|
|
seq_start:seq_end],
|
|
key[:,
|
|
seq_start:seq_end],
|
|
value[:,
|
|
seq_start:seq_end],
|
|
attn_bias=attn_bias[i],
|
|
p=0.0,
|
|
scale=scale)
|
|
out = out.view_as(query[:, seq_start:seq_end]).view(
|
|
seq_len, num_heads, head_size)
|
|
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
|
|
...])
|
|
seq_start += seq_len
|
|
query_start += query_len
|
|
torch.cuda.synchronize()
|
|
end_time = time.time()
|
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
|
|
|
|
|
# These tests are optional to only run when explicitly invoked
|
|
#
|
|
# pytest -v -s --optional \
|
|
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
|
#
|
|
# These tests are useful to test model dtype float32 on Turing devices.
|
|
# We skip them to not increase the time when running tests on CI
|
|
@pytest.mark.optional
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
|
@torch.inference_mode()
|
|
def test_contexted_kv_attention_f32(
|
|
num_heads: int,
|
|
num_queries_per_kv: int,
|
|
head_size: int,
|
|
sliding_window: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str,
|
|
device: str,
|
|
) -> None:
|
|
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
|
sliding_window, dtype, kv_cache_dtype, device)
|
|
|
|
|
|
@pytest.mark.optional
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@torch.inference_mode()
|
|
def test_contexted_kv_attention_alibi_f32(
|
|
num_heads: int,
|
|
num_queries_per_kv: int,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str,
|
|
device: str,
|
|
) -> None:
|
|
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
|
dtype, kv_cache_dtype, device)
|