[Hardware][Intel] Add CPU inference backend (#3634)

Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Yuan Zhou <yuan.zhou@intel.com>
This commit is contained in:
bigPYJ1151 2024-04-02 13:07:30 +08:00 committed by GitHub
parent eb69d68804
commit 0e3f06fe9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2747 additions and 5 deletions

View File

@ -0,0 +1,14 @@
# This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
# Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu .
# Setup cleanup
remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image and launch offline inference
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-check cpu-test python3 examples/offline_inference.py

View File

@ -8,6 +8,9 @@ steps:
queue: amd
command: bash .buildkite/run-amd-test.sh
- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh
- label: ":docker: build image"
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."

View File

@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)
project(vllm_extensions LANGUAGES CXX)
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
@ -76,6 +79,19 @@ find_package(Torch REQUIRED)
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
endif()
return()
endif()
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.

20
Dockerfile.cpu Normal file
View File

@ -0,0 +1,20 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
FROM ubuntu:22.04
RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
CMD ["/bin/bash"]

90
cmake/cpu_extension.cmake Normal file
View File

@ -0,0 +1,90 @@
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
#
# Define environment variables for special configurations
#
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512BF16 ON)
endif()
include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
execute_process(COMMAND cat /proc/cpuinfo
RESULT_VARIABLE CPUINFO_RET
OUTPUT_VARIABLE CPUINFO)
if (NOT CPUINFO_RET EQUAL 0)
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
endif()
function (find_isa CPUINFO TARGET OUT)
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
if(NOT ISA_FOUND EQUAL -1)
set(${OUT} ON PARENT_SCOPE)
else()
set(${OUT} OFF PARENT_SCOPE)
endif()
endfunction()
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
if (AVX512_FOUND)
list(APPEND CXX_COMPILE_FLAGS
"-mavx512f"
"-mavx512vl"
"-mavx512bw"
"-mavx512dq")
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
else()
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
#
# Define extension targets
#
#
# _C extension
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp")
define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
WITH_SOABI
)
add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

148
csrc/cpu/activation.cpp Normal file
View File

@ -0,0 +1,148 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
scalar_t *__restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
int start = i * d;
if constexpr (is_gated) {
start *= 2;
}
const scalar_vec_t x(input + start + j);
const vec_op::FP32Vec8 f32_x(x);
vec_op::FP32Vec8 f32_ans = func(f32_x);
if constexpr (is_gated) {
const scalar_vec_t y(input + start + d + j);
const vec_op::FP32Vec8 f32_y(y);
f32_ans = f32_y * f32_ans;
}
const scalar_vec_t result(f32_ans);
result.save(output + i * d + j);
}
}
}
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 x3 = x * x * x;
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
const vec_op::FP32Vec8 w3(0.044715);
const vec_op::FP32Vec8 x_3 = x * x * x;
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}
void gelu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
activation_kernel<scalar_t, gelu_tanh_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
});
}
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_new_impl)
activation_kernel<scalar_t, gelu_new_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
});
}
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
activation_kernel<scalar_t, gelu_fast_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}

744
csrc/cpu/attention.cpp Normal file
View File

