[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

This commit is contained in:
bnellnm 2024-06-09 16:23:30 -04:00 committed by GitHub
parent 5d7e3d0176
commit 5467ac3196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 833 additions and 451 deletions

View File

@ -66,19 +66,6 @@ endif()
#
find_package(Torch REQUIRED)
#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually with `find_library` using torch's
# installed library path.
#
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
@ -218,6 +205,7 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
#
@ -225,7 +213,7 @@ define_gpu_extension_target(
#
set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu")
define_gpu_extension_target(
@ -235,6 +223,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
#
@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp")
"csrc/punica/torch_bindings.cpp")
#
# Copy GPU compilation flags+update for punica
@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "

View File

@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
&& cd ..

View File

@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
list(APPEND CXX_COMPILE_FLAGS
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
@ -44,8 +44,8 @@ if (AVX512_FOUND)
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
else()
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
@ -73,7 +73,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp")
"csrc/cpu/torch_bindings.cpp")
define_gpu_extension_target(
_C
@ -81,10 +81,10 @@ define_gpu_extension_target(
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
WITH_SOABI
USE_SABI 3
WITH_SOABI
)
add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

View File

@ -5,7 +5,7 @@
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module)
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
@ -294,6 +294,7 @@ endmacro()
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
#
# Note: optimization level/debug info is set via cmake build type.
#
@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE"
"DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm.
@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
set(GPU_WITH_SOABI)
endif()
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()
if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.

View File

@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>

View File

@ -17,7 +17,7 @@
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
@ -808,16 +808,17 @@ void paged_attention_v1(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
@ -972,16 +973,17 @@ void paged_attention_v2(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
@ -990,4 +992,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/extension.h>
#include <torch/all.h>
#include <map>
#include <vector>
@ -8,14 +8,18 @@
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const float kv_scale);
const std::string& kv_cache_dtype,
const double kv_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float scale, const std::string& kv_cache_dtype);
const double scale, const std::string& kv_cache_dtype);

View File

@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
} // namespace vllm
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
@ -255,7 +258,7 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) {
const std::string& kv_cache_dtype, const double kv_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) {
const double kv_scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")

View File

@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");

View File

@ -5,8 +5,8 @@
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
}
}; // namespace
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, float kv_scale) {
const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0);

View File

@ -3,7 +3,7 @@
#define CPU_TYPES_HPP
#include <immintrin.h>
#include <torch/extension.h>
#include <torch/all.h>
namespace vec_op {

View File

@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
} // namespace
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon) {
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
}
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon) {
torch::Tensor& weight, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

View File

@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);

View File

@ -1,43 +0,0 @@
#include "cache.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
}

106
csrc/cpu/torch_bindings.cpp Normal file
View File

@ -0,0 +1,106 @@
#include "cache.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCPU, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCPU, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -1,7 +1,5 @@
#pragma once
#include <torch/extension.h>
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int get_device_attribute(int attribute, int device_id);
int get_max_shared_memory_per_block_device_attribute(int device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);

View File

@ -2,7 +2,7 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int get_device_attribute(int attribute, int device_id) {
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
int device, value;
if (device_id < 0) {
cudaGetDevice(&device);
@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) {
return value;
}
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int attribute;
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
int64_t attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74

View File

@ -1,17 +1,17 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <torch/all.h>
#include "custom_all_reduce.cuh"
// fake pointer type
using fptr_t = uint64_t;
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size());
}
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
@ -125,7 +125,7 @@ void dispose(fptr_t _fa) {
delete fa;
}
int meta_size() { return sizeof(vllm::Signal); }
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t,
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta();
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,

View File

@ -4,7 +4,7 @@
*/
#pragma once
#include <torch/extension.h>
#include <torch/all.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \

View File

@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
@ -291,7 +291,7 @@ fused_add_rms_norm_kernel(
void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

View File

@ -1,8 +0,0 @@
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
}

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/extension.h>
#include <torch/all.h>
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,

View File

@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"

View File

