[Minor] Remove gather_cached_kv kernel (#3043)
This commit is contained in:
parent
cfc15a1031
commit
d6e4a130b0
@ -23,13 +23,6 @@ void reshape_and_cache(
|
|||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void gather_cached_kv(
|
|
||||||
torch::Tensor& key,
|
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping);
|
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8_e5m2(
|
void convert_fp8_e5m2(
|
||||||
torch::Tensor& src_cache,
|
torch::Tensor& src_cache,
|
||||||
|
@ -269,167 +269,6 @@ void reshape_and_cache(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Grid: (num_blocks, block_size).
|
|
||||||
template<typename scalar_t>
|
|
||||||
__global__ void gather_cached_kv_kernel(
|
|
||||||
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
|
||||||
const int key_stride,
|
|
||||||
const int value_stride,
|
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int x) {
|
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
|
||||||
const int block_idx = slot_idx / block_size;
|
|
||||||
const int block_offset = slot_idx % block_size;
|
|
||||||
|
|
||||||
const int num_tokens = num_heads * head_size;
|
|
||||||
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
|
||||||
const int tgt_key_idx = token_idx * key_stride + i;
|
|
||||||
const int tgt_value_idx = token_idx * value_stride + i;
|
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
|
||||||
const int head_offset = i % head_size;
|
|
||||||
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
|
||||||
const int x_offset = head_offset % x;
|
|
||||||
|
|
||||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
|
||||||
+ x_idx * block_size * x
|
|
||||||
+ block_offset * x
|
|
||||||
+ x_offset;
|
|
||||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
|
||||||
+ head_idx * head_size * block_size
|
|
||||||
+ head_offset * block_size
|
|
||||||
+ block_offset;
|
|
||||||
|
|
||||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
|
||||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void gather_cached_kv_kernel_optimized(
|
|
||||||
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
const int *__restrict__ slot_mapping, // [num_tokens]
|
|
||||||
const int key_stride,
|
|
||||||
const int value_stride,
|
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int x)
|
|
||||||
{
|
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
|
||||||
const int block_idx = slot_idx / block_size;
|
|
||||||
const int block_offset = slot_idx % block_size;
|
|
||||||
|
|
||||||
const int dim = num_heads * head_size;
|
|
||||||
assert(dim % 4 == 0); // this is true for known use cases
|
|
||||||
const int unroll_factor = 4;
|
|
||||||
const int unrolled_dim = dim / unroll_factor;
|
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
|
||||||
{
|
|
||||||
int tgt_key_indices[unroll_factor];
|
|
||||||
int tgt_value_indices[unroll_factor];
|
|
||||||
int src_key_indices[unroll_factor];
|
|
||||||
int src_value_indices[unroll_factor];
|
|
||||||
scalar_t keys_to_store[unroll_factor];
|
|
||||||
scalar_t values_to_store[unroll_factor];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < unroll_factor; ++j)
|
|
||||||
{
|
|
||||||
int index = i + j * unrolled_dim;
|
|
||||||
|
|
||||||
const int tgt_key_idx = token_idx * key_stride + index;
|
|
||||||
const int tgt_value_idx = token_idx * value_stride + index;
|
|
||||||
|
|
||||||
const int head_idx = index / head_size;
|
|
||||||
const int head_offset = index % head_size;
|
|
||||||
const int x_idx = head_offset / x;
|
|
||||||
const int x_offset = head_offset % x;
|
|
||||||
|
|
||||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
|
||||||
+ x_idx * block_size * x
|
|
||||||
+ block_offset * x
|
|
||||||
+ x_offset;
|
|
||||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
|
||||||
+ head_idx * head_size * block_size
|
|
||||||
+ head_offset * block_size
|
|
||||||
+ block_offset;
|
|
||||||
|
|
||||||
tgt_key_indices[j] = tgt_key_idx;
|
|
||||||
tgt_value_indices[j] = tgt_value_idx;
|
|
||||||
src_key_indices[j] = src_key_idx;
|
|
||||||
src_value_indices[j] = src_value_idx;
|
|
||||||
|
|
||||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
|
||||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < unroll_factor; ++j)
|
|
||||||
{
|
|
||||||
key[tgt_key_indices[j]] = keys_to_store[j];
|
|
||||||
value[tgt_value_indices[j]] = values_to_store[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
|
|
||||||
void gather_cached_kv(
|
|
||||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
|
||||||
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
|
||||||
{
|
|
||||||
int num_tokens = key.size(0);
|
|
||||||
int num_heads = key.size(1);
|
|
||||||
int head_size = key.size(2);
|
|
||||||
int block_size = key_cache.size(3);
|
|
||||||
int x = key_cache.size(4);
|
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
|
||||||
int value_stride = value.stride(0);
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
|
||||||
key.scalar_type(),
|
|
||||||
"gather_cached_kv_kernel_optimized",
|
|
||||||
[&] {
|
|
||||||
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
key.data_ptr<scalar_t>(),
|
|
||||||
value.data_ptr<scalar_t>(),
|
|
||||||
key_cache.data_ptr<scalar_t>(),
|
|
||||||
value_cache.data_ptr<scalar_t>(),
|
|
||||||
slot_mapping.data_ptr<int>(),
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
block_size,
|
|
||||||
x);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
template<typename Tout, typename Tin>
|
||||||
__global__ void convert_fp8_e5m2_kernel(
|
__global__ void convert_fp8_e5m2_kernel(
|
||||||
const Tin* __restrict__ src_cache,
|
const Tin* __restrict__ src_cache,
|
||||||
|
@ -79,10 +79,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
"reshape_and_cache",
|
"reshape_and_cache",
|
||||||
&reshape_and_cache,
|
&reshape_and_cache,
|
||||||
"Reshape the key and value tensors and cache them");
|
"Reshape the key and value tensors and cache them");
|
||||||
cache_ops.def(
|
|
||||||
"gather_cached_kv",
|
|
||||||
&gather_cached_kv,
|
|
||||||
"Gather key and value from the cache into contiguous QKV tensors");
|
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"convert_fp8_e5m2",
|
"convert_fp8_e5m2",
|
||||||
&convert_fp8_e5m2,
|
&convert_fp8_e5m2,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user