[Minor] Remove gather_cached_kv kernel (#3043)

This commit is contained in:
Woosuk Kwon 2024-02-26 15:00:54 -08:00 committed by GitHub
parent cfc15a1031
commit d6e4a130b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 0 additions and 172 deletions

View File

@ -23,13 +23,6 @@ void reshape_and_cache(
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype); const std::string& kv_cache_dtype);
void gather_cached_kv(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
// Just for unittest // Just for unittest
void convert_fp8_e5m2( void convert_fp8_e5m2(
torch::Tensor& src_cache, torch::Tensor& src_cache,

View File

@ -269,167 +269,6 @@ void reshape_and_cache(
namespace vllm { namespace vllm {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;
const int num_tokens = num_heads * head_size;
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x;
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
}
}
template <typename scalar_t>
__global__ void gather_cached_kv_kernel_optimized(
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int *__restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x)
{
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;
const int dim = num_heads * head_size;
assert(dim % 4 == 0); // this is true for known use cases
const int unroll_factor = 4;
const int unrolled_dim = dim / unroll_factor;
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
{
int tgt_key_indices[unroll_factor];
int tgt_value_indices[unroll_factor];
int src_key_indices[unroll_factor];
int src_value_indices[unroll_factor];
scalar_t keys_to_store[unroll_factor];
scalar_t values_to_store[unroll_factor];
#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
int index = i + j * unrolled_dim;
const int tgt_key_idx = token_idx * key_stride + index;
const int tgt_value_idx = token_idx * value_stride + index;
const int head_idx = index / head_size;
const int head_offset = index % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
tgt_key_indices[j] = tgt_key_idx;
tgt_value_indices[j] = tgt_value_idx;
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
}
#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
key[tgt_key_indices[j]] = keys_to_store[j];
value[tgt_value_indices[j]] = values_to_store[j];
}
}
}
} // namespace vllm
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [in] [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
namespace vllm {
template<typename Tout, typename Tin> template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel( __global__ void convert_fp8_e5m2_kernel(
const Tin* __restrict__ src_cache, const Tin* __restrict__ src_cache,

View File

@ -79,10 +79,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache", "reshape_and_cache",
&reshape_and_cache, &reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
cache_ops.def( cache_ops.def(
"convert_fp8_e5m2", "convert_fp8_e5m2",
&convert_fp8_e5m2, &convert_fp8_e5m2,