@ -0,0 +1,12 @@
#include "registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ATen.h>
@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
} // namespace vllm
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int block_size, torch::Tensor sorted_token_ids,
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@ -1,40 +1,42 @@
#pragma once
#include <torch/extension.h>
#include <torch/library.h>
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon);
double epsilon);
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon);
torch::Tensor& weight, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
int rot_dim,
int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters);
int64_t split_k_iters);
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx,
int thy);
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
@ -88,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif
@ -106,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int bit);
bool use_exllama, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
@ -116,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int block_size, torch::Tensor sorted_token_ids,
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
using fptr_t = uint64_t;
using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa);
int meta_size();
int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);

View File

@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
@ -127,7 +127,7 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size,
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
@ -138,7 +138,7 @@ void rotary_embedding(
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
@ -168,9 +168,9 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size,
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int rot_dim,
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) {
int64_t num_tokens = cos_sin_cache_offsets.size(0);
@ -180,7 +180,7 @@ void batched_rotary_embedding(
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {

View File

@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) {
torch::Tensor indicies, int64_t layer_idx, double scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);

View File

@ -1,11 +1,11 @@
#pragma once
#include <torch/extension.h>
#include <torch/all.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale);
torch::Tensor indicies, int64_t layer_idx, double scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
double scale, int64_t h_in, int64_t h_out,
int64_t y_offset);

View File

@ -1,13 +0,0 @@
#include <torch/extension.h>
#include "punica_ops.h"
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

View File

@ -0,0 +1,18 @@
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
m.def(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()");
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -1,114 +0,0 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
"Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"Compute FP8 quantized tensor and scaling factor");
ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
"Compute int8 quantized tensor and scaling factor");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def("convert_fp8", &convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils =
m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def("get_device_attribute", &get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
}

View File

@ -18,7 +18,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>

View File

@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx,
int thy) {
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy) {
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters) {
int64_t split_k_iters) {
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));

View File

@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"

View File

@ -1,5 +1,5 @@
#include <stddef.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>

View File

@ -4,7 +4,7 @@
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>

View File

@ -1,7 +1,7 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,

View File

@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>

View File

@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cstdint>
#include <cstdio>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int bit) {
bool use_exllama, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) {
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*)q_weight.data_ptr(),

View File

@ -1867,4 +1867,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
#endif
#endif

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

View File

@ -15,7 +15,7 @@
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

View File

@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

View File

@ -1,5 +1,4 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

22
csrc/registration.h Normal file
View File

@ -0,0 +1,22 @@
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}

283
csrc/torch_bindings.cpp Normal file
View File

@ -0,0 +1,283 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops.def("aqlm_gemm", &aqlm_gemm);
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM.
ops.def("aqlm_dequant", &aqlm_dequant);
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ.
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def("marlin_gemm", &marlin_gemm);
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// gptq_marlin repack from GPTQ.
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
#endif
// Quantized GEMM for GPTQ.
ops.def("gptq_gemm", &gptq_gemm);
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()");
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
"kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
cuda_utils.def("get_device_attribute", &get_device_attribute);
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute);
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def("init_custom_ar", &init_custom_ar);
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def("should_custom_ar", &should_custom_ar);
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
custom_ar.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
custom_ar.def("dispose", &dispose);
custom_ar.impl("dispose", torch::kCPU, &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.impl("meta_size", torch::kCPU, &meta_size);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
&get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.impl("register_graph_buffers", torch::kCPU,
&register_graph_buffers);
}
#endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -60,7 +60,7 @@ def remove_prefix(text, prefix):
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
super().__init__(name, sources=[], **kwa)
super().__init__(name, sources=[], py_limited_api=True, **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)

View File

@ -1,7 +1,8 @@
import pytest
import torch
from vllm._C import ops
# ruff: noqa: F401
import vllm._C
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
ops.static_scaled_int8_quant(out2, x, scale_argument)
torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors

View File

