[Kernel][Hardware][Amd]Custom paged attention kernel for rocm (#8310)
This commit is contained in:
parent
851725202a
commit
1ef0d2efd0
@ -324,6 +324,25 @@ define_gpu_extension_target(
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
# _rocm_C extension
|
||||
#
|
||||
set(VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/torch_bindings.cpp"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_rocm_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_ROCM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
@ -331,5 +350,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling rocm extension.")
|
||||
add_dependencies(default _rocm_C)
|
||||
endif()
|
||||
|
1038
csrc/rocm/attention.cu
Normal file
1038
csrc/rocm/attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
13
csrc/rocm/ops.h
Normal file
13
csrc/rocm/ops.h
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads,
|
||||
double scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
33
csrc/rocm/torch_bindings.cpp
Normal file
33
csrc/rocm/torch_bindings.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include "core/registration.h"
|
||||
#include "rocm/ops.h"
|
||||
|
||||
// Note on op signatures:
|
||||
// The X_meta signatures are for the meta functions corresponding to op X.
|
||||
// They must be kept in sync with the signature for X. Generally, only
|
||||
// functions that return Tensors require a meta function.
|
||||
//
|
||||
// See the following links for detailed docs on op registration and function
|
||||
// schemas.
|
||||
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
// vLLM custom ops for rocm
|
||||
|
||||
// Custom attention op
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
rocm_ops.def(
|
||||
"paged_attention(Tensor! out, Tensor exp_sums,"
|
||||
" Tensor max_logits, Tensor tmp_out,"
|
||||
" Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads,"
|
||||
" float scale, Tensor block_tables,"
|
||||
" Tensor context_lens, int block_size,"
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
3
setup.py
3
setup.py
@ -462,6 +462,9 @@ if _build_core_ext():
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
|
||||
if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
|
@ -3,8 +3,6 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
@ -12,6 +10,10 @@ from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
if not is_hip():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
@ -328,6 +330,165 @@ def ref_multi_query_kv_attention(
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
|
||||
def test_paged_attention_rocm(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
#context_lens = [8192 for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||
#print('>>> ctx lens', context_lens)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# TODO(charlifu) enable fp8 kv cache
|
||||
# Using default kv_scale
|
||||
# kv_scale = 1.0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
|
||||
PARTITION_SIZE_ROCM)
|
||||
assert PARTITION_SIZE_ROCM % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if version == "rocm":
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 2e-4, 1e-5
|
||||
if use_alibi:
|
||||
if dtype == torch.half:
|
||||
atol, rtol = 5e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@ -335,6 +496,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
|
@ -17,6 +17,9 @@ if not current_platform.is_tpu():
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import vllm._rocm_C # noqa: F401
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@ -127,6 +130,30 @@ def paged_attention_v2(
|
||||
blocksparse_block_size, blocksparse_head_sliding_step)
|
||||
|
||||
|
||||
def paged_attention_rocm(
|
||||
out: torch.Tensor,
|
||||
exp_sum: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
tmp_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads,
|
||||
scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
@ -15,6 +16,9 @@ from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE = 256
|
||||
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -480,20 +484,61 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
# Whether to use rocm custom paged attention or not
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
|
||||
gqa_ratio, decode_meta.max_decode_seq_len)
|
||||
if use_custom:
|
||||
max_seq_len = decode_meta.max_decode_seq_len
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_rocm(
|
||||
output[num_prefill_tokens:],
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
@ -532,3 +577,14 @@ def _sdpa_attention(
|
||||
start = end
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, kv_cache_dtype: str,
|
||||
gqa_ratio: int, max_seq_len: int) -> bool:
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and kv_cache_dtype == "auto"
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
|
Loading…
x
Reference in New Issue
Block a user