[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend (#14152)
This commit is contained in:
parent
4f27044aab
commit
6bd1dd9d26
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -10,6 +11,8 @@ from xformers import ops as xops
|
|||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||||
|
|
||||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||||
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||||
|
chunked_prefill_paged_decode)
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
@ -24,6 +27,8 @@ CUDA_DEVICES = [
|
|||||||
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
||||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||||
|
|
||||||
|
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
@ -32,6 +37,7 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
|||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||||
|
@pytest.mark.parametrize("op", OPS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention(
|
def test_contexted_kv_attention(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -41,6 +47,7 @@ def test_contexted_kv_attention(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
op: Callable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||||
@ -65,6 +72,9 @@ def test_contexted_kv_attention(
|
|||||||
block_size = 32
|
block_size = 32
|
||||||
max_block_per_request = 64
|
max_block_per_request = 64
|
||||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||||
|
# ensure one sequence in batch is a decode
|
||||||
|
query_lens[-1] = 1
|
||||||
|
|
||||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||||
num_kv_heads = num_heads // num_queries_per_kv
|
num_kv_heads = num_heads // num_queries_per_kv
|
||||||
@ -144,36 +154,36 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
# Warm up the Triton kernel by calling it once before actually measuring
|
# Warm up the Triton kernel by calling it once before actually measuring
|
||||||
# generation time
|
# generation time
|
||||||
context_attention_fwd(query,
|
op(query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
context_attention_fwd(query,
|
op(query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
@ -228,7 +238,7 @@ def test_contexted_kv_attention(
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
output_ref = output_ref.reshape(output.shape)
|
output_ref = output_ref.reshape(output.shape)
|
||||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
||||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@ -238,6 +248,7 @@ def test_contexted_kv_attention(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("op", OPS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention_alibi(
|
def test_contexted_kv_attention_alibi(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
op: Callable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||||
@ -375,36 +387,36 @@ def test_contexted_kv_attention_alibi(
|
|||||||
|
|
||||||
# Warm up the Triton kernel by calling it once before actually measuring
|
# Warm up the Triton kernel by calling it once before actually measuring
|
||||||
# generation time
|
# generation time
|
||||||
context_attention_fwd(query,
|
op(query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
context_attention_fwd(query,
|
op(query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
@ -503,6 +515,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||||
|
@pytest.mark.parametrize("op", OPS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention_f32(
|
def test_contexted_kv_attention_f32(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
op: Callable,
|
||||||
) -> None:
|
) -> None:
|
||||||
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
||||||
sliding_window, dtype, kv_cache_dtype, device)
|
sliding_window, dtype, kv_cache_dtype, device,
|
||||||
|
op)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.optional
|
@pytest.mark.optional
|
||||||
@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
|
|||||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("op", OPS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention_alibi_f32(
|
def test_contexted_kv_attention_alibi_f32(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
op: Callable,
|
||||||
) -> None:
|
) -> None:
|
||||||
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
||||||
dtype, kv_cache_dtype, device)
|
dtype, kv_cache_dtype, device, op)
|
||||||
|
289
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
289
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel_paged_attention_2d(
|
||||||
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
|
seq_lens_ptr, # [num_seqs]
|
||||||
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
scale, # float32
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
num_query_heads: tl.constexpr, # int
|
||||||
|
num_queries_per_kv: tl.constexpr, # int
|
||||||
|
block_table_stride: tl.constexpr, # int
|
||||||
|
query_stride_0: tl.constexpr, # int
|
||||||
|
query_stride_1: tl.constexpr, # int, should be equal to head_size
|
||||||
|
output_stride_0: tl.constexpr, # int
|
||||||
|
output_stride_1: tl.constexpr, # int, should be equal to head_size
|
||||||
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
x: tl.constexpr, # int
|
||||||
|
stride_k_cache_0: tl.constexpr, # int
|
||||||
|
stride_k_cache_1: tl.constexpr, # int
|
||||||
|
stride_k_cache_2: tl.constexpr, # int
|
||||||
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
|
stride_k_cache_4: tl.constexpr, # int
|
||||||
|
stride_v_cache_0: tl.constexpr, # int
|
||||||
|
stride_v_cache_1: tl.constexpr, # int
|
||||||
|
stride_v_cache_2: tl.constexpr, # int
|
||||||
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
|
filter_by_query_len: tl.constexpr, # bool
|
||||||
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
):
|
||||||
|
seq_idx = tl.program_id(0)
|
||||||
|
query_head_idx = tl.program_id(1)
|
||||||
|
kv_head_idx = query_head_idx // num_queries_per_kv
|
||||||
|
|
||||||
|
if filter_by_query_len:
|
||||||
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||||
|
1)
|
||||||
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||||
|
- cur_batch_in_all_start_index
|
||||||
|
if cur_batch_query_len > 1:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cur_batch_in_all_start_index = seq_idx
|
||||||
|
|
||||||
|
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||||
|
query_head_idx * query_stride_1)
|
||||||
|
|
||||||
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||||
|
0).to(tl.int1)
|
||||||
|
|
||||||
|
# Q : (HEAD_SIZE,)
|
||||||
|
Q = tl.load(
|
||||||
|
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED),
|
||||||
|
mask=dim_mask,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
|
M = tl.full([1], float("-inf"), dtype=tl.float32)
|
||||||
|
L = tl.full([1], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
|
# sequence len for this particular sequence
|
||||||
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||||
|
|
||||||
|
# alibi slope for this head
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx)
|
||||||
|
|
||||||
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||||
|
|
||||||
|
# iterate through tiles
|
||||||
|
for j in range(0, num_blocks):
|
||||||
|
|
||||||
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||||
|
|
||||||
|
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||||
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
|
||||||
|
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||||
|
kv_head_idx * stride_v_cache_1 +
|
||||||
|
offs_d[:, None] * stride_v_cache_2 +
|
||||||
|
offs_n[None, :] * stride_v_cache_3)
|
||||||
|
|
||||||
|
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||||
|
kv_head_idx * stride_k_cache_1 +
|
||||||
|
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||||
|
offs_n[None, :] * stride_k_cache_3 +
|
||||||
|
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||||
|
|
||||||
|
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||||
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
|
mask=dim_mask[:, None],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if K_load.dtype.is_fp8():
|
||||||
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
K = K_load
|
||||||
|
|
||||||
|
# V : (HEAD_SIZE, BLOCK_SIZE)
|
||||||
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
|
mask=dim_mask[:, None],
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if V_load.dtype.is_fp8():
|
||||||
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||||
|
else:
|
||||||
|
V = V_load
|
||||||
|
|
||||||
|
tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||||
|
mask_new = tmp < boundary
|
||||||
|
# S : (BLOCK_SIZE,)
|
||||||
|
S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32)
|
||||||
|
S += scale * tl.sum(K * Q[:, None], axis=0)
|
||||||
|
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000)
|
||||||
|
|
||||||
|
if USE_ALIBI_SLOPES:
|
||||||
|
S += alibi_slope * (tmp - seq_len + 1)
|
||||||
|
|
||||||
|
# compute running maximum
|
||||||
|
# m_j : (1,)
|
||||||
|
m_j = tl.maximum(M, tl.max(S, axis=0))
|
||||||
|
|
||||||
|
# P : (BLOCK_SIZE,)
|
||||||
|
P = tl.exp(S - m_j)
|
||||||
|
|
||||||
|
# l_j : (1,)
|
||||||
|
l_j = tl.sum(P, axis=0)
|
||||||
|
|
||||||
|
# alpha : (1, )
|
||||||
|
alpha = tl.exp(M - m_j)
|
||||||
|
|
||||||
|
# acc : (BLOCK_SIZE,)
|
||||||
|
acc = acc * alpha
|
||||||
|
|
||||||
|
# update constants
|
||||||
|
L = L * alpha + l_j
|
||||||
|
M = m_j
|
||||||
|
|
||||||
|
# acc : (BLOCK_SIZE,)
|
||||||
|
acc += tl.sum(V * P[None, :], axis=1)
|
||||||
|
|
||||||
|
# epilogue
|
||||||
|
acc = acc / L
|
||||||
|
|
||||||
|
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||||
|
query_head_idx * output_stride_1)
|
||||||
|
|
||||||
|
tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED),
|
||||||
|
acc,
|
||||||
|
mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_prefill_paged_decode(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
kv_cache_dtype,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_table,
|
||||||
|
query_start_loc,
|
||||||
|
seq_lens,
|
||||||
|
max_query_len,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
sm_scale=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||||
|
|
||||||
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
|
||||||
|
if sliding_window is None or sliding_window <= 0:
|
||||||
|
sliding_window = 0
|
||||||
|
|
||||||
|
if max_query_len > 1:
|
||||||
|
context_attention_fwd(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
o=output,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
b_loc=block_table,
|
||||||
|
b_start_loc=query_start_loc,
|
||||||
|
b_seq_len=seq_lens,
|
||||||
|
max_input_len=max_query_len,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
skip_decode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
num_query_heads = query.shape[1]
|
||||||
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
|
head_size = query.shape[2]
|
||||||
|
|
||||||
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
assert key_cache.dtype == torch.uint8
|
||||||
|
assert value_cache.dtype == torch.uint8
|
||||||
|
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
target_dtype = torch.float8_e4m3fn
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
target_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||||
|
|
||||||
|
key_cache = key_cache.view(target_dtype)
|
||||||
|
value_cache = value_cache.view(target_dtype)
|
||||||
|
|
||||||
|
kernel_paged_attention_2d[(
|
||||||
|
num_seqs,
|
||||||
|
num_query_heads,
|
||||||
|
)](
|
||||||
|
output_ptr=output,
|
||||||
|
query_ptr=query,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
block_tables_ptr=block_table,
|
||||||
|
seq_lens_ptr=seq_lens,
|
||||||
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
scale=sm_scale,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
num_query_heads=num_query_heads,
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
block_table_stride=block_table.stride(0),
|
||||||
|
query_stride_0=query.stride(0),
|
||||||
|
query_stride_1=query.stride(1),
|
||||||
|
output_stride_0=output.stride(0),
|
||||||
|
output_stride_1=output.stride(1),
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
HEAD_SIZE=head_size,
|
||||||
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
x=key_cache.shape[4],
|
||||||
|
stride_k_cache_0=key_cache.stride(0),
|
||||||
|
stride_k_cache_1=key_cache.stride(1),
|
||||||
|
stride_k_cache_2=key_cache.stride(2),
|
||||||
|
stride_k_cache_3=key_cache.stride(3),
|
||||||
|
stride_k_cache_4=key_cache.stride(4),
|
||||||
|
stride_v_cache_0=value_cache.stride(0),
|
||||||
|
stride_v_cache_1=value_cache.stride(1),
|
||||||
|
stride_v_cache_2=value_cache.stride(2),
|
||||||
|
stride_v_cache_3=value_cache.stride(3),
|
||||||
|
filter_by_query_len=True,
|
||||||
|
query_start_len_ptr=query_start_loc,
|
||||||
|
)
|
@ -64,7 +64,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
SLIDING_WINDOW: tl.constexpr,
|
SLIDING_WINDOW: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
start_m = tl.program_id(2)
|
start_m = tl.program_id(2)
|
||||||
@ -78,6 +80,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_batch_in_all_start_index)
|
cur_batch_in_all_start_index)
|
||||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
# start position inside of the query
|
# start position inside of the query
|
||||||
# generally, N goes over kv, while M goes over query_len
|
# generally, N goes over kv, while M goes over query_len
|
||||||
block_start_loc = BLOCK_M * start_m
|
block_start_loc = BLOCK_M * start_m
|
||||||
@ -500,6 +505,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL: tl.constexpr, # head size
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
SKIP_DECODE: tl.constexpr,
|
||||||
):
|
):
|
||||||
# attn_bias[]
|
# attn_bias[]
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
@ -518,6 +524,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_batch_in_all_start_index)
|
cur_batch_in_all_start_index)
|
||||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
|
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||||
|
return
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
# initialize offsets
|
# initialize offsets
|
||||||
@ -721,7 +730,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
sm_scale=None):
|
sm_scale=None,
|
||||||
|
skip_decode=False):
|
||||||
|
|
||||||
q_dtype_is_f32 = q.dtype is torch.float32
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
# need to reduce num. blocks when using fp32
|
# need to reduce num. blocks when using fp32
|
||||||
@ -823,6 +833,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
num_warps=NUM_WARPS,
|
num_warps=NUM_WARPS,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
@ -875,6 +886,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
SLIDING_WINDOW=sliding_window,
|
SLIDING_WINDOW=sliding_window,
|
||||||
|
SKIP_DECODE=skip_decode,
|
||||||
num_warps=NUM_WARPS,
|
num_warps=NUM_WARPS,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
|
@ -6,8 +6,9 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||||
|
chunked_prefill_paged_decode)
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
from vllm.v1.attention.backends.flash_attn import (
|
||||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||||
@ -156,20 +157,22 @@ class ROCmAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
context_attention_fwd(q=query[:num_actual_tokens],
|
chunked_prefill_paged_decode(
|
||||||
k=key[:num_actual_tokens],
|
query=query[:num_actual_tokens],
|
||||||
v=value[:num_actual_tokens],
|
key=key[:num_actual_tokens],
|
||||||
o=output[:num_actual_tokens],
|
value=value[:num_actual_tokens],
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
output=output[:num_actual_tokens],
|
||||||
k_cache=key_cache,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
v_cache=value_cache,
|
key_cache=key_cache,
|
||||||
b_loc=attn_metadata.block_table,
|
value_cache=value_cache,
|
||||||
b_start_loc=attn_metadata.query_start_loc,
|
block_table=attn_metadata.block_table,
|
||||||
b_seq_len=attn_metadata.seq_lens,
|
query_start_loc=attn_metadata.query_start_loc,
|
||||||
max_input_len=attn_metadata.max_query_len,
|
seq_lens=attn_metadata.seq_lens,
|
||||||
k_scale=layer._k_scale,
|
max_query_len=attn_metadata.max_query_len,
|
||||||
v_scale=layer._v_scale,
|
k_scale=layer._k_scale,
|
||||||
alibi_slopes=self.alibi_slopes,
|
v_scale=layer._v_scale,
|
||||||
sliding_window=self.sliding_window[0],
|
alibi_slopes=self.alibi_slopes,
|
||||||
sm_scale=self.scale)
|
sliding_window=self.sliding_window[0],
|
||||||
|
sm_scale=self.scale)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user