@ -0,0 +1,744 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t> struct KernelVecType {
using q_load_vec_type = void;
using q_vec_type = void;
using k_load_vec_type = void;
using k_vec_type = void;
using qk_acc_vec_type = void;
using v_load_vec_type = void;
};
template <> struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
};
#ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32;
using k_vec_type = vec_op::BF16Vec32;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
const int capacity) {
T max = data[0];
for (int i = 1; i < size; ++i) {
max = max >= data[i] ? max : data[i];
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index,
const int context_len) {
data[0] += alibi_slope * (start_index - context_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
data[i] = qk;
max = max >= qk ? max : qk;
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
const int size) {
T max = max_data[0];
for (int i = 1; i < size; ++i) {
max = max >= max_data[i] ? max : max_data[i];
}
T rescaled_sum = 0;
for (int i = 0; i < size; ++i) {
T rescale_factor = std::exp(max_data[i] - max);
rescaled_sum += rescale_factor * sum_data[i];
sum_data[i] *= rescale_factor;
}
for (int i = 0; i < size; ++i) {
sum_data[i] /= rescaled_sum + 1e-8;
}
}
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
struct reduceQKBlockKernel {
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
const scalar_t *__restrict__ k_block,
float *__restrict__ logits, float scale,
const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
if (token_num == BLOCK_SIZE) {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
} else {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
for (int token_group_start = 0; token_group_start < group_num;
token_group_start += UNROLL_GROUP_NUM) {
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
[token_group_start, k_block, &q_group_vec,
&group_accums](int token_group_idx) {
token_group_idx += token_group_start;
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
}
}
for (int token_group_idx = 0; token_group_idx < group_num;
++token_group_idx) {
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
[&group_accums, logits, scale, token_group_idx](int token_idx) {
float dot_v =
group_accums[token_group_idx]
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
TOKEN_PER_GROUP>(token_idx);
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
dot_v * scale;
});
}
}
};
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc_t &&acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM);
vec_op::FP32Vec16 prob_vec(prob);
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
vec_op::FP32Vec16 fp32_v_vec(v_vec);
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
});
}
}; // namespace
// Paged attention v1
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl {
static void
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx];
const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num =
context_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
// Compute softmax
if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len);
} else {
reduceSoftmax(thread_block_logits, context_len,
block_num * BLOCK_SIZE);
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
std::free(logits);
}
};
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale,
torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
});
}
// Paged attention v2
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl {
static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len)
continue;
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1);
const int context_token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num =
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
std::pair<float, float> max_and_sum;
if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len);
} else {
max_and_sum = reduceSoftmax(logits, context_token_num,
block_num * BLOCK_SIZE);
}
auto &&[max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr;
if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
max_logits[idx] = max_logit;
exp_sums[idx] = exp_sum;
output_buffer =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
} else {
output_buffer =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
}
// Rescale partition softmax and store the factors to exp_sums
#pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx];
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
partition_num);
}
}
// Reduce values
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
// didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx];
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
const float *__restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group;
vec_op::FP32Vec16 acc;
for (int i = 0; i < partition_num; ++i) {
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
vec_op::FP32Vec16 fp32_value(value);
acc = acc + fp32_value * rescale_factor;
}
v_load_vec_type cast_acc(acc);
cast_acc.save(seq_head_output);
}
}
}
}
};
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \
max_context_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out,
torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
});
}

139
csrc/cpu/cache.cpp Normal file
View File

@ -0,0 +1,139 @@
#include <map>
#include <vector>
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size();
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second;
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
}
}
}
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx >= 0) {
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
}
}
for (int src_value_idx = 0; src_value_idx < head_size;
++src_value_idx) {
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
}
}
}
}
}
}; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache,
torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
value_stride, num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
const std::map<int64_t, int64_t> &block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}

352
csrc/cpu/cpu_types.hpp Normal file
View File

@ -0,0 +1,352 @@
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
#include <immintrin.h>
#include <torch/extension.h>
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T> struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128h reg;
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(__m128h data) : reg(data) {}
FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
};
#endif
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit BF16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
explicit BF16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
explicit BF16Vec16(const FP32Vec16 &);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m512i reg;
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1),
(__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {}
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__m128 reg;
float values[VEC_ELEM_NUM];
};
__m128 reg;
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
__m256 reg;
float values[VEC_ELEM_NUM];
};
__m256 reg;
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
expf(ar.values[5]), expf(ar.values[4]),
expf(ar.values[3]), expf(ar.values[2]),
expf(ar.values[1]), expf(ar.values[0])));
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
tanhf(ar.values[5]), tanhf(ar.values[4]),
tanhf(ar.values[3]), tanhf(ar.values[2]),
tanhf(ar.values[1]), tanhf(ar.values[0])));
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
erf(ar.values[5]), erf(ar.values[4]),
erf(ar.values[3]), erf(ar.values[2]),
erf(ar.values[1]), erf(ar.values[0])));
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
}
FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg));
}
FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
}
FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg));
}
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512 reg;
float values[VEC_ELEM_NUM];
};
__m512 reg;
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
(__m128i)data.reg, 1),
(__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16 &v)
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg);
}
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
template <typename T> struct VecType { using vec_type = void; };
template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
}
#else
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::BFloat16 *>(&v);
*ptr = *(v_ptr + 1);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#endif
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
}; // namespace vec_op
#endif

117
csrc/cpu/layernorm.cpp Normal file
View File

