Fix CUDA kernel index data type in vllm/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +10 (#15160)
Signed-off-by: Lu Fang <lufang@fb.com> Co-authored-by: Richard Barnes <rbarnes@meta.com>
This commit is contained in:
parent
25f560a62c
commit
051da7efe3
@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||
}
|
||||
}
|
||||
|
@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int lda, int block_rows) {
|
||||
int start_row = block_rows * blockIdx.x;
|
||||
auto start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
finish_row = size_m;
|
||||
@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -723,8 +723,8 @@ __global__ void Marlin(
|
||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
|
||||
// For act_order
|
||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||
@ -743,7 +743,7 @@ __global__ void Marlin(
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
@ -756,7 +756,7 @@ __global__ void Marlin(
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
auto zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
@ -1047,7 +1047,7 @@ __global__ void Marlin(
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1085,7 +1085,7 @@ __global__ void Marlin(
|
||||
|
||||
// Determine "position" inside the thread-block (based on warp and
|
||||
// thread-id)
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps =
|
||||
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
||||
|
||||
@ -1094,7 +1094,7 @@ __global__ void Marlin(
|
||||
|
||||
cur_k += warp_row * 16;
|
||||
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
||||
|
||||
int s_col_shift =
|
||||
@ -1159,7 +1159,7 @@ __global__ void Marlin(
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1197,7 +1197,7 @@ __global__ void Marlin(
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1323,7 +1323,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
@ -1390,7 +1390,7 @@ __global__ void Marlin(
|
||||
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
}
|
||||
}
|
||||
|
@ -277,12 +277,12 @@ __global__ void Marlin(
|
||||
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x;
|
||||
int b_sh_rd = threadIdx.x;
|
||||
auto b_sh_wr = threadIdx.x;
|
||||
auto b_sh_rd = threadIdx.x;
|
||||
|
||||
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
int s_sh_rd;
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
@ -455,7 +455,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride;
|
||||
auto red_idx = threadIdx.x / b_sh_stride;
|
||||
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
||||
@ -522,7 +522,7 @@ __global__ void Marlin(
|
||||
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -353,10 +353,10 @@ __global__ void Marlin(
|
||||
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x;
|
||||
int b_sh_rd = threadIdx.x;
|
||||
auto b_sh_wr = threadIdx.x;
|
||||
auto b_sh_rd = threadIdx.x;
|
||||
|
||||
int s_tok_gl_rd = threadIdx.x;
|
||||
auto s_tok_gl_rd = threadIdx.x;
|
||||
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
|
||||
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
|
||||
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
|
||||
@ -368,8 +368,8 @@ __global__ void Marlin(
|
||||
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
|
||||
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
|
||||
|
||||
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
||||
int s_ch_sh_wr = threadIdx.x;
|
||||
auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
||||
auto s_ch_sh_wr = threadIdx.x;
|
||||
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
2 * ((threadIdx.x % 32) % 4);
|
||||
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
|
||||
@ -558,7 +558,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride;
|
||||
auto red_idx = threadIdx.x / b_sh_stride;
|
||||
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
||||
@ -628,7 +628,7 @@ __global__ void Marlin(
|
||||
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
|
||||
c_gl_wr += (4 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads * 2;
|
||||
int c_sh_wr = 2 * threadIdx.x;
|
||||
auto c_sh_wr = 2 * threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -273,15 +273,15 @@ __global__ void Marlin_24(
|
||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
|
||||
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
|
||||
(threadIdx.x % (m_sh_stride));
|
||||
m_gl_rd += (m_sh_stride)*slice_col;
|
||||
m_gl_rd += m_gl_rd_delta_o * slice_row;
|
||||
int m_sh_wr = threadIdx.x;
|
||||
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
|
||||
auto m_sh_wr = threadIdx.x;
|
||||
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
|
||||
|
||||
int s_gl_rd;
|
||||
if constexpr (group_blocks == -1) {
|
||||
@ -291,7 +291,7 @@ __global__ void Marlin_24(
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
int s_sh_rd;
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
@ -516,7 +516,7 @@ __global__ void Marlin_24(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
@ -583,7 +583,7 @@ __global__ void Marlin_24(
|
||||
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||
|
||||
|
@ -284,18 +284,18 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
// clang-format on
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
const int seq_idx = blockIdx.x;
|
||||
const int partition_idx = blockIdx.y;
|
||||
const auto seq_idx = blockIdx.x;
|
||||
const auto partition_idx = blockIdx.y;
|
||||
|
||||
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
|
||||
@ -346,9 +346,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// can be interpreted as B8x16 for 8 bit types
|
||||
_B16x8 Klocal[TLOOP][QKHELOOP];
|
||||
|
||||
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const int wg_start_kv_head_idx = blockIdx.z;
|
||||
const int total_num_heads = gridDim.z * GQA_RATIO;
|
||||
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const auto wg_start_kv_head_idx = blockIdx.z;
|
||||
const auto total_num_heads = gridDim.z * GQA_RATIO;
|
||||
|
||||
// for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
|
||||
// each mfma takes QH16xT16x16HE across warp
|
||||
@ -789,14 +789,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
// clang-format on
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane4id = laneid % 4;
|
||||
|
||||
const int seq_idx = blockIdx.x;
|
||||
const int partition_idx = blockIdx.y;
|
||||
const int partition_size = blockDim.x;
|
||||
const int max_num_partitions = gridDim.y;
|
||||
const auto seq_idx = blockIdx.x;
|
||||
const auto partition_idx = blockIdx.y;
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int partition_start_token_idx = partition_idx * partition_size;
|
||||
@ -838,8 +838,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
qk_max[h] = -FLT_MAX;
|
||||
}
|
||||
|
||||
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const int wg_start_kv_head_idx = blockIdx.z;
|
||||
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const auto wg_start_kv_head_idx = blockIdx.z;
|
||||
|
||||
const int warp_start_token_idx =
|
||||
partition_start_token_idx + warpid * WARP_SIZE;
|
||||
@ -857,7 +857,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
// token id within partition
|
||||
const int local_token_idx = threadIdx.x;
|
||||
const auto local_token_idx = threadIdx.x;
|
||||
// token id within sequence
|
||||
const int global_token_idx = partition_start_token_idx + local_token_idx;
|
||||
|
||||
@ -1126,7 +1126,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int num_heads = gridDim.z * GQA_RATIO;
|
||||
const auto num_heads = gridDim.z * GQA_RATIO;
|
||||
float* max_logits_ptr =
|
||||
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
|
||||
float* exp_sums_ptr =
|
||||
@ -1268,14 +1268,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const auto num_heads = gridDim.x;
|
||||
const auto head_idx = blockIdx.x;
|
||||
const auto seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
[[maybe_unused]] const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
// max num partitions supported is warp_size * NPAR_LOOPS
|
||||
@ -1294,7 +1294,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
valid_partition[i] =
|
||||
(partition_no < num_partitions) ? partition_no : last_valid_partition;
|
||||
}
|
||||
@ -1324,7 +1324,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
rescaled_exp_sum[i] *= (partition_no < num_partitions)
|
||||
? expf(reg_max_logit[i] - max_logit)
|
||||
: 0.0f;
|
||||
@ -1336,7 +1336,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user