[Kernel][Hardware][Amd]Custom paged attention kernel for rocm (#8310)
This commit is contained in:
parent
851725202a
commit
1ef0d2efd0
@ -324,6 +324,25 @@ define_gpu_extension_target(
|
|||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
|
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
|
#
|
||||||
|
# _rocm_C extension
|
||||||
|
#
|
||||||
|
set(VLLM_ROCM_EXT_SRC
|
||||||
|
"csrc/rocm/torch_bindings.cpp"
|
||||||
|
"csrc/rocm/attention.cu")
|
||||||
|
|
||||||
|
define_gpu_extension_target(
|
||||||
|
_rocm_C
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE ${VLLM_GPU_LANG}
|
||||||
|
SOURCES ${VLLM_ROCM_EXT_SRC}
|
||||||
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
@ -331,5 +350,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
|
|
||||||
message(STATUS "Enabling moe extension.")
|
message(STATUS "Enabling moe extension.")
|
||||||
add_dependencies(default _moe_C)
|
add_dependencies(default _moe_C)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
|
message(STATUS "Enabling rocm extension.")
|
||||||
|
add_dependencies(default _rocm_C)
|
||||||
endif()
|
endif()
|
||||||
|
1038
csrc/rocm/attention.cu
Normal file
1038
csrc/rocm/attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
13
csrc/rocm/ops.h
Normal file
13
csrc/rocm/ops.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||||
|
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache, int64_t num_kv_heads,
|
||||||
|
double scale, torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens, int64_t block_size,
|
||||||
|
int64_t max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
33
csrc/rocm/torch_bindings.cpp
Normal file
33
csrc/rocm/torch_bindings.cpp
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#include "core/registration.h"
|
||||||
|
#include "rocm/ops.h"
|
||||||
|
|
||||||
|
// Note on op signatures:
|
||||||
|
// The X_meta signatures are for the meta functions corresponding to op X.
|
||||||
|
// They must be kept in sync with the signature for X. Generally, only
|
||||||
|
// functions that return Tensors require a meta function.
|
||||||
|
//
|
||||||
|
// See the following links for detailed docs on op registration and function
|
||||||
|
// schemas.
|
||||||
|
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
||||||
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||||
|
// vLLM custom ops for rocm
|
||||||
|
|
||||||
|
// Custom attention op
|
||||||
|
// Compute the attention between an input query and the cached
|
||||||
|
// keys/values using PagedAttention.
|
||||||
|
rocm_ops.def(
|
||||||
|
"paged_attention(Tensor! out, Tensor exp_sums,"
|
||||||
|
" Tensor max_logits, Tensor tmp_out,"
|
||||||
|
" Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads,"
|
||||||
|
" float scale, Tensor block_tables,"
|
||||||
|
" Tensor context_lens, int block_size,"
|
||||||
|
" int max_context_len,"
|
||||||
|
" Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype) -> ()");
|
||||||
|
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
3
setup.py
3
setup.py
@ -462,6 +462,9 @@ if _build_core_ext():
|
|||||||
if _is_cuda() or _is_hip():
|
if _is_cuda() or _is_hip():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||||
|
|
||||||
|
if _is_hip():
|
||||||
|
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||||
|
|
||||||
if _build_custom_ops():
|
if _build_custom_ops():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||||
|
|
||||||
|
@ -3,8 +3,6 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from xformers import ops as xops
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -12,6 +10,10 @@ from vllm.utils import get_max_shared_memory_bytes, is_hip
|
|||||||
|
|
||||||
from .allclose_default import get_default_atol, get_default_rtol
|
from .allclose_default import get_default_atol, get_default_rtol
|
||||||
|
|
||||||
|
if not is_hip():
|
||||||
|
from xformers import ops as xops
|
||||||
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
# This will change depending on the compute capability.
|
# This will change depending on the compute capability.
|
||||||
# - 512 as a buffer
|
# - 512 as a buffer
|
||||||
@ -328,6 +330,165 @@ def ref_multi_query_kv_attention(
|
|||||||
return torch.cat(ref_outputs, dim=0)
|
return torch.cat(ref_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version", ["rocm"])
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
|
||||||
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
|
||||||
|
def test_paged_attention_rocm(
|
||||||
|
kv_cache_factory,
|
||||||
|
version: str,
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
|
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
alibi_slopes = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||||
|
|
||||||
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
context_lens[-1] = MAX_SEQ_LEN
|
||||||
|
#context_lens = [8192 for _ in range(num_seqs)]
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||||
|
#print('>>> ctx lens', context_lens)
|
||||||
|
|
||||||
|
# Create the block tables.
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||||
|
|
||||||
|
# Create the KV caches.
|
||||||
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
|
num_kv_heads, head_size,
|
||||||
|
kv_cache_dtype, dtype, seed,
|
||||||
|
device)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
|
# TODO(charlifu) enable fp8 kv cache
|
||||||
|
# Using default kv_scale
|
||||||
|
# kv_scale = 1.0
|
||||||
|
|
||||||
|
# Call the paged attention kernel.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
PARTITION_SIZE_ROCM = 256
|
||||||
|
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
|
||||||
|
PARTITION_SIZE_ROCM)
|
||||||
|
assert PARTITION_SIZE_ROCM % block_size == 0
|
||||||
|
num_seqs, num_heads, head_size = output.shape
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
if version == "rocm":
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
# Convert cache data back to dtype.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||||
|
block_size, x)
|
||||||
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||||
|
key_cache = dequantized_key_cache
|
||||||
|
|
||||||
|
value_cache_shape = value_cache.shape
|
||||||
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||||
|
value_cache = dequantized_value_cache
|
||||||
|
|
||||||
|
ref_output = torch.empty_like(query)
|
||||||
|
ref_single_query_cached_kv_attention(
|
||||||
|
ref_output,
|
||||||
|
query,
|
||||||
|
num_queries_per_kv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
scale,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
|
# implementations, there is a small numerical difference in the two
|
||||||
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
|
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||||
|
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||||
|
|
||||||
|
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||||
|
# so we use a relaxed tolerance for the test.
|
||||||
|
atol, rtol = 1e-4, 1e-5
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
atol, rtol = 2e-4, 1e-5
|
||||||
|
if use_alibi:
|
||||||
|
if dtype == torch.half:
|
||||||
|
atol, rtol = 5e-4, 1e-5
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
atol, rtol = 1e-3, 1e-5
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
atol, rtol = 1e-2, 1e-5
|
||||||
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@ -335,6 +496,7 @@ def ref_multi_query_kv_attention(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_multi_query_kv_attention(
|
def test_multi_query_kv_attention(
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
|
@ -17,6 +17,9 @@ if not current_platform.is_tpu():
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
import vllm._rocm_C # noqa: F401
|
||||||
|
|
||||||
with contextlib.suppress(ImportError):
|
with contextlib.suppress(ImportError):
|
||||||
import vllm._moe_C # noqa: F401
|
import vllm._moe_C # noqa: F401
|
||||||
|
|
||||||
@ -127,6 +130,30 @@ def paged_attention_v2(
|
|||||||
blocksparse_block_size, blocksparse_head_sliding_step)
|
blocksparse_block_size, blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention_rocm(
|
||||||
|
out: torch.Tensor,
|
||||||
|
exp_sum: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
tmp_out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
) -> None:
|
||||||
|
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||||
|
key_cache, value_cache, num_kv_heads,
|
||||||
|
scale, block_tables, seq_lens,
|
||||||
|
block_size, max_seq_len, alibi_slopes,
|
||||||
|
kv_cache_dtype)
|
||||||
|
|
||||||
|
|
||||||
# pos encoding ops
|
# pos encoding ops
|
||||||
def rotary_embedding(
|
def rotary_embedding(
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||||
@ -15,6 +16,9 @@ from vllm.logger import init_logger
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PARTITION_SIZE = 256
|
||||||
|
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -480,20 +484,61 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
# Whether to use rocm custom paged attention or not
|
||||||
decode_query,
|
num_seqs, num_heads, head_size = decode_query.shape
|
||||||
key_cache,
|
block_size = value_cache.shape[3]
|
||||||
value_cache,
|
gqa_ratio = num_heads // self.num_kv_heads
|
||||||
decode_meta.block_tables,
|
use_custom = use_rocm_custom_paged_attention(
|
||||||
decode_meta.seq_lens_tensor,
|
decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
|
||||||
decode_meta.max_decode_seq_len,
|
gqa_ratio, decode_meta.max_decode_seq_len)
|
||||||
self.kv_cache_dtype,
|
if use_custom:
|
||||||
self.num_kv_heads,
|
max_seq_len = decode_meta.max_decode_seq_len
|
||||||
self.scale,
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||||
self.alibi_slopes,
|
_PARTITION_SIZE)
|
||||||
k_scale,
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
v_scale,
|
tmp_output = torch.empty(
|
||||||
)
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output[num_prefill_tokens:],
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
decode_query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
decode_meta.block_tables,
|
||||||
|
decode_meta.seq_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||||
|
decode_query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
decode_meta.block_tables,
|
||||||
|
decode_meta.seq_lens_tensor,
|
||||||
|
decode_meta.max_decode_seq_len,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
self.alibi_slopes,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
@ -532,3 +577,14 @@ def _sdpa_attention(
|
|||||||
start = end
|
start = end
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||||
|
block_size: int, kv_cache_dtype: str,
|
||||||
|
gqa_ratio: int, max_seq_len: int) -> bool:
|
||||||
|
# rocm custom page attention not support on navi (gfx1*)
|
||||||
|
return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
|
and (head_size == 64 or head_size == 128)
|
||||||
|
and (block_size == 16 or block_size == 32)
|
||||||
|
and kv_cache_dtype == "auto"
|
||||||
|
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user