[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel (#16693)
Signed-off-by: DefTruth <qiustudent_r@163.com>
This commit is contained in:
parent
facbe2a114
commit
e82ee40de3
@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
|
|||||||
|
|
||||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||||
{ \
|
{ \
|
||||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
|
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
<<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||||
num_heads, head_size); \
|
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||||
|
num_heads, head_size); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/*@brief Merges the attention states from prefix and suffix
|
/*@brief Merges the attention states from prefix and suffix
|
||||||
@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
|
|||||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||||
* @param prefix_output [n,h,d] The prefix attention states.
|
* @param prefix_output [n,h,d] The prefix attention states.
|
||||||
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
|
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||||
* states.
|
* states.
|
||||||
* @param suffix_output [n,h,d] The suffix attention states.
|
* @param suffix_output [n,h,d] The suffix attention states.
|
||||||
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
|
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||||
* states.
|
* states.
|
||||||
*/
|
*/
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
|||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
}
|
}
|
||||||
// process one pack elements per thread. float -> 4, half/bf16 -> 8
|
// Process one pack elements per thread. for float, the
|
||||||
|
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||||
const uint threads_per_head = head_size / pack_size;
|
const uint threads_per_head = head_size / pack_size;
|
||||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||||
|
|
||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||||
|
|
||||||
|
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user