From 280d62b8a2cfd456e42aa3e4f9abb59fb87124ca Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 15 Apr 2025 20:58:37 +0800 Subject: [PATCH] [Kernel] Remove redundant Exp calculations (#16123) Signed-off-by: DefTruth --- vllm/attention/ops/triton_merge_attn_states.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 9671b933..250426d9 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -66,7 +66,10 @@ def merge_attn_states_kernel( max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse - out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + # Will reuse precomputed Exp values for scale factor computation. + p_se = tl.exp(p_lse) + s_se = tl.exp(s_lse) + out_se = (p_se + s_se) if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse @@ -84,8 +87,8 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / out_se - s_scale = tl.exp(s_lse) / out_se + p_scale = p_se / out_se + s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store(output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,