diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index e45e1848..05744bb5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -38,6 +38,8 @@ function cpu_tests() { set -e pip install -r vllm/requirements/test.txt pip install -r vllm/requirements/cpu.txt + pytest -v -s tests/kernels/test_cache.py -m cpu_model + pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model pytest -v -s tests/models/decoder_only/language -m cpu_model pytest -v -s tests/models/embedding/language -m cpu_model pytest -v -s tests/models/encoder_decoder/language -m cpu_model diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 345b75d6..b57d9e22 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -190,6 +190,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index d726ee93..69f6d06e 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl( } }; // namespace +template +void concat_and_cache_mla_cpu_impl( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int num_tokens, // + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size // +) { +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + continue; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, + scalar_t* __restrict__ dst, int src_stride, int dst_stride, + int size, int offset) { + for (int i = 0; i < size; i++) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + dst[dst_idx] = src[src_idx]; + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); + } +} + // Note: the key_caches and value_caches vectors are constant but // not the Tensors they contain. The vectors need to be const refs // in order to satisfy pytorch's C++ operator registration code. @@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, }); } +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + TORCH_CHECK(kv_cache_dtype != "fp8"); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + VLLM_DISPATCH_FLOATING_TYPES( + kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl) + concat_and_cache_mla_cpu_impl( + kv_c.data_ptr(), k_pe.data_ptr(), + kv_cache.data_ptr(), slot_mapping.data_ptr(), + num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride, + kv_lora_rank, pe_dim, block_size); + CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl) + }); +} + void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index a9369e1f..4568699b 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -130,6 +130,8 @@ struct BF16Vec32 : public Vec { __m512i reg; + explicit BF16Vec32() : reg(_mm512_setzero_si512()) {} + explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} explicit BF16Vec32(__m512i data) : reg(data) {} diff --git a/csrc/cpu/mla_decode.cpp b/csrc/cpu/mla_decode.cpp new file mode 100644 index 00000000..37bd463b --- /dev/null +++ b/csrc/cpu/mla_decode.cpp @@ -0,0 +1,393 @@ +#include "cpu_types.hpp" +#include + +namespace { +template +struct KernelVecType { + using qk_load_vec_type = void; + using qk_vec_type = void; + using v_load_vec_type = void; +}; + +template <> +struct KernelVecType { + using qk_load_vec_type = vec_op::FP32Vec16; + using qk_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power and s390x architecture-specific vector types + using qk_load_vec_type = vec_op::FP32Vec16; + using qk_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures, including x86 + using qk_load_vec_type = vec_op::FP16Vec16; + using qk_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP16Vec16; +#endif +}; + +#ifdef __AVX512BF16__ +template <> +struct KernelVecType { + using qk_load_vec_type = vec_op::BF16Vec32; + using qk_vec_type = vec_op::BF16Vec32; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT) +// pass +#else +template <> +struct KernelVecType { + using qk_load_vec_type = vec_op::BF16Vec16; + using qk_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#endif + +template +void mla_decode_block_head( + const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim] + const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim] + const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim] + float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim] + float* __restrict__ acc_lse, // [HEAD_UNROLL] + const float scale, const int num_tokens) { + using f32_vec_type = vec_op::FP32Vec16; + constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; + constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM; + + float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros + float max_val[HEAD_UNROLL]; + std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX); + + f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL]; + for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) { + // load to registers + qk_vec_type q_vec[HEAD_UNROLL]; + +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) + q_vec[unroll] = + qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]}; + + for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { + qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]); + +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) + vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec); + } + } + + for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { + const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale; + logits[block_offset][unroll] = acc; + max_val[unroll] = std::max(max_val[unroll], acc); + } + } + + float sum_exp[HEAD_UNROLL] = {}; + for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { + const float val = + std::exp(logits[block_offset][unroll] - max_val[unroll]); + logits[block_offset][unroll] = val; + sum_exp[unroll] += val; + } + } + + f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL]; + + for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { + // load to registers + f32_vec_type scale_[HEAD_UNROLL]; + +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) + scale_[unroll] = + f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]}; + + for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) { + f32_vec_type v_vec( + v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]); + +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) + vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]); + } + } + + // merge attention state + // section 2.2 in https://arxiv.org/pdf/2501.01005 + f32_vec_type prev_scale[HEAD_UNROLL]; + f32_vec_type curr_scale[HEAD_UNROLL]; + +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { + const float prev_lse = acc_lse[unroll]; + const float curr_lse = std::log(sum_exp[unroll]) + + max_val[unroll]; // add back max_val to get true lse + // softmax trick + const float max_lse = std::max(prev_lse, curr_lse); + const float prev_sum_exp = std::exp(prev_lse - max_lse); + const float curr_sum_exp = std::exp(curr_lse - max_lse); + + const float new_sum_exp = prev_sum_exp + curr_sum_exp; + acc_lse[unroll] = std::log(new_sum_exp) + max_lse; + + prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp}; + curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp}; + } + + for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) { +#pragma unroll + for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { + f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll); + o_vec = o_vec * prev_scale[unroll] + + this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll]; + o_vec.save(acc_out + i + V_HEAD_DIM * unroll); + } + } + + q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL; + acc_out += V_HEAD_DIM * HEAD_UNROLL; +} + +template +void mla_decode_block( + const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim] + const scalar_t* __restrict__ kv_cache, // [block_size, head_dim] + float* __restrict__ acc_out, // [num_heads, v_head_dim] + float* __restrict__ acc_lse, // [num_heads] + const int num_heads, const float scale, const int num_tokens) { + using qk_load_vec_type = typename KernelVecType::qk_load_vec_type; + static_assert( + std::is_same::qk_vec_type>::value); + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + using f32_vec_type = vec_op::FP32Vec16; + static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM); + static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM); + constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; + constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM; + + const qk_vec_type* k_vecs; + const f32_vec_type* v_vecs_f32; + float* kv_cache_f32 = nullptr; + + if constexpr (!std::is_same::value) { + // convert KV cache block to FP32 to reuse it across query heads and + // attn @ V computation, since FP16/BF16->FP32 is expensive. + // TODO: move malloc outside of this fn to reuse across iterations. + const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float); + kv_cache_f32 = static_cast(std::aligned_alloc(64, nbytes)); + + for (int block_offset = 0; block_offset < num_tokens; ++block_offset) + for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) { + v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i); + f32_vec_type kv_vec_f32(kv_load_vec); + kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i); + } + + if constexpr (std::is_same::value) { + // for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion) + // NOTE: in this case, we only need to convert the V section to FP32. + // But for simplicity, we will convert the whole KV block to FP32. + k_vecs = reinterpret_cast(kv_cache); + } else { + k_vecs = reinterpret_cast(kv_cache_f32); + } + + // attn @ V always use FP32 for V, since attn is FP32. + v_vecs_f32 = reinterpret_cast(kv_cache_f32); + + } else { + // KV cache is FP32. don't need to do anything. + k_vecs = reinterpret_cast(kv_cache); + v_vecs_f32 = reinterpret_cast(kv_cache); + } + + // compute 2 heads at the same time to improve ILP and + // take advantage of register cache for K and V. + constexpr int HEAD_UNROLL = 2; + for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) { + mla_decode_block_head( + q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens); + + q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM; + acc_out += HEAD_UNROLL * V_HEAD_DIM; + acc_lse += HEAD_UNROLL; + } + + // take care of the remaining heads + for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) { + mla_decode_block_head( + q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens); + + q_vecs += HEAD_DIM / QK_NUM_ELEM; + acc_out += V_HEAD_DIM; + acc_lse += 1; + } + + if (kv_cache_f32 != nullptr) { + std::free(kv_cache_f32); + } +} +} // namespace + +template +void mla_decode_kvcache_cpu_impl( + scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim] + const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, + // head_dim] + const int num_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, const int o_stride, const int q_stride, + const int kv_stride, const int num_seqs) { + using qk_load_vec_type = typename KernelVecType::qk_load_vec_type; + using qk_vec_type = typename KernelVecType::qk_vec_type; + constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; + + // shared across threads + const int max_threads = omp_get_max_threads(); + const int acc_out_nbytes = + max_threads * num_heads * V_HEAD_DIM * sizeof(float); + float* acc_out = static_cast(std::aligned_alloc(64, acc_out_nbytes)); + std::vector acc_lse(max_threads * num_heads); + + // allocate memory to pre-convert query to FP32 later + float* q_f32; + constexpr bool PRE_CONVERT_QUERY = + !std::is_same::value && + std::is_same::value; + if constexpr (PRE_CONVERT_QUERY) { + const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float); + q_f32 = static_cast(std::aligned_alloc(64, q_f32_nbytes)); + } + +#pragma omp parallel + { + const int num_threads = omp_get_num_threads(); + const int thread_id = omp_get_thread_num(); + float* __restrict__ acc_out_thread = + acc_out + thread_id * num_heads * V_HEAD_DIM; + float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads; + + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + // reset accumulator + std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f); + std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX); + + const int seq_len = seq_lens[seq_idx]; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE; + + const qk_vec_type* q_vecs; + if constexpr (PRE_CONVERT_QUERY) { +// pre-convert query to FP32 since FP16/BF16->FP32 is slow. +#pragma omp for + for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) { + qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i); + qk_vec_type q_vec(q_load_vec); + q_vec.save(q_f32 + i); + } + q_vecs = reinterpret_cast(q_f32); + } else { + q_vecs = reinterpret_cast(q + seq_idx * q_stride); + } + +#pragma omp for + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int physical_block_idx = + block_tables[seq_idx * max_num_blocks_per_seq + block_idx]; + const int num_tokens = + block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size; + + mla_decode_block( + q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread, + acc_lse_thread, num_heads, scale, num_tokens); + } + +// merge attention states across threads +// section 2.2 in https://arxiv.org/pdf/2501.01005 +// each thread is responsible for 1 head +#pragma omp for + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + float* acc_lse_head = acc_lse.data() + head_idx; + float* acc_out_head = acc_out + head_idx * V_HEAD_DIM; + + float max_val = -FLT_MAX; + for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { + max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]); + } + + float sum_exp = 0.0f; + for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { + float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val); + acc_lse_head[thread_id_ * num_heads] = val; + sum_exp += val; + } + + float inv_sum = 1.0f / sum_exp; + float out_head[V_HEAD_DIM] = {}; + for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { + float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum; + for (int i = 0; i < V_HEAD_DIM; ++i) { + out_head[i] += + acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_; + } + } + + for (int i = 0; i < V_HEAD_DIM; ++i) { + vec_op::storeFP32(out_head[i], out + seq_idx * o_stride + + head_idx * V_HEAD_DIM + i); + } + } + } + } + if (PRE_CONVERT_QUERY) { + std::free(q_f32); + } + std::free(acc_out); +} + +void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& kv_cache, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens) { + const int num_seqs = query.size(0); + const int num_heads = query.size(1); + const int head_dim = query.size(2); + const int block_size = kv_cache.size(1); + const int v_head_dim = out.size(2); + + const int max_num_blocks_per_seq = block_tables.size(1); + const int o_stride = out.stride(0); + const int q_stride = query.stride(0); + const int kv_stride = kv_cache.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl) + if (head_dim == 576 && v_head_dim == 512 && block_size == 16) + mla_decode_kvcache_cpu_impl( + out.data_ptr(), query.data_ptr(), + kv_cache.data_ptr(), num_heads, scale, + block_tables.data_ptr(), seq_lens.data_ptr(), + max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs); + else + TORCH_CHECK(false, "Unsupported block size: ", block_size); + CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl) + }); +} \ No newline at end of file diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5d1c5f4c..ef5a2fb5 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& kv_cache, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -150,6 +154,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); + + cache_ops.def( + "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor scale) -> ()"); + cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { @@ -157,4 +169,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) { + cpu_ops.def( + "mla_decode_kvcache(" + " Tensor! out, Tensor query, Tensor kv_cache," + " float scale, Tensor block_tables, Tensor seq_lens) -> ()"); + cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); +} + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index f7936989..89912281 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) torch.testing.assert_close(dst, expected) + + +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +@torch.inference_mode() +def test_concat_and_cache_mla_cpu( + kv_lora_rank: int, + qk_rope_head_dim: int, + num_tokens: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + device = "cpu" + kv_cache_dtype = "auto" + current_platform.seed_everything(seed) + torch.set_default_device(device) + + total_slots = num_blocks * block_size + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, + qk_rope_head_dim, + dtype=dtype, + device=device) + entry_size = kv_lora_rank + qk_rope_head_dim + + scale = torch.tensor(0.1, dtype=torch.float32, device=device) + kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device) + ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i] + ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i] + + if kv_cache_dtype == "fp8": + ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) + ops.convert_fp8(ref_kv_cache, + ref_temp, + scale.item(), + kv_dtype=kv_cache_dtype) + else: + ref_kv_cache = ref_temp + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla, + (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, + kv_cache_dtype, scale) + torch.testing.assert_close(kv_cache, ref_kv_cache) diff --git a/tests/kernels/test_mla_decode_cpu.py b/tests/kernels/test_mla_decode_cpu.py new file mode 100644 index 00000000..8cebe32c --- /dev/null +++ b/tests/kernels/test_mla_decode_cpu.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +import vllm._custom_ops as ops +from vllm.platforms import current_platform + + +def cdiv(a, b): + return (a + b - 1) // b + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("bs", [4]) +@pytest.mark.parametrize("mean_seq_len", [256]) +@pytest.mark.parametrize("h_q", [16]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_mla_decode_cpu( + bs: int, + mean_seq_len: int, + h_q: int, + d: int, + dv: int, + block_size: int, + dtype: torch.dtype, + varlen: bool, +): + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + scale = d**(-0.5) + if varlen: + seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) + seq_lens = seq_lens.clip(2).to(torch.int32) + else: + seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + max_seq_len = seq_lens.max().item() + seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary? + + q = torch.randn(bs, h_q, d) + block_table = torch.arange(bs * seqlen_pad // block_size, + dtype=torch.int32) + block_table = block_table.view(bs, seqlen_pad // block_size) + + kv_cache = torch.randn(block_table.numel(), block_size, d) + for i, seq_len in enumerate(seq_lens.tolist()): + kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan") + + out_mla = q.new_zeros(bs, h_q, dv) + ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, + seq_lens) + + out_ref = q.new_zeros(bs, h_q, dv) + ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) + + assert not out_mla.isnan().any(), "Likely read out of bounds" + torch.testing.assert_close(out_mla, out_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d68c097f..dc07bad4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -124,6 +124,18 @@ def paged_attention_rocm( kv_cache_dtype, k_scale, v_scale) +def mla_decode_kvcache_cpu( + out: torch.Tensor, + query: torch.Tensor, + kv_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale, + block_tables, seq_lens) + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index a7b909d2..c3d210c2 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -187,15 +187,28 @@ class ipex_ops: gen_: torch.Generator, logits_soft_cap: float, ) -> None: - ipex.llm.functional.varlen_attention(query.contiguous(), - key.contiguous(), - value.contiguous(), out, - seqlen_q.int(), seqlen_k.int(), - max_seqlen_q, max_seqlen_k, - pdropout, softmax_scale, - zero_tensors, is_causal, - return_softmax, gen_, - logits_soft_cap) + if ipex.__version__.endswith("cpu"): + if logits_soft_cap != 0.0: + raise ValueError("IPEX CPU does not support logits_soft_cap") + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_) + else: # XPU build + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_, logits_soft_cap) @staticmethod def reshape_and_cache( diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py new file mode 100644 index 00000000..e2d16908 --- /dev/null +++ b/vllm/attention/backends/cpu_mla.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm._custom_ops as ops +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + + +class CPUMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "CPU_MLA" + + @staticmethod + def get_metadata_cls() -> Type["CPUMLAMetadata"]: + return CPUMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: + return CPUMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_impl_cls() -> Type["CPUMLAImpl"]: + return CPUMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class CPUMLAMetadata(TorchSDPAMetadata): + # New for MLA + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor = None + + # required by MLACommonImpl + is_profile_run: bool = False + + +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + assert not self.chunked_prefill, \ + "chunked prefill is currently not supported" + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # metadata for prefill + if input_data.num_prefills > 0: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + + # for chunked-prefill + if self.chunked_prefill: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + prefill_block_tables = None + + else: + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + prefill_block_tables = None + + # metadata for decode + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + return CPUMLAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + input_positions=torch.tensor([self.input_data.input_positions])) + + +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CPUMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CPUMLAImpl") + + # states is implemented. + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CPUMLAImpl with FP8 KV cache not yet supported") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = torch.empty_like(q) + ipex_ops.varlen_attention( + query=q, + key=k, + value=v_padded, + out=output, + seqlen_q=prefill_metadata.query_start_loc, + seqlen_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.max_query_len, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=0.0, + ) + + # remove padding + output = output.view(-1, self.num_heads, + q.shape[-1])[..., :v.shape[-1]] + output = output.reshape(-1, self.num_heads * v.shape[-1]) + return self.o_proj(output)[0] + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) + + # Run MQA + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor) + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 1b1ab314..8d70afe2 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version +if HAS_TRITON: + from vllm.attention.ops.triton_flash_attention import triton_attention + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +else: + merge_attn_states = None + triton_attention = None + try: from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True except ImportError: - # For rocm use upstream flash attention - from flash_attn import flash_attn_varlen_func is_vllm_fa = False - -from vllm.attention.ops.triton_flash_attention import triton_attention + try: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + except ImportError: + flash_attn_varlen_func = None if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4b10b298..0eb747a4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -37,6 +37,9 @@ class CpuPlatform(Platform): use_mla: bool) -> str: if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) + if use_mla: + logger.info("Using CPU MLA backend.") + return "vllm.attention.backends.cpu_mla.CPUMLABackend" logger.info("Using Torch SDPA backend.") return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @@ -129,9 +132,6 @@ class CpuPlatform(Platform): # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - # MLA attention is not supported - os.environ["VLLM_MLA_DISABLE"] = "1" - # Intel OpenMP setting ld_prealod_str = os.getenv("LD_PRELOAD", "") if "libiomp5.so" in ld_prealod_str: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 8407f073..9f4b1886 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, ) if needs_attn_backend else None # Multi-modal data support diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 70d2924a..b93aae9c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -66,6 +66,7 @@ class CPUCacheEngine: cache_config.cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, ) # Initialize the cache.