@ -0,0 +1,117 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out,
const scalar_t *__restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto output_p = out + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
vec_op::FP32Vec8 fp32_x(x);
variance = variance + fp32_x * fp32_x;
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t w(weight + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(output_p + j);
}
}
}
template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
scalar_t *__restrict__ residual,
const scalar_t *__restrict__ weight,
const float epsilon, const int num_tokens,
const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto residual_p = residual + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_res(res);
fp32_x = fp32_x + fp32_res;
variance = variance + fp32_x * fp32_x;
scalar_vec_t out(fp32_x);
out.save(residual_p + j);
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t w(weight + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_res(res);
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(input_p + j);
}
}
}
} // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
});
}
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
fused_add_rms_norm_impl(
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
});
}

199
csrc/cpu/pos_encoding.cpp Normal file
View File

@ -0,0 +1,199 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rotary_embedding_impl(
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
}
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
}
}
}
template <typename scalar_t>
void rotary_embedding_gptj_impl(
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
const int embed_dim = rot_dim / 2;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_query[x_index];
const float y = head_query[y_index];
head_query[x_index] = x * cos - y * sin;
head_query[y_index] = y * cos + x * sin;
}
}
}
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_key[x_index];
const float y = head_key[y_index];
head_key[x_index] = x * cos - y * sin;
head_key[y_index] = y * cos + x * sin;
}
}
}
}
}; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t key_stride = key.stride(-2);
int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "rotary_embedding_impl", [&] {
CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
if (is_neox) {
rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
} else {
rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
}
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
});
}

73
csrc/cpu/pybind.cpp Normal file
View File

@ -0,0 +1,73 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
}

View File

