[Kernel][CPU] CPU MLA (#14744)
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
parent
4157f563b4
commit
4f044b1d67
@ -38,6 +38,8 @@ function cpu_tests() {
|
||||
set -e
|
||||
pip install -r vllm/requirements/test.txt
|
||||
pip install -r vllm/requirements/cpu.txt
|
||||
pytest -v -s tests/kernels/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||
|
@ -190,6 +190,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
|
@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
|
@ -130,6 +130,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
|
||||
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
393
csrc/cpu/mla_decode.cpp
Normal file
393
csrc/cpu/mla_decode.cpp
Normal file
@ -0,0 +1,393 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include <float.h>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using qk_load_vec_type = void;
|
||||
using qk_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures, including x86
|
||||
using qk_load_vec_type = vec_op::FP16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec32;
|
||||
using qk_vec_type = vec_op::BF16Vec32;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||
// pass
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block_head(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
|
||||
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
|
||||
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
|
||||
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [HEAD_UNROLL]
|
||||
const float scale, const int num_tokens) {
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
|
||||
float max_val[HEAD_UNROLL];
|
||||
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
|
||||
|
||||
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
|
||||
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
// load to registers
|
||||
qk_vec_type q_vec[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
q_vec[unroll] =
|
||||
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
|
||||
}
|
||||
}
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
|
||||
logits[block_offset][unroll] = acc;
|
||||
max_val[unroll] = std::max(max_val[unroll], acc);
|
||||
}
|
||||
}
|
||||
|
||||
float sum_exp[HEAD_UNROLL] = {};
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float val =
|
||||
std::exp(logits[block_offset][unroll] - max_val[unroll]);
|
||||
logits[block_offset][unroll] = val;
|
||||
sum_exp[unroll] += val;
|
||||
}
|
||||
}
|
||||
|
||||
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
// load to registers
|
||||
f32_vec_type scale_[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
scale_[unroll] =
|
||||
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
f32_vec_type v_vec(
|
||||
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
|
||||
}
|
||||
}
|
||||
|
||||
// merge attention state
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
f32_vec_type prev_scale[HEAD_UNROLL];
|
||||
f32_vec_type curr_scale[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float prev_lse = acc_lse[unroll];
|
||||
const float curr_lse = std::log(sum_exp[unroll]) +
|
||||
max_val[unroll]; // add back max_val to get true lse
|
||||
// softmax trick
|
||||
const float max_lse = std::max(prev_lse, curr_lse);
|
||||
const float prev_sum_exp = std::exp(prev_lse - max_lse);
|
||||
const float curr_sum_exp = std::exp(curr_lse - max_lse);
|
||||
|
||||
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
|
||||
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
|
||||
|
||||
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
|
||||
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
|
||||
o_vec = o_vec * prev_scale[unroll] +
|
||||
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
|
||||
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
|
||||
}
|
||||
}
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
|
||||
acc_out += V_HEAD_DIM * HEAD_UNROLL;
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
|
||||
float* __restrict__ acc_out, // [num_heads, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [num_heads]
|
||||
const int num_heads, const float scale, const int num_tokens) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
static_assert(
|
||||
std::is_same<qk_vec_type,
|
||||
typename KernelVecType<scalar_t>::qk_vec_type>::value);
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
|
||||
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
const qk_vec_type* k_vecs;
|
||||
const f32_vec_type* v_vecs_f32;
|
||||
float* kv_cache_f32 = nullptr;
|
||||
|
||||
if constexpr (!std::is_same<scalar_t, float>::value) {
|
||||
// convert KV cache block to FP32 to reuse it across query heads and
|
||||
// attn @ V computation, since FP16/BF16->FP32 is expensive.
|
||||
// TODO: move malloc outside of this fn to reuse across iterations.
|
||||
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
|
||||
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
|
||||
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
|
||||
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
|
||||
f32_vec_type kv_vec_f32(kv_load_vec);
|
||||
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
|
||||
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
|
||||
// NOTE: in this case, we only need to convert the V section to FP32.
|
||||
// But for simplicity, we will convert the whole KV block to FP32.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
} else {
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
|
||||
}
|
||||
|
||||
// attn @ V always use FP32 for V, since attn is FP32.
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
|
||||
|
||||
} else {
|
||||
// KV cache is FP32. don't need to do anything.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
|
||||
}
|
||||
|
||||
// compute 2 heads at the same time to improve ILP and
|
||||
// take advantage of register cache for K and V.
|
||||
constexpr int HEAD_UNROLL = 2;
|
||||
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += HEAD_UNROLL * V_HEAD_DIM;
|
||||
acc_lse += HEAD_UNROLL;
|
||||
}
|
||||
|
||||
// take care of the remaining heads
|
||||
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += V_HEAD_DIM;
|
||||
acc_lse += 1;
|
||||
}
|
||||
|
||||
if (kv_cache_f32 != nullptr) {
|
||||
std::free(kv_cache_f32);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
|
||||
void mla_decode_kvcache_cpu_impl(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||
// head_dim]
|
||||
const int num_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
|
||||
const int kv_stride, const int num_seqs) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
// shared across threads
|
||||
const int max_threads = omp_get_max_threads();
|
||||
const int acc_out_nbytes =
|
||||
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
|
||||
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
|
||||
std::vector<float> acc_lse(max_threads * num_heads);
|
||||
|
||||
// allocate memory to pre-convert query to FP32 later
|
||||
float* q_f32;
|
||||
constexpr bool PRE_CONVERT_QUERY =
|
||||
!std::is_same<scalar_t, float>::value &&
|
||||
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
|
||||
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
|
||||
}
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
const int num_threads = omp_get_num_threads();
|
||||
const int thread_id = omp_get_thread_num();
|
||||
float* __restrict__ acc_out_thread =
|
||||
acc_out + thread_id * num_heads * V_HEAD_DIM;
|
||||
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
|
||||
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
// reset accumulator
|
||||
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
|
||||
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
|
||||
const qk_vec_type* q_vecs;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
|
||||
#pragma omp for
|
||||
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
|
||||
qk_vec_type q_vec(q_load_vec);
|
||||
q_vec.save(q_f32 + i);
|
||||
}
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
|
||||
} else {
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int physical_block_idx =
|
||||
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
|
||||
const int num_tokens =
|
||||
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
|
||||
|
||||
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
|
||||
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
|
||||
acc_lse_thread, num_heads, scale, num_tokens);
|
||||
}
|
||||
|
||||
// merge attention states across threads
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
// each thread is responsible for 1 head
|
||||
#pragma omp for
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
float* acc_lse_head = acc_lse.data() + head_idx;
|
||||
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
|
||||
|
||||
float max_val = -FLT_MAX;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
|
||||
acc_lse_head[thread_id_ * num_heads] = val;
|
||||
sum_exp += val;
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum_exp;
|
||||
float out_head[V_HEAD_DIM] = {};
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
out_head[i] +=
|
||||
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
|
||||
head_idx * V_HEAD_DIM + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (PRE_CONVERT_QUERY) {
|
||||
std::free(q_f32);
|
||||
}
|
||||
std::free(acc_out);
|
||||
}
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
|
||||
const int num_seqs = query.size(0);
|
||||
const int num_heads = query.size(1);
|
||||
const int head_dim = query.size(2);
|
||||
const int block_size = kv_cache.size(1);
|
||||
const int v_head_dim = out.size(2);
|
||||
|
||||
const int max_num_blocks_per_seq = block_tables.size(1);
|
||||
const int o_stride = out.stride(0);
|
||||
const int q_stride = query.stride(0);
|
||||
const int kv_stride = kv_cache.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
|
||||
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
|
||||
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
|
||||
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
|
||||
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
|
||||
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
|
||||
else
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size);
|
||||
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
|
||||
});
|
||||
}
|
@ -18,6 +18,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -150,6 +154,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
@ -157,4 +169,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
||||
cpu_ops.def(
|
||||
"mla_decode_kvcache("
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla_cpu(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
device = "cpu"
|
||||
kv_cache_dtype = "auto"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
|
||||
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
94
tests/kernels/test_mla_decode_cpu.py
Normal file
94
tests/kernels/test_mla_decode_cpu.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [4])
|
||||
@pytest.mark.parametrize("mean_seq_len", [256])
|
||||
@pytest.mark.parametrize("h_q", [16])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_mla_decode_cpu(
|
||||
bs: int,
|
||||
mean_seq_len: int,
|
||||
h_q: int,
|
||||
d: int,
|
||||
dv: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
varlen: bool,
|
||||
):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = d**(-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size,
|
||||
dtype=torch.int32)
|
||||
block_table = block_table.view(bs, seqlen_pad // block_size)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
for i, seq_len in enumerate(seq_lens.tolist()):
|
||||
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
||||
|
||||
out_mla = q.new_zeros(bs, h_q, dv)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
|
||||
seq_lens)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
assert not out_mla.isnan().any(), "Likely read out of bounds"
|
||||
torch.testing.assert_close(out_mla, out_ref)
|
@ -124,6 +124,18 @@ def paged_attention_rocm(
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
|
||||
def mla_decode_kvcache_cpu(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale,
|
||||
block_tables, seq_lens)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
|
@ -187,15 +187,28 @@ class ipex_ops:
|
||||
gen_: torch.Generator,
|
||||
logits_soft_cap: float,
|
||||
) -> None:
|
||||
if ipex.__version__.endswith("cpu"):
|
||||
if logits_soft_cap != 0.0:
|
||||
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(), seqlen_k.int(),
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
pdropout, softmax_scale,
|
||||
zero_tensors, is_causal,
|
||||
return_softmax, gen_,
|
||||
logits_soft_cap)
|
||||
seqlen_q.int(),
|
||||
seqlen_k.int(), max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax,
|
||||
gen_)
|
||||
else: # XPU build
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(),
|
||||
seqlen_k.int(), max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax,
|
||||
gen_, logits_soft_cap)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
|
303
vllm/attention/backends/cpu_mla.py
Normal file
303
vllm/attention/backends/cpu_mla.py
Normal file
@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
|
||||
class CPUMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
||||
return CPUMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
||||
return CPUMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLACommonState"]:
|
||||
return MLACommonState
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
||||
return CPUMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUMLAMetadata(TorchSDPAMetadata):
|
||||
# New for MLA
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor = None
|
||||
|
||||
# required by MLACommonImpl
|
||||
is_profile_run: bool = False
|
||||
|
||||
|
||||
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
assert not self.chunked_prefill, \
|
||||
"chunked prefill is currently not supported"
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# metadata for prefill
|
||||
if input_data.num_prefills > 0:
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
|
||||
# for chunked-prefill
|
||||
if self.chunked_prefill:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
|
||||
else:
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
prefill_block_tables = None
|
||||
|
||||
# metadata for decode
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return CPUMLAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
input_positions=torch.tensor([self.input_data.input_positions]))
|
||||
|
||||
|
||||
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CPUMLAImpl")
|
||||
|
||||
# states is implemented.
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
assert prefill_metadata is not None
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = torch.empty_like(q)
|
||||
ipex_ops.varlen_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v_padded,
|
||||
out=output,
|
||||
seqlen_q=prefill_metadata.query_start_loc,
|
||||
seqlen_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.max_query_len,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
logits_soft_cap=0.0,
|
||||
)
|
||||
|
||||
# remove padding
|
||||
output = output.view(-1, self.num_heads,
|
||||
q.shape[-1])[..., :v.shape[-1]]
|
||||
output = output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
return self.o_proj(output)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
||||
|
||||
# Run MQA
|
||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor)
|
||||
return self._v_up_proj_and_o_proj(o)
|
@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
else:
|
||||
merge_attn_states = None
|
||||
triton_attention = None
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
is_vllm_fa = False
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
|
@ -37,6 +37,9 @@ class CpuPlatform(Platform):
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
logger.info("Using CPU MLA backend.")
|
||||
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||
|
||||
@ -129,9 +132,6 @@ class CpuPlatform(Platform):
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# MLA attention is not supported
|
||||
os.environ["VLLM_MLA_DISABLE"] = "1"
|
||||
|
||||
# Intel OpenMP setting
|
||||
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
||||
if "libiomp5.so" in ld_prealod_str:
|
||||
|
@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
) if needs_attn_backend else None
|
||||
|
||||
# Multi-modal data support
|
||||
|
@ -66,6 +66,7 @@ class CPUCacheEngine:
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
|
||||
# Initialize the cache.
|
||||
|
Loading…
x
Reference in New Issue
Block a user