From e82ee40de3362afda8671e6f5daece0eaa7f0d51 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 16 Apr 2025 18:31:39 +0800 Subject: [PATCH] [Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel (#16693) Signed-off-by: DefTruth --- csrc/attention/merge_attn_states.cu | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 7af0cace..14e5edd7 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel( #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ { \ - vllm::merge_attn_states_kernel<<>>( \ - reinterpret_cast(output.data_ptr()), output_lse_ptr, \ - reinterpret_cast(prefix_output.data_ptr()), \ - reinterpret_cast(prefix_lse.data_ptr()), \ - reinterpret_cast(suffix_output.data_ptr()), \ - reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + vllm::merge_attn_states_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ + num_heads, head_size); \ } /*@brief Merges the attention states from prefix and suffix @@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel( * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. - * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. - * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * states. */ template @@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output, if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); } - // process one pack elements per thread. float -> 4, half/bf16 -> 8 + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); }