@ -1,35 +1,47 @@
from typing import Optional, Tuple, Type
import contextlib
from typing import List, Optional, Tuple, Type
import torch
try:
from vllm._C import cache_ops as vllm_cache_ops
from vllm._C import ops as vllm_ops
import vllm._C
except ImportError as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError):
# ruff: noqa: F401
import vllm._punica_C
def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)
torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)
torch.ops._C.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)
torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)
torch.ops._C.gelu_new(out, x)
# page attention ops
@ -53,7 +65,7 @@ def paged_attention_v1(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v1(
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
@ -83,7 +95,7 @@ def paged_attention_v2(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v2(
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
@ -100,8 +112,8 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
@ -109,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)
torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
thy)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
@ -144,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# marlin_24
@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m, size_n,
size_k)
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m,
size_n, size_k)
# cutlass
@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
return out
@ -198,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
@ -220,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
@ -259,9 +272,9 @@ def scaled_fp8_quant(
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
@ -284,14 +297,14 @@ def scaled_int8_quant(
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
# static-per-tensor quantization.
vllm_ops.static_scaled_int8_quant(output, input, scale)
torch.ops._C.static_scaled_int8_quant(output, input, scale)
return output, scale
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales
@ -300,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
def reshape_and_cache(
@ -314,8 +334,9 @@ def reshape_and_cache(
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
@ -326,25 +347,115 @@ def reshape_and_cache_flash(
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
scale: float = 1.0,
kv_dtype: str = "fp8") -> None:
vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype)
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
#TODO: cuda_utils, custom_ar
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# ruff: noqa: E501
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
device)
# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int,
full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
offsets, rank, full_nvlink)
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
offsets: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[str],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm._C import cache_ops
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
cache_ops.reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,

View File

@ -6,6 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import (
@ -15,7 +16,11 @@ from vllm.logger import init_logger
try:
import pynvml
from vllm._C import custom_ar
# Simulate ImportError if custom_ar ops are not supported.
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)
custom_ar = True
@contextmanager
def _nvml():
@ -27,7 +32,7 @@ try:
except ImportError:
# For AMD GPUs
custom_ar = None
custom_ar = False
pynvml = None
@contextmanager
@ -97,7 +102,7 @@ class CustomAllreduce:
self._IS_CAPTURING = False
self.disabled = True
if custom_ar is None:
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
@ -175,7 +180,7 @@ class CustomAllreduce:
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
self.meta = torch.zeros(ops.meta_size() + max_size,
dtype=torch.uint8,
device=self.device)
# This is a pre-registered IPC buffer. In eager mode, input tensors
@ -196,9 +201,8 @@ class CustomAllreduce:
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
handles, offsets, rank,
self.full_nvlink)
self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
offsets, rank, self.full_nvlink)
self.register_buffer(self.buffer)
@contextmanager
@ -252,31 +256,31 @@ class CustomAllreduce:
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
return ops.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out)
ops.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
@ -304,7 +308,7 @@ class CustomAllreduce:
def close(self):
if not self.disabled and self._ptr:
custom_ar.dispose(self._ptr)
ops.dispose(self._ptr)
self._ptr = 0
def __del__(self):

View File

@ -4,16 +4,21 @@ from typing import Optional
import torch
from vllm import _custom_ops as ops
def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return
def _raise_import_error(e):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") from e
"punica LoRA kernels require compute capability >= 8.0")
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e
"was set.")
def bgmv(
@ -41,12 +46,9 @@ def bgmv(
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
_check_punica_support()
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
_check_punica_support()
ops.dispatch_bgmv_low_level(
y,
x,
w_t_all,
@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor,
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
_check_punica_support()
r = wb_t_all.size(-1)
if buffer is None:
@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
def add_lora_slice(y: torch.Tensor,
@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
_check_punica_support()
r = wb_t_all.size(-1)
if buffer is None:
@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
ops.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor,
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
ops.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,

View File

@ -8,7 +8,6 @@ import torch
import triton
import triton.language as tl
import vllm._moe_C as moe_kernels
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@ -355,7 +354,7 @@ def fused_topk(
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,

View File

@ -22,6 +22,7 @@ import psutil
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
@ -148,12 +149,8 @@ def is_neuron() -> bool:
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
max_shared_mem = (
cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
ops.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"