2024-06-09 16:23:30 -04:00
|
|
|
#include <torch/all.h>
|
2023-02-16 07:47:03 +00:00
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
2024-01-03 11:09:59 +08:00
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
#include "cuda_compat.h"
|
2023-09-02 14:59:47 +09:00
|
|
|
#include "dispatch_utils.h"
|
2024-05-09 17:04:17 -07:00
|
|
|
|
|
|
|
#ifdef USE_ROCM
|
2024-05-22 03:18:41 -04:00
|
|
|
#include "quantization/fp8/amd/quant_utils.cuh"
|
2024-05-09 17:04:17 -07:00
|
|
|
#else
|
2024-05-22 03:18:41 -04:00
|
|
|
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
2024-02-02 01:35:09 +08:00
|
|
|
#endif
|
2023-09-02 14:59:47 +09:00
|
|
|
|
2023-02-18 19:22:57 +00:00
|
|
|
#include <algorithm>
|
2023-02-16 07:47:03 +00:00
|
|
|
#include <cassert>
|
|
|
|
#include <map>
|
2023-03-10 09:58:21 -08:00
|
|
|
#include <vector>
|
2023-02-16 07:47:03 +00:00
|
|
|
|
2024-02-02 01:35:09 +08:00
|
|
|
#ifdef USE_ROCM
|
|
|
|
#include <hip/hip_bf16.h>
|
2024-05-22 03:18:41 -04:00
|
|
|
typedef __hip_bfloat16 __nv_bfloat16;
|
2024-02-02 01:35:09 +08:00
|
|
|
#endif
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
|
|
|
const torch::Tensor& block_mapping) {
|
2023-02-16 07:47:03 +00:00
|
|
|
torch::Device src_device = src.device();
|
|
|
|
torch::Device dst_device = dst.device();
|
|
|
|
cudaMemcpyKind memcpy_type;
|
|
|
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
2024-05-22 03:18:41 -04:00
|
|
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
|
|
|
"src and dst must be on the same GPU");
|
2023-02-16 07:47:03 +00:00
|
|
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
|
|
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
|
|
|
memcpy_type = cudaMemcpyDeviceToHost;
|
|
|
|
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
|
|
|
memcpy_type = cudaMemcpyHostToDevice;
|
|
|
|
} else {
|
2023-05-03 14:09:44 -07:00
|
|
|
TORCH_CHECK(false, "Invalid device combination");
|
2023-02-16 07:47:03 +00:00
|
|
|
}
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
2024-05-08 12:07:05 -07:00
|
|
|
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
|
|
|
// synchronization.
|
|
|
|
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
char* src_ptr = static_cast<char*>(src.data_ptr());
|
|
|
|
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
2023-02-16 07:47:03 +00:00
|
|
|
|
|
|
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
2024-05-22 03:18:41 -04:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(
|
|
|
|
src_device.is_cuda() ? src_device : dst_device);
|
2023-02-16 07:47:03 +00:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2023-05-03 14:09:44 -07:00
|
|
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
2024-05-08 12:07:05 -07:00
|
|
|
const int64_t num_blocks = block_mapping.size(0);
|
|
|
|
for (size_t i = 0; i < num_blocks; i++) {
|
|
|
|
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
|
|
|
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
2023-02-16 07:47:03 +00:00
|
|
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
|
|
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
2024-05-22 03:18:41 -04:00
|
|
|
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
|
|
|
block_size_in_bytes, memcpy_type, stream);
|
2023-02-16 07:47:03 +00:00
|
|
|
}
|
|
|
|
}
|
2023-02-18 19:22:57 +00:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-04-07 17:45:07 -07:00
|
|
|
|
|
|
|
// Grid: (num_layers, num_pairs)
|
2024-05-22 03:18:41 -04:00
|
|
|
template <typename scalar_t>
|
|
|
|
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
|
|
|
int64_t* value_cache_ptrs,
|
|
|
|
const int64_t* __restrict__ block_mapping,
|
|
|
|
const int numel_per_block) {
|
2023-04-07 17:45:07 -07:00
|
|
|
const int layer_idx = blockIdx.x;
|
|
|
|
const int pair_idx = blockIdx.y;
|
|
|
|
|
|
|
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
2024-05-22 03:18:41 -04:00
|
|
|
scalar_t* value_cache =
|
|
|
|
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
|
|
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
2023-04-07 17:45:07 -07:00
|
|
|
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t src_block_offset = src_block_number * numel_per_block;
|
|
|
|
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
2023-04-07 17:45:07 -07:00
|
|
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_offset = src_block_offset + i;
|
|
|
|
int64_t dst_offset = dst_block_offset + i;
|
2023-04-07 17:45:07 -07:00
|
|
|
key_cache[dst_offset] = key_cache[src_offset];
|
|
|
|
}
|
|
|
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
int64_t src_offset = src_block_offset + i;
|
|
|
|
int64_t dst_offset = dst_block_offset + i;
|
2023-04-07 17:45:07 -07:00
|
|
|
value_cache[dst_offset] = value_cache[src_offset];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
} // namespace vllm
|
2023-04-07 17:45:07 -07:00
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
// Note: the key_caches and value_caches vectors are constant but
|
|
|
|
// not the Tensors they contain. The vectors need to be const refs
|
|
|
|
// in order to satisfy pytorch's C++ operator registration code.
|
|
|
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|
|
|
std::vector<torch::Tensor> const& value_caches,
|
2024-05-22 03:18:41 -04:00
|
|
|
const torch::Tensor& block_mapping) {
|
2023-04-07 17:45:07 -07:00
|
|
|
int num_layers = key_caches.size();
|
|
|
|
TORCH_CHECK(num_layers == value_caches.size());
|
|
|
|
if (num_layers == 0) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
torch::Device cache_device = key_caches[0].device();
|
|
|
|
TORCH_CHECK(cache_device.is_cuda());
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-04-07 17:45:07 -07:00
|
|
|
// Create data structures for the kernel.
|
|
|
|
// Create an array of pointers to the key and value caches.
|
|
|
|
int64_t key_cache_ptrs[num_layers];
|
|
|
|
int64_t value_cache_ptrs[num_layers];
|
|
|
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
2024-05-22 03:18:41 -04:00
|
|
|
key_cache_ptrs[layer_idx] =
|
|
|
|
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
|
|
|
value_cache_ptrs[layer_idx] =
|
|
|
|
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
2023-04-07 17:45:07 -07:00
|
|
|
}
|
2024-05-06 21:30:27 -07:00
|
|
|
|
|
|
|
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
|
|
|
int num_pairs = block_mapping.size(0);
|
2023-04-07 17:45:07 -07:00
|
|
|
|
|
|
|
// Move the data structures to the GPU.
|
|
|
|
// NOTE: This synchronizes the CPU and GPU.
|
2024-05-22 03:18:41 -04:00
|
|
|
torch::Tensor key_cache_ptrs_tensor =
|
|
|
|
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
|
|
|
.to(cache_device);
|
|
|
|
torch::Tensor value_cache_ptrs_tensor =
|
|
|
|
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
|
|
|
.to(cache_device);
|
2023-04-07 17:45:07 -07:00
|
|
|
|
|
|
|
// Launch the kernel.
|
|
|
|
const int numel_per_block = key_caches[0][0].numel();
|
|
|
|
dim3 grid(num_layers, num_pairs);
|
|
|
|
dim3 block(std::min(1024, numel_per_block));
|
2024-01-03 11:09:59 +08:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
2023-04-07 17:45:07 -07:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2024-01-29 08:43:54 +08:00
|
|
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
2024-05-22 03:18:41 -04:00
|
|
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
|
|
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
|
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
|
|
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
|
|
|
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
|
|
|
}));
|
2023-03-10 09:58:21 -08:00
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-03-13 13:48:38 -07:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
2023-02-18 19:22:57 +00:00
|
|
|
__global__ void reshape_and_cache_kernel(
|
2024-05-22 03:18:41 -04:00
|
|
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
|
|
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
|
|
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
|
|
|
// block_size, x]
|
|
|
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
|
|
|
// block_size]
|
|
|
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
|
|
|
const int key_stride, const int value_stride, const int num_heads,
|
2025-01-23 13:04:03 -05:00
|
|
|
const int head_size, const int block_size, const int x,
|
|
|
|
const float* k_scale, const float* v_scale) {
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
const int64_t slot_idx = slot_mapping[token_idx];
|
2023-10-16 17:48:42 -07:00
|
|
|
if (slot_idx < 0) {
|
|
|
|
// Padding token that should be ignored.
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t block_idx = slot_idx / block_size;
|
|
|
|
const int64_t block_offset = slot_idx % block_size;
|
2023-02-18 19:22:57 +00:00
|
|
|
|
|
|
|
const int n = num_heads * head_size;
|
|
|
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
2023-10-31 15:19:30 -07:00
|
|
|
const int64_t src_key_idx = token_idx * key_stride + i;
|
|
|
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
2023-02-18 19:22:57 +00:00
|
|
|
|
|
|
|
const int head_idx = i / head_size;
|
|
|
|
const int head_offset = i % head_size;
|
|
|
|
const int x_idx = head_offset / x;
|
|
|
|
const int x_offset = head_offset % x;
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
const int64_t tgt_key_idx =
|
|
|
|
block_idx * num_heads * (head_size / x) * block_size * x +
|
|
|
|
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
|
|
|
block_offset * x + x_offset;
|
|
|
|
const int64_t tgt_value_idx =
|
|
|
|
block_idx * num_heads * head_size * block_size +
|
|
|
|
head_idx * head_size * block_size + head_offset * block_size +
|
|
|
|
block_offset;
|
2024-01-29 08:43:54 +08:00
|
|
|
scalar_t tgt_key = key[src_key_idx];
|
|
|
|
scalar_t tgt_value = value[src_value_idx];
|
2024-05-09 17:04:17 -07:00
|
|
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
2024-01-29 08:43:54 +08:00
|
|
|
key_cache[tgt_key_idx] = tgt_key;
|
|
|
|
value_cache[tgt_value_idx] = tgt_value;
|
2024-05-09 17:04:17 -07:00
|
|
|
} else {
|
2024-05-22 03:18:41 -04:00
|
|
|
key_cache[tgt_key_idx] =
|
2025-01-23 13:04:03 -05:00
|
|
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
2024-05-22 03:18:41 -04:00
|
|
|
value_cache[tgt_value_idx] =
|
2025-01-23 13:04:03 -05:00
|
|
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
2024-01-29 08:43:54 +08:00
|
|
|
}
|
2023-02-18 19:22:57 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-24 11:36:52 -07:00
|
|
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
2024-05-03 15:51:27 -07:00
|
|
|
__global__ void reshape_and_cache_flash_kernel(
|
2024-05-22 03:18:41 -04:00
|
|
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
|
|
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
2024-07-24 11:36:52 -07:00
|
|
|
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
2024-05-22 03:18:41 -04:00
|
|
|
// head_size]
|
2024-07-24 11:36:52 -07:00
|
|
|
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
2024-05-22 03:18:41 -04:00
|
|
|
// head_size]
|
|
|
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
|
|
|
const int block_stride, const int key_stride, const int value_stride,
|
2024-07-24 11:36:52 -07:00
|
|
|
const int num_heads, const int head_size, const int block_size,
|
2025-01-23 13:04:03 -05:00
|
|
|
const float* k_scale, const float* v_scale) {
|
2024-05-03 15:51:27 -07:00
|
|
|
const int64_t token_idx = blockIdx.x;
|
|
|
|
const int64_t slot_idx = slot_mapping[token_idx];
|
|
|
|
// NOTE: slot_idx can be -1 if the token is padded
|
|
|
|
if (slot_idx < 0) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
const int64_t block_idx = slot_idx / block_size;
|
|
|
|
const int64_t block_offset = slot_idx % block_size;
|
|
|
|
const int n = num_heads * head_size;
|
|
|
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
|
|
|
const int64_t src_key_idx = token_idx * key_stride + i;
|
|
|
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
|
|
|
const int head_idx = i / head_size;
|
|
|
|
const int head_offset = i % head_size;
|
2024-07-24 11:36:52 -07:00
|
|
|
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
|
|
|
block_offset * num_heads * head_size +
|
|
|
|
head_idx * head_size + head_offset;
|
|
|
|
scalar_t tgt_key = key[src_key_idx];
|
|
|
|
scalar_t tgt_value = value[src_value_idx];
|
|
|
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
|
|
|
key_cache[tgt_key_value_idx] = tgt_key;
|
|
|
|
value_cache[tgt_key_value_idx] = tgt_value;
|
|
|
|
} else {
|
|
|
|
key_cache[tgt_key_value_idx] =
|
2025-01-23 13:04:03 -05:00
|
|
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
2024-07-24 11:36:52 -07:00
|
|
|
value_cache[tgt_key_value_idx] =
|
2025-01-23 13:04:03 -05:00
|
|
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
2024-07-24 11:36:52 -07:00
|
|
|
}
|
2024-05-03 15:51:27 -07:00
|
|
|
}
|
|
|
|
}
|
2024-05-22 03:18:41 -04:00
|
|
|
} // namespace vllm
|
2023-05-03 14:09:44 -07:00
|
|
|
|
2024-05-09 17:04:17 -07:00
|
|
|
// KV_T is the stored data type of kv-cache.
|
|
|
|
// CACHE_T is the data type of key and value tensors.
|
|
|
|
// KV_DTYPE is the real data type of kv-cache.
|
2024-05-22 03:18:41 -04:00
|
|
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
|
|
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
|
|
|
<<<grid, block, 0, stream>>>( \
|
|
|
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
|
|
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
|
|
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
2025-01-23 13:04:03 -05:00
|
|
|
num_heads, head_size, block_size, x, \
|
|
|
|
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
|
|
|
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
2024-01-29 08:43:54 +08:00
|
|
|
|
2023-05-03 14:09:44 -07:00
|
|
|
void reshape_and_cache(
|
2024-05-22 03:18:41 -04:00
|
|
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor&
|
|
|
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
torch::Tensor&
|
|
|
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
|
torch::Tensor& slot_mapping, // [num_tokens]
|
2025-01-23 13:04:03 -05:00
|
|
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
|
|
|
torch::Tensor& v_scale) {
|
2023-05-03 14:09:44 -07:00
|
|
|
int num_tokens = key.size(0);
|
|
|
|
int num_heads = key.size(1);
|
|
|
|
int head_size = key.size(2);
|
|
|
|
int block_size = key_cache.size(3);
|
|
|
|
int x = key_cache.size(4);
|
|
|
|
|
|
|
|
int key_stride = key.stride(0);
|
|
|
|
int value_stride = value.stride(0);
|
|
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(std::min(num_heads * head_size, 512));
|
2024-01-03 11:09:59 +08:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
2023-05-03 14:09:44 -07:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2024-05-09 17:04:17 -07:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
|
|
|
CALL_RESHAPE_AND_CACHE)
|
2023-05-03 14:09:44 -07:00
|
|
|
}
|
|
|
|
|
2024-07-24 11:36:52 -07:00
|
|
|
// KV_T is the stored data type of kv-cache.
|
|
|
|
// CACHE_T is the data type of key and value tensors.
|
|
|
|
// KV_DTYPE is the real data type of kv-cache.
|
|
|
|
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
|
|
|
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
|
|
|
<<<grid, block, 0, stream>>>( \
|
|
|
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
|
|
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
|
|
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
|
|
|
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
2025-01-23 13:04:03 -05:00
|
|
|
value_stride, num_heads, head_size, block_size, \
|
|
|
|
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
|
|
|
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
2024-07-24 11:36:52 -07:00
|
|
|
|
2024-05-03 15:51:27 -07:00
|
|
|
void reshape_and_cache_flash(
|
2024-07-24 11:36:52 -07:00
|
|
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
|
|
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
|
|
|
torch::Tensor&
|
|
|
|
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
2024-12-09 12:38:46 -08:00
|
|
|
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
2025-01-23 13:04:03 -05:00
|
|
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
|
|
|
torch::Tensor& v_scale) {
|
2024-12-09 12:38:46 -08:00
|
|
|
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
|
|
|
// slot_mapping.size(0) because of padding for CUDA graphs.
|
|
|
|
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
|
|
|
// both include padding.
|
|
|
|
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
|
|
|
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
|
|
|
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
|
|
|
// before padding.
|
|
|
|
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
|
|
|
// number of tokens.
|
|
|
|
int num_tokens = slot_mapping.size(0);
|
2024-05-03 15:51:27 -07:00
|
|
|
int num_heads = key.size(1);
|
|
|
|
int head_size = key.size(2);
|
2024-07-24 11:36:52 -07:00
|
|
|
int block_size = key_cache.size(1);
|
2024-05-03 15:51:27 -07:00
|
|
|
|
|
|
|
int key_stride = key.stride(0);
|
|
|
|
int value_stride = value.stride(0);
|
2024-07-24 11:36:52 -07:00
|
|
|
int block_stride = key_cache.stride(0);
|
|
|
|
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
2024-05-03 15:51:27 -07:00
|
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(std::min(num_heads * head_size, 512));
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2024-07-24 11:36:52 -07:00
|
|
|
|
|
|
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
|
|
|
CALL_RESHAPE_AND_CACHE_FLASH);
|
2024-05-03 15:51:27 -07:00
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-05-03 14:09:44 -07:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
|
|
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
|
|
|
Tout* __restrict__ dst_cache,
|
2024-07-16 18:31:32 -04:00
|
|
|
const float scale,
|
2024-05-22 03:18:41 -04:00
|
|
|
const int64_t block_stride) {
|
2024-01-29 08:43:54 +08:00
|
|
|
const int64_t block_idx = blockIdx.x;
|
|
|
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
|
|
|
int64_t idx = block_idx * block_stride + i;
|
2024-05-22 03:18:41 -04:00
|
|
|
dst_cache[idx] =
|
2024-07-16 18:31:32 -04:00
|
|
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
2024-01-29 08:43:54 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
} // namespace vllm
|
2024-01-29 08:43:54 +08:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
|
|
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
|
|
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
2024-07-16 18:31:32 -04:00
|
|
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
2024-01-29 08:43:54 +08:00
|
|
|
|
2024-05-09 17:04:17 -07:00
|
|
|
// Only for testing.
|
2024-05-22 03:18:41 -04:00
|
|
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
2024-07-16 18:31:32 -04:00
|
|
|
const double scale, const std::string& kv_cache_dtype) {
|
2024-04-03 16:15:55 -05:00
|
|
|
torch::Device src_device = src_cache.device();
|
|
|
|
torch::Device dst_device = dst_cache.device();
|
|
|
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
|
|
|
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
2024-05-22 03:18:41 -04:00
|
|
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
|
|
|
"src and dst must be on the same GPU");
|
2024-04-03 16:15:55 -05:00
|
|
|
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
|
|
|
|
2024-01-29 08:43:54 +08:00
|
|
|
int64_t num_blocks = src_cache.size(0);
|
|
|
|
int64_t block_stride = src_cache.stride(0);
|
|
|
|
|
|
|
|
dim3 grid(num_blocks);
|
|
|
|
dim3 block(std::min(block_stride, int64_t(512)));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
2024-05-09 17:04:17 -07:00
|
|
|
if (kv_cache_dtype == "auto") {
|
|
|
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
|
|
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
|
|
|
}
|
|
|
|
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
|
|
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
|
|
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
2024-05-22 03:18:41 -04:00
|
|
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
|
|
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
2024-05-09 17:04:17 -07:00
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
|
|
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
|
|
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
|
|
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
2024-05-22 03:18:41 -04:00
|
|
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
|
|
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
2024-05-09 17:04:17 -07:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
2024-01-29 08:43:54 +08:00
|
|
|
}
|
|
|
|
}
|