@ -0,0 +1,87 @@
.. _installation_cpu:
Installation with CPU
========================
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16.
Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Performance tips <cpu_backend_performance_tips>`
.. _cpu_backend_requirements:
Requirements
------------
* OS: Linux
* Compiler: gcc/g++>=12.3.0 (recommended)
* Instruction set architecture (ISA) requirement: AVX512 is required.
.. _cpu_backend_quick_start_dockerfile:
Quick start using Dockerfile
----------------------------
.. code-block:: console
$ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g .
$ docker run -it \
--rm \
--network=host \
--cpuset-cpus=<cpu-id-list, optional> \
--cpuset-mems=<memory-node, optional> \
vllm-cpu-env
.. _build_cpu_backend_from_source:
Build from source
-----------------
- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install -y gcc-12 g++-12
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
- Second, install Python packages for vLLM CPU backend building:
.. code-block:: console
$ pip install --upgrade pip
$ pip install wheel packaging ninja setuptools>=49.4.0 numpy
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, build and install vLLM CPU backend:
.. code-block:: console
$ VLLM_TARGET_DEVICE=cpu python setup.py install
.. note::
- BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support.
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
.. _cpu_backend_performance_tips:
Performance tips
-----------------
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful.

View File

@ -63,6 +63,7 @@ Documentation
getting_started/installation
getting_started/amd-installation
getting_started/neuron-installation
getting_started/cpu-installation
getting_started/quickstart
.. toctree::

15
requirements-cpu.txt Normal file
View File

@ -0,0 +1,15 @@
cmake>=3.21
ninja # For faster builds.
psutil
ray >= 2.9
sentencepiece # Required for LLaMA tokenizer.
numpy
transformers >= 4.38.0 # Required for Gemma.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
torch == 2.1.2+cpu
triton >= 2.1.0
filelock == 3.13.3
py-cpuinfo

View File

@ -15,6 +15,8 @@ from torch.utils.cpp_extension import CUDA_HOME
ROOT_DIR = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu]
VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda")
# vLLM only supports Linux platform
assert sys.platform.startswith(
@ -112,6 +114,7 @@ class cmake_build_ext(build_ext):
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
]
verbose = bool(int(os.getenv('VERBOSE', '0')))
@ -185,11 +188,14 @@ class cmake_build_ext(build_ext):
def _is_cuda() -> bool:
return torch.version.cuda is not None and not _is_neuron()
return VLLM_TARGET_DEVICE == "cuda" \
and torch.version.cuda is not None \
and not _is_neuron()
def _is_hip() -> bool:
return torch.version.hip is not None
return (VLLM_TARGET_DEVICE == "cuda"
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
def _is_neuron() -> bool:
@ -201,6 +207,10 @@ def _is_neuron() -> bool:
return torch_neuronx_installed
def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"
def _install_punica() -> bool:
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
@ -296,6 +306,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_cpu():
version += "+cpu"
else:
raise RuntimeError("Unknown runtime environment")
@ -322,6 +334,9 @@ def get_requirements() -> List[str]:
elif _is_neuron():
with open(get_path("requirements-neuron.txt")) as f:
requirements = f.read().strip().split("\n")
elif _is_cpu():
with open(get_path("requirements-cpu.txt")) as f:
requirements = f.read().strip().split("\n")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCM or Neuron.")

View File

@ -0,0 +1,253 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor]
num_prompt_tokens: int
num_generation_tokens: int
max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None
subquery_start_loc: Optional[torch.Tensor] = None
seq_start_loc: Optional[torch.Tensor] = None
use_cuda_graph: bool = False
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
class TorchSDPABackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prompt_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens,
attn_metadata.attn_bias):
end = start + prompt_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
value[:, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
prompt_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
prompt_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
tensor = torch.full(
(1, prompt_len, prompt_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases

View File

@ -5,7 +5,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.utils import is_cpu, is_hip
logger = init_logger(__name__)
@ -17,6 +17,10 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif is_cpu():
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
@ -29,6 +33,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool:
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if is_cpu():
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "

View File

@ -10,7 +10,8 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -598,6 +599,8 @@ class DeviceConfig:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_cpu():
self.device_type = "cpu"
else:
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked

View File

@ -332,7 +332,7 @@ class EngineArgs:
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
choices=["auto", "cuda", "neuron", "cpu"],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser.add_argument(

View File

@ -178,6 +178,9 @@ class LLMEngine:
if device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor

View File

@ -0,0 +1,154 @@
import os
from typing import Dict, List, Optional
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
assert device_config.device_type == "cpu"
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Instantiate the worker and load the model to CPU.
self._init_worker()
self._init_cache()
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_cache(self) -> None:
num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num(
block_size=self.cache_config.block_size,
cache_space=self.cache_config.cpu_kvcache_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
logger.info(f"# CPU blocks: {num_cpu_blocks}")
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore
self.cache_config.num_cpu_blocks = 0 # type: ignore
# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
def list_loras(self) -> List[int]:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager "
"mode.")
config.enforce_eager = True
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False
kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")
kv_cache_space = int(kv_cache_space_str)
if kv_cache_space >= 0:
if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.")
else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
return config

View File

@ -117,6 +117,13 @@ def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import version
is_cpu_flag = "cpu" in version("vllm")
return is_cpu_flag
@lru_cache(maxsize=None)
def is_neuron() -> bool:
try:
@ -362,6 +369,9 @@ def is_pin_memory_available() -> bool:
elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu():
print_warning_once("Pin memory is not supported on CPU.")
return False
return True

280
vllm/worker/cpu_worker.py Normal file
View File

@ -0,0 +1,280 @@
"""A CPU worker class."""
from typing import Dict, List, Optional
import torch
import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.model_runner import ModelRunner
logger = init_logger(__name__)
class CPUModelRunner(ModelRunner):
def load_model(self) -> None:
self.model = get_model(self.model_config,
self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[torch.Tensor]:
"""Allocates KV cache on CPU."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
class CPUWorker:
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = CPUModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cpu_cache = None
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def get_cpu_cache_block_num(
self,
block_size: int,
cache_space: int,
cache_dtype: str,
) -> int:
"""
Args:
block_size: The size of the cache block.
cache_space: The size of the CPU KV cache space in bytes.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = CPUCacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)
num_cpu_blocks = int(cache_space // cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
return num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CPUCacheEngine(self.cache_config,
self.model_config,
self.parallel_config,
self.device_config)
self.cpu_cache = self.cache_engine.cpu_cache
self.model_runner.block_size = self.cache_engine.block_size
assert self.cpu_cache is not None
# Populate the cache to warmup the memory
for layer_cache in self.cpu_cache:
layer_cache.fill_(0)
def cache_copy(
self,
blocks_to_copy: Dict[int, List[int]],
) -> None:
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
return output
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)