2023-02-16 07:47:03 +00:00
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
2024-01-03 11:09:59 +08:00
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
#include "cuda_compat.h"
|
2023-09-02 14:59:47 +09:00
|
|
|
#include "dispatch_utils.h"
|
2024-02-02 01:35:09 +08:00
|
|
|
#ifdef ENABLE_FP8_E5M2
|
2024-01-29 08:43:54 +08:00
|
|
|
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
2024-02-02 01:35:09 +08:00
|
|
|
#endif
|
2023-09-02 14:59:47 +09:00
|
|
|
|
2023-02-18 19:22:57 +00:00
|
|
|
#include <algorithm>
|
2023-02-16 07:47:03 +00:00
|
|
|
#include <cassert>
|
|
|
|
#include <map>
|
2023-03-10 09:58:21 -08:00
|
|
|
#include <vector>
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2024-02-02 01:35:09 +08:00
|
|
|
#ifdef USE_ROCM
|
|
|
|
#include <hip/hip_bf16.h>
|
|
|
|
typedef __hip_bfloat16 __nv_bfloat16;
|
|
|
|
#endif
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
void swap_blocks(
|
2023-02-16 07:47:03 +00:00
|
|
|
torch::Tensor& src,
|
|
|
|
torch::Tensor& dst,
|
|
|
|
const std::map<int64_t, int64_t>& block_mapping) {
|
|
|
|
torch::Device src_device = src.device();
|
|
|
|
torch::Device dst_device = dst.device();
|
|
|
|
cudaMemcpyKind memcpy_type;
|
|
|
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
2023-05-03 14:09:44 -07:00
|
|
|
TORCH_CHECK(
|
|
|
|
src_device.index() == dst_device.index(),
|
|
|
|
"src and dst must be on the same GPU");
|
2023-02-16 07:47:03 +00:00
|
|
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
|
|
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
|
|
|
memcpy_type = cudaMemcpyDeviceToHost;
|
|
|
|
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
|
|
|
memcpy_type = cudaMemcpyHostToDevice;
|
|
|
|
} else {
|
2023-05-03 14:09:44 -07:00
|
|
|
TORCH_CHECK(false, "Invalid device combination");
|
2023-02-16 07:47:03 +00:00
|
|
|
}
|
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
char *src_ptr = static_cast<char*>(src.data_ptr());
|
|
|
|
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
2023-02-16 07:47:03 +00:00
|
|
|
|
|
|
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
2024-01-26 08:48:17 +01:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
2023-02-16 07:47:03 +00:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2023-05-03 14:09:44 -07:00
|
|
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
2023-02-16 07:47:03 +00:00
|
|
|
for (const auto& pair : block_mapping) {
|
|
|
|
int64_t src_block_number = pair.first;
|
|
|
|
int64_t dst_block_number = pair.second;
|
|
|
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
|
|
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
|
|
|
cudaMemcpyAsync(
|
|
|
|
dst_ptr + dst_offset,
|
|
|
|
src_ptr + src_offset,
|
|
|
|
block_size_in_bytes,
|
|
|
|
memcpy_type,
|
|
|
|
stream);
|
|
|
|
}
|
|
|
|
}
|
2023-02-18 19:22:57 +00:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-04-07 17:45:07 -07:00
|
|
|
|
|
|
|
// Grid: (num_layers, num_pairs)
|
|
|
|
template<typename scalar_t>
|
|
|
|
__global__ void copy_blocks_kernel(
|
|
|
|
int64_t* key_cache_ptrs,
|
|
|
|
int64_t* value_cache_ptrs,
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t* __restrict__ block_mapping,
|
2023-04-07 17:45:07 -07:00
|
|
|
const int numel_per_block) {
|
|
|
|
const int layer_idx = blockIdx.x;
|
|
|
|
const int pair_idx = blockIdx.y;
|
|
|
|
|
|
|
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
|
|
|
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
|
|
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
2023-04-07 17:45:07 -07:00
|
|
|
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t src_block_offset = src_block_number * numel_per_block;
|
|
|
|
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
2023-04-07 17:45:07 -07:00
|
|
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_offset = src_block_offset + i;
|
|
|
|
int64_t dst_offset = dst_block_offset + i;
|
2023-04-07 17:45:07 -07:00
|
|
|
key_cache[dst_offset] = key_cache[src_offset];
|
|
|
|
}
|
|
|
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_offset = src_block_offset + i;
|
|
|
|
int64_t dst_offset = dst_block_offset + i;
|
2023-04-07 17:45:07 -07:00
|
|
|
value_cache[dst_offset] = value_cache[src_offset];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
} // namespace vllm
|
2023-04-07 17:45:07 -07:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
void copy_blocks(
|
2023-04-07 17:45:07 -07:00
|
|
|
std::vector<torch::Tensor>& key_caches,
|
|
|
|
std::vector<torch::Tensor>& value_caches,
|
2023-03-10 09:58:21 -08:00
|
|
|
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
2023-04-07 17:45:07 -07:00
|
|
|
int num_layers = key_caches.size();
|
|
|
|
TORCH_CHECK(num_layers == value_caches.size());
|
|
|
|
if (num_layers == 0) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
torch::Device cache_device = key_caches[0].device();
|
|
|
|
TORCH_CHECK(cache_device.is_cuda());
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-04-07 17:45:07 -07:00
|
|
|
// Create data structures for the kernel.
|
|
|
|
// Create an array of pointers to the key and value caches.
|
|
|
|
int64_t key_cache_ptrs[num_layers];
|
|
|
|
int64_t value_cache_ptrs[num_layers];
|
|
|
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
|
|
|
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
|
|
|
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
|
|
|
}
|
|
|
|
// Create block mapping array.
|
2023-10-31 15:19:30 -07:00
|
|
|
std::vector<int64_t> block_mapping_vec;
|
2023-03-10 09:58:21 -08:00
|
|
|
for (const auto& pair : block_mapping) {
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_block_number = pair.first;
|
|
|
|
for (int64_t dst_block_number : pair.second) {
|
2023-04-07 17:45:07 -07:00
|
|
|
block_mapping_vec.push_back(src_block_number);
|
|
|
|
block_mapping_vec.push_back(dst_block_number);
|
2023-03-10 09:58:21 -08:00
|
|
|
}
|
|
|
|
}
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t* block_mapping_array = block_mapping_vec.data();
|
2023-04-07 17:45:07 -07:00
|
|
|
int num_pairs = block_mapping_vec.size() / 2;
|
|
|
|
|
|
|
|
// Move the data structures to the GPU.
|
|
|
|
// NOTE: This synchronizes the CPU and GPU.
|
|
|
|
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
|
|
|
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
|
|
|
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
|
|
|
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
|
|
|
torch::Tensor block_mapping_tensor = torch::from_blob(
|
2023-10-31 15:19:30 -07:00
|
|
|
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
2023-04-07 17:45:07 -07:00
|
|
|
|
|
|
|
// Launch the kernel.
|
|
|
|
const int numel_per_block = key_caches[0][0].numel();
|
|
|
|
dim3 grid(num_layers, num_pairs);
|
|
|
|
dim3 block(std::min(1024, numel_per_block));
|
2024-01-03 11:09:59 +08:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
2023-04-07 17:45:07 -07:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2024-01-29 08:43:54 +08:00
|
|
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
2023-04-07 17:45:07 -07:00
|
|
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
2023-06-17 03:07:40 -07:00
|
|
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
2023-04-07 17:45:07 -07:00
|
|
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
|
|
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
2023-10-31 15:19:30 -07:00
|
|
|
block_mapping_tensor.data_ptr<int64_t>(),
|
2023-04-07 17:45:07 -07:00
|
|
|
numel_per_block);
|
|
|
|
}));
|
2023-03-10 09:58:21 -08:00
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-03-13 13:48:38 -07:00
|
|
|
|
2024-01-29 08:43:54 +08:00
|
|
|
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
2023-02-18 19:22:57 +00:00
|
|
|
__global__ void reshape_and_cache_kernel(
|
2023-10-31 15:19:30 -07:00
|
|
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
|
|
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
2024-01-29 08:43:54 +08:00
|
|
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
2023-04-02 00:30:17 -07:00
|
|
|
const int key_stride,
|
|
|
|
const int value_stride,
|
2023-02-18 19:22:57 +00:00
|
|
|
const int num_heads,
|
|
|
|
const int head_size,
|
|
|
|
const int block_size,
|
|
|
|
const int x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
const int64_t slot_idx = slot_mapping[token_idx];
|
2023-10-16 17:48:42 -07:00
|
|
|
if (slot_idx < 0) {
|
|
|
|
// Padding token that should be ignored.
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t block_idx = slot_idx / block_size;
|
|
|
|
const int64_t block_offset = slot_idx % block_size;
|
2023-02-18 19:22:57 +00:00
|
|
|
|
|
|
|
const int n = num_heads * head_size;
|
|
|
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t src_key_idx = token_idx * key_stride + i;
|
|
|
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
2023-02-18 19:22:57 +00:00
|
|
|
|
|
|
|
const int head_idx = i / head_size;
|
|
|
|
const int head_offset = i % head_size;
|
|
|
|
const int x_idx = head_offset / x;
|
|
|
|
const int x_offset = head_offset % x;
|
|
|
|
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
|
|
|
+ head_idx * (head_size / x) * block_size * x
|
|
|
|
+ x_idx * block_size * x
|
|
|
|
+ block_offset * x
|
|
|
|
+ x_offset;
|
|
|
|
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
|
|
|
+ head_idx * head_size * block_size
|
|
|
|
+ head_offset * block_size
|
|
|
|
+ block_offset;
|
2024-01-29 08:43:54 +08:00
|
|
|
scalar_t tgt_key = key[src_key_idx];
|
|
|
|
scalar_t tgt_value = value[src_value_idx];
|
|
|
|
if constexpr (is_fp8_e5m2_kv_cache) {
|
|
|
|
#ifdef ENABLE_FP8_E5M2
|
|
|
|
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
|
|
|
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
|
|
|
#else
|
|
|
|
assert(false);
|
|
|
|
#endif
|
|
|
|
} else {
|
|
|
|
key_cache[tgt_key_idx] = tgt_key;
|
|
|
|
value_cache[tgt_value_idx] = tgt_value;
|
|
|
|
}
|
2023-02-18 19:22:57 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
} // namespace vllm
|
2023-05-03 14:09:44 -07:00
|
|
|
|
2024-01-29 08:43:54 +08:00
|
|
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
|
|
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
|
|
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
|
|
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
|
|
|
slot_mapping.data_ptr<int64_t>(), \
|
|
|
|
key_stride, \
|
|
|
|
value_stride, \
|
|
|
|
num_heads, \
|
|
|
|
head_size, \
|
|
|
|
block_size, \
|
|
|
|
x);
|
|
|
|
|
2023-05-03 14:09:44 -07:00
|
|
|
void reshape_and_cache(
|
|
|
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
2024-01-29 08:43:54 +08:00
|
|
|
torch::Tensor& slot_mapping, // [num_tokens]
|
|
|
|
const std::string& kv_cache_dtype)
|
2023-05-03 14:09:44 -07:00
|
|
|
{
|
|
|
|
int num_tokens = key.size(0);
|
|
|
|
int num_heads = key.size(1);
|
|
|
|
int head_size = key.size(2);
|
|
|
|
int block_size = key_cache.size(3);
|
|
|
|
int x = key_cache.size(4);
|
|
|
|
|
|
|
|
int key_stride = key.stride(0);
|
|
|
|
int value_stride = value.stride(0);
|
|
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(std::min(num_heads * head_size, 512));
|
2024-01-03 11:09:59 +08:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
2023-05-03 14:09:44 -07:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2024-01-29 08:43:54 +08:00
|
|
|
if (kv_cache_dtype == "auto") {
|
|
|
|
if (key.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(float, float, false);
|
|
|
|
} else if (key.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
|
|
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
|
|
|
}
|
|
|
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
|
|
|
if (key.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
|
|
|
} else if (key.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
|
|
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
|
|
|
}
|
2023-05-03 14:09:44 -07:00
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-05-03 14:09:44 -07:00
|
|
|
|
2024-01-29 08:43:54 +08:00
|
|
|
template<typename Tout, typename Tin>
|
|
|
|
__global__ void convert_fp8_e5m2_kernel(
|
|
|
|
const Tin* __restrict__ src_cache,
|
|
|
|
Tout* __restrict__ dst_cache,
|
|
|
|
const int64_t block_stride) {
|
|
|
|
const int64_t block_idx = blockIdx.x;
|
|
|
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
|
|
|
int64_t idx = block_idx * block_stride + i;
|
|
|
|
#ifdef ENABLE_FP8_E5M2
|
|
|
|
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
|
|
|
#else
|
|
|
|
assert(false);
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace vllm
|
|
|
|
|
|
|
|
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
|
|
|
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
|
|
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
|
|
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
|
|
|
block_stride);
|
|
|
|
|
|
|
|
void convert_fp8_e5m2(
|
|
|
|
torch::Tensor& src_cache,
|
|
|
|
torch::Tensor& dst_cache)
|
|
|
|
{
|
|
|
|
int64_t num_blocks = src_cache.size(0);
|
|
|
|
int64_t block_stride = src_cache.stride(0);
|
|
|
|
|
|
|
|
dim3 grid(num_blocks);
|
|
|
|
dim3 block(std::min(block_stride, int64_t(512)));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
|
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
|
|
|
}
|
|
|
|
}
|