#include "cpu_types.hpp" #include namespace { template struct KernelVecType { using qk_load_vec_type = void; using qk_vec_type = void; using v_load_vec_type = void; }; template <> struct KernelVecType { using qk_load_vec_type = vec_op::FP32Vec16; using qk_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16; }; template <> struct KernelVecType { #if defined(__powerpc64__) || defined(__s390x__) // Power and s390x architecture-specific vector types using qk_load_vec_type = vec_op::FP32Vec16; using qk_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16; #else // Fallback for other architectures, including x86 using qk_load_vec_type = vec_op::FP16Vec16; using qk_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP16Vec16; #endif }; #ifdef __AVX512BF16__ template <> struct KernelVecType { using qk_load_vec_type = vec_op::BF16Vec32; using qk_vec_type = vec_op::BF16Vec32; using v_load_vec_type = vec_op::BF16Vec16; }; #elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT) // pass #else template <> struct KernelVecType { using qk_load_vec_type = vec_op::BF16Vec16; using qk_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::BF16Vec16; }; #endif template void mla_decode_block_head( const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim] const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim] const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim] float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim] float* __restrict__ acc_lse, // [HEAD_UNROLL] const float scale, const int num_tokens) { using f32_vec_type = vec_op::FP32Vec16; constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM; float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros float max_val[HEAD_UNROLL]; std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX); f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL]; for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) { // load to registers qk_vec_type q_vec[HEAD_UNROLL]; #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) q_vec[unroll] = qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]}; for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]); #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec); } } for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale; logits[block_offset][unroll] = acc; max_val[unroll] = std::max(max_val[unroll], acc); } } float sum_exp[HEAD_UNROLL] = {}; for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { const float val = std::exp(logits[block_offset][unroll] - max_val[unroll]); logits[block_offset][unroll] = val; sum_exp[unroll] += val; } } f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL]; for (int block_offset = 0; block_offset < num_tokens; ++block_offset) { // load to registers f32_vec_type scale_[HEAD_UNROLL]; #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) scale_[unroll] = f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]}; for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) { f32_vec_type v_vec( v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]); #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]); } } // merge attention state // section 2.2 in https://arxiv.org/pdf/2501.01005 f32_vec_type prev_scale[HEAD_UNROLL]; f32_vec_type curr_scale[HEAD_UNROLL]; #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { const float prev_lse = acc_lse[unroll]; const float curr_lse = std::log(sum_exp[unroll]) + max_val[unroll]; // add back max_val to get true lse // softmax trick const float max_lse = std::max(prev_lse, curr_lse); const float prev_sum_exp = std::exp(prev_lse - max_lse); const float curr_sum_exp = std::exp(curr_lse - max_lse); const float new_sum_exp = prev_sum_exp + curr_sum_exp; acc_lse[unroll] = std::log(new_sum_exp) + max_lse; prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp}; curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp}; } for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) { #pragma unroll for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) { f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll); o_vec = o_vec * prev_scale[unroll] + this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll]; o_vec.save(acc_out + i + V_HEAD_DIM * unroll); } } q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL; acc_out += V_HEAD_DIM * HEAD_UNROLL; } template void mla_decode_block( const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim] const scalar_t* __restrict__ kv_cache, // [block_size, head_dim] float* __restrict__ acc_out, // [num_heads, v_head_dim] float* __restrict__ acc_lse, // [num_heads] const int num_heads, const float scale, const int num_tokens) { using qk_load_vec_type = typename KernelVecType::qk_load_vec_type; static_assert( std::is_same::qk_vec_type>::value); using v_load_vec_type = typename KernelVecType::v_load_vec_type; using f32_vec_type = vec_op::FP32Vec16; static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM); static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM); constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM; const qk_vec_type* k_vecs; const f32_vec_type* v_vecs_f32; float* kv_cache_f32 = nullptr; if constexpr (!std::is_same::value) { // convert KV cache block to FP32 to reuse it across query heads and // attn @ V computation, since FP16/BF16->FP32 is expensive. // TODO: move malloc outside of this fn to reuse across iterations. const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float); kv_cache_f32 = static_cast(std::aligned_alloc(64, nbytes)); for (int block_offset = 0; block_offset < num_tokens; ++block_offset) for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) { v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i); f32_vec_type kv_vec_f32(kv_load_vec); kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i); } if constexpr (std::is_same::value) { // for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion) // NOTE: in this case, we only need to convert the V section to FP32. // But for simplicity, we will convert the whole KV block to FP32. k_vecs = reinterpret_cast(kv_cache); } else { k_vecs = reinterpret_cast(kv_cache_f32); } // attn @ V always use FP32 for V, since attn is FP32. v_vecs_f32 = reinterpret_cast(kv_cache_f32); } else { // KV cache is FP32. don't need to do anything. k_vecs = reinterpret_cast(kv_cache); v_vecs_f32 = reinterpret_cast(kv_cache); } // compute 2 heads at the same time to improve ILP and // take advantage of register cache for K and V. constexpr int HEAD_UNROLL = 2; for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) { mla_decode_block_head( q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens); q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM; acc_out += HEAD_UNROLL * V_HEAD_DIM; acc_lse += HEAD_UNROLL; } // take care of the remaining heads for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) { mla_decode_block_head( q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens); q_vecs += HEAD_DIM / QK_NUM_ELEM; acc_out += V_HEAD_DIM; acc_lse += 1; } if (kv_cache_f32 != nullptr) { std::free(kv_cache_f32); } } } // namespace template void mla_decode_kvcache_cpu_impl( scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim] const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, // head_dim] const int num_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const int o_stride, const int q_stride, const int kv_stride, const int num_seqs) { using qk_load_vec_type = typename KernelVecType::qk_load_vec_type; using qk_vec_type = typename KernelVecType::qk_vec_type; constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM; // shared across threads const int max_threads = omp_get_max_threads(); const int acc_out_nbytes = max_threads * num_heads * V_HEAD_DIM * sizeof(float); float* acc_out = static_cast(std::aligned_alloc(64, acc_out_nbytes)); std::vector acc_lse(max_threads * num_heads); // allocate memory to pre-convert query to FP32 later float* q_f32; constexpr bool PRE_CONVERT_QUERY = !std::is_same::value && std::is_same::value; if constexpr (PRE_CONVERT_QUERY) { const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float); q_f32 = static_cast(std::aligned_alloc(64, q_f32_nbytes)); } #pragma omp parallel { const int num_threads = omp_get_num_threads(); const int thread_id = omp_get_thread_num(); float* __restrict__ acc_out_thread = acc_out + thread_id * num_heads * V_HEAD_DIM; float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads; for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { // reset accumulator std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f); std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX); const int seq_len = seq_lens[seq_idx]; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE; const qk_vec_type* q_vecs; if constexpr (PRE_CONVERT_QUERY) { // pre-convert query to FP32 since FP16/BF16->FP32 is slow. #pragma omp for for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) { qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i); qk_vec_type q_vec(q_load_vec); q_vec.save(q_f32 + i); } q_vecs = reinterpret_cast(q_f32); } else { q_vecs = reinterpret_cast(q + seq_idx * q_stride); } #pragma omp for for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int physical_block_idx = block_tables[seq_idx * max_num_blocks_per_seq + block_idx]; const int num_tokens = block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size; mla_decode_block( q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread, acc_lse_thread, num_heads, scale, num_tokens); } // merge attention states across threads // section 2.2 in https://arxiv.org/pdf/2501.01005 // each thread is responsible for 1 head #pragma omp for for (int head_idx = 0; head_idx < num_heads; ++head_idx) { float* acc_lse_head = acc_lse.data() + head_idx; float* acc_out_head = acc_out + head_idx * V_HEAD_DIM; float max_val = -FLT_MAX; for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]); } float sum_exp = 0.0f; for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val); acc_lse_head[thread_id_ * num_heads] = val; sum_exp += val; } float inv_sum = 1.0f / sum_exp; float out_head[V_HEAD_DIM] = {}; for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) { float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum; for (int i = 0; i < V_HEAD_DIM; ++i) { out_head[i] += acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_; } } for (int i = 0; i < V_HEAD_DIM; ++i) { vec_op::storeFP32(out_head[i], out + seq_idx * o_stride + head_idx * V_HEAD_DIM + i); } } } } if (PRE_CONVERT_QUERY) { std::free(q_f32); } std::free(acc_out); } void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens) { const int num_seqs = query.size(0); const int num_heads = query.size(1); const int head_dim = query.size(2); const int block_size = kv_cache.size(1); const int v_head_dim = out.size(2); const int max_num_blocks_per_seq = block_tables.size(1); const int o_stride = out.stride(0); const int q_stride = query.stride(0); const int kv_stride = kv_cache.stride(0); VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl) if (head_dim == 576 && v_head_dim == 512 && block_size == 16) mla_decode_kvcache_cpu_impl( out.data_ptr(), query.data_ptr(), kv_cache.data_ptr(), num_heads, scale, block_tables.data_ptr(), seq_lens.data_ptr(), max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs); else TORCH_CHECK(false, "Unsupported block size: ", block_size); CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl) }); }