[Kernel] Remove redundant Exp calculations (#16123)
Signed-off-by: DefTruth <qiustudent_r@163.com>
This commit is contained in:
parent
1666e66443
commit
280d62b8a2
@ -66,7 +66,10 @@ def merge_attn_states_kernel(
|
|||||||
max_lse = tl.maximum(p_lse, s_lse)
|
max_lse = tl.maximum(p_lse, s_lse)
|
||||||
p_lse = p_lse - max_lse
|
p_lse = p_lse - max_lse
|
||||||
s_lse = s_lse - max_lse
|
s_lse = s_lse - max_lse
|
||||||
out_se = (tl.exp(p_lse) + tl.exp(s_lse))
|
# Will reuse precomputed Exp values for scale factor computation.
|
||||||
|
p_se = tl.exp(p_lse)
|
||||||
|
s_se = tl.exp(s_lse)
|
||||||
|
out_se = (p_se + s_se)
|
||||||
|
|
||||||
if OUTPUT_LSE:
|
if OUTPUT_LSE:
|
||||||
out_lse = tl.log(out_se) + max_lse
|
out_lse = tl.log(out_se) + max_lse
|
||||||
@ -84,8 +87,8 @@ def merge_attn_states_kernel(
|
|||||||
# NOTE(woosuk): Be careful with the numerical stability.
|
# NOTE(woosuk): Be careful with the numerical stability.
|
||||||
# We should compute the scale first, and then multiply it with the output.
|
# We should compute the scale first, and then multiply it with the output.
|
||||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||||
p_scale = tl.exp(p_lse) / out_se
|
p_scale = p_se / out_se
|
||||||
s_scale = tl.exp(s_lse) / out_se
|
s_scale = s_se / out_se
|
||||||
out = p_out * p_scale + s_out * s_scale
|
out = p_out * p_scale + s_out * s_scale
|
||||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||||
head_idx * HEAD_SIZE + head_arange,
|
head_idx * HEAD_SIZE + head_arange,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user