215 lines
8.8 KiB
C++
215 lines
8.8 KiB
C++
#include <map>
|
|
#include <vector>
|
|
|
|
#include "cpu_types.hpp"
|
|
|
|
#if defined(__x86_64__)
|
|
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
|
#else
|
|
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
|
#endif
|
|
|
|
namespace {
|
|
template <typename scalar_t>
|
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
|
std::vector<torch::Tensor> const& value_caches,
|
|
const torch::Tensor& mapping_pairs,
|
|
const int element_num_per_block,
|
|
const int layer_num) {
|
|
const size_t pair_num = mapping_pairs.size(0);
|
|
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
|
#pragma omp parallel for collapse(2)
|
|
for (int layer = 0; layer < layer_num; ++layer) {
|
|
for (size_t pair = 0; pair < pair_num; ++pair) {
|
|
int64_t source_offset =
|
|
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
|
int64_t target_offset =
|
|
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
|
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
|
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
|
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
|
|
|
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
|
source_ptr = value_cache_ptr + source_offset;
|
|
target_ptr = value_cache_ptr + target_offset;
|
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
void reshape_and_cache_cpu_impl(
|
|
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
|
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
|
const int64_t* __restrict__ slot_mapping, const int num_tokens,
|
|
const int key_stride, const int value_stride, const int num_heads,
|
|
const int head_size, const int block_size, const int x) {
|
|
const int block_elem_num = num_heads * head_size * block_size;
|
|
|
|
#pragma omp parallel for collapse(2)
|
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
|
const int64_t slot_idx = slot_mapping[token_idx];
|
|
if (slot_idx >= 0) {
|
|
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
|
int src_value_head_idx =
|
|
token_idx * value_stride + head_idx * head_size;
|
|
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
|
|
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
|
|
const int64_t block_index = slot_idx / block_size;
|
|
const int64_t block_offset = slot_idx % block_size;
|
|
scalar_t* target_key_head_ptr = key_cache +
|
|
block_elem_num * block_index +
|
|
head_idx * block_size * head_size;
|
|
scalar_t* target_value_head_ptr = value_cache +
|
|
block_elem_num * block_index +
|
|
head_idx * block_size * head_size;
|
|
|
|
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
|
|
const int64_t target_offset =
|
|
src_key_idx * block_size + block_offset * x;
|
|
for (int i = 0; i < x; ++i) {
|
|
target_key_head_ptr[target_offset + i] =
|
|
src_key_head_ptr[src_key_idx + i];
|
|
}
|
|
}
|
|
|
|
for (int src_value_idx = 0; src_value_idx < head_size;
|
|
++src_value_idx) {
|
|
const int64_t target_offset =
|
|
src_value_idx * block_size + block_offset;
|
|
target_value_head_ptr[target_offset] =
|
|
src_value_head_ptr[src_value_idx];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}; // namespace
|
|
|
|
template <typename scalar_t>
|
|
void concat_and_cache_mla_cpu_impl(
|
|
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
|
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
|
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
|
// + pe_dim)]
|
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
|
const int num_tokens, //
|
|
const int block_stride, //
|
|
const int entry_stride, //
|
|
const int kv_c_stride, //
|
|
const int k_pe_stride, //
|
|
const int kv_lora_rank, //
|
|
const int pe_dim, //
|
|
const int block_size //
|
|
) {
|
|
#pragma omp parallel for
|
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
const int64_t slot_idx = slot_mapping[token_idx];
|
|
// NOTE: slot_idx can be -1 if the token is padded
|
|
if (slot_idx < 0) {
|
|
continue;
|
|
}
|
|
const int64_t block_idx = slot_idx / block_size;
|
|
const int64_t block_offset = slot_idx % block_size;
|
|
|
|
auto copy = [&](const scalar_t* __restrict__ src,
|
|
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
|
int size, int offset) {
|
|
for (int i = 0; i < size; i++) {
|
|
const int64_t src_idx = token_idx * src_stride + i;
|
|
const int64_t dst_idx =
|
|
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
|
dst[dst_idx] = src[src_idx];
|
|
}
|
|
};
|
|
|
|
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
|
}
|
|
}
|
|
|
|
// Note: the key_caches and value_caches vectors are constant but
|
|
// not the Tensors they contain. The vectors need to be const refs
|
|
// in order to satisfy pytorch's C++ operator registration code.
|
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|
std::vector<torch::Tensor> const& value_caches,
|
|
const torch::Tensor& block_mapping) {
|
|
unsigned num_layers = key_caches.size();
|
|
TORCH_CHECK(num_layers == value_caches.size());
|
|
if (num_layers == 0) {
|
|
return;
|
|
}
|
|
|
|
const int element_num_per_block = key_caches[0][0].numel();
|
|
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
|
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
|
element_num_per_block, num_layers);
|
|
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
|
});
|
|
}
|
|
|
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
|
torch::Tensor& slot_mapping,
|
|
const std::string& kv_cache_dtype,
|
|
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
|
int num_tokens = key.size(0);
|
|
int num_heads = key.size(1);
|
|
int head_size = key.size(2);
|
|
int block_size = key_cache.size(3);
|
|
int x = key_cache.size(4);
|
|
|
|
int key_stride = key.stride(0);
|
|
int value_stride = value.stride(0);
|
|
|
|
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
|
reshape_and_cache_cpu_impl<scalar_t>(
|
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
|
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
|
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
|
num_heads, head_size, block_size, x);
|
|
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
|
});
|
|
}
|
|
|
|
void concat_and_cache_mla(
|
|
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
|
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
|
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
|
// pe_dim)]
|
|
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
|
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
|
int num_tokens = slot_mapping.size(0);
|
|
int kv_lora_rank = kv_c.size(1);
|
|
int pe_dim = k_pe.size(1);
|
|
int block_size = kv_cache.size(1);
|
|
|
|
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
|
TORCH_CHECK(kv_cache_dtype != "fp8");
|
|
|
|
int kv_c_stride = kv_c.stride(0);
|
|
int k_pe_stride = k_pe.stride(0);
|
|
int block_stride = kv_cache.stride(0);
|
|
int entry_stride = kv_cache.stride(1);
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
|
concat_and_cache_mla_cpu_impl<scalar_t>(
|
|
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
|
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
|
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
|
kv_lora_rank, pe_dim, block_size);
|
|
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
|
});
|
|
}
|
|
|
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
|
const torch::Tensor& block_mapping) {
|
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
|
}
|