[Kernel] Explicitly specify other value in tl.load calls (#9014)
Signed-off-by: Angus Wang <wangjadehao@gmail.com>
This commit is contained in:
parent
6b2d25efc7
commit
c2170a5b39
@ -157,19 +157,22 @@ def _fwd_kernel_inner(
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD)
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
@ -200,19 +203,22 @@ def _fwd_kernel_inner(
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD)
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference(
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
|
@ -75,7 +75,9 @@ def _bgmv_expand_kernel(
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||
mask=c_mask,
|
||||
other=0.0)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel(
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
# explicitly pass in other=None to tell triton that masked values
|
||||
# can be uninitialized. This is OK because the later tl.store
|
||||
# operation uses the same mask, eliminating the risk of garbage
|
||||
# values propagating
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||
mask=c_mask,
|
||||
other=None)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
@ -88,7 +88,10 @@ def _sgmv_expand_kernel(
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
# explicitly pass in other=None to tell triton that masked values
|
||||
# can be uninitialized. This is OK because the later tl.store operation
|
||||
# uses the same mask, eliminating the risk of garbage values propagating
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel(
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
# explicitly pass in other=None to tell triton that masked values
|
||||
# can be uninitialized. This is OK because the later tl.store operation
|
||||
# uses the same mask, eliminating the risk of garbage values propagating
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
@ -42,7 +42,7 @@ def awq_dequantize_kernel(
|
||||
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
||||
|
||||
# Load the weights.
|
||||
iweights = tl.load(qweight_ptr + offsets, masks)
|
||||
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
@ -71,7 +71,7 @@ def awq_dequantize_kernel(
|
||||
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
||||
|
||||
# Load the zeros.
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
@ -91,7 +91,7 @@ def awq_dequantize_kernel(
|
||||
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
||||
|
||||
# Load the scales.
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Dequantize.
|
||||
@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a)
|
||||
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
masks_zk = offsets_szk < K // group_size
|
||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||
zeros_ptrs = zeros_ptr + offsets_z
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z)
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
masks_sk = offsets_szk < K // group_size
|
||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||
scales_ptrs = scales_ptr + offsets_s
|
||||
scales = tl.load(scales_ptrs, mask=masks_s)
|
||||
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
b = (b >> shifts) & 0xF
|
||||
|
Loading…
x
Reference in New Issue
Block a user