vllm/csrc/rocm/torch_bindings.cpp
Aleksandr Malyshev e73ff24e31
[ROCM][KERNEL] Paged attention for V1 (#15720)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
2025-04-02 19:48:00 -07:00

37 lines
1.5 KiB
C++

#include "core/registration.h"
#include "rocm/ops.h"
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
rocm_ops.def(
"paged_